From 3a3d347179fcee0bcd5dca42bb681661750edaf5 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Mon, 12 Aug 2024 04:16:56 +0000 Subject: [PATCH] Revert "Update InMemoryIndex.save to save to mmap instead of specialized serialization format" This reverts commit 21acdee5bd7de8a98362cf5907777a1ce3d84201. --- src/in_memory_index.rs | 41 +++++++++++++---------------------------- src/table.rs | 15 +++++++++------ 2 files changed, 22 insertions(+), 34 deletions(-) diff --git a/src/in_memory_index.rs b/src/in_memory_index.rs index 463c6f6..84ebb2c 100644 --- a/src/in_memory_index.rs +++ b/src/in_memory_index.rs @@ -1,21 +1,19 @@ use anyhow::Result; -use bincode::deserialize; +use bincode::{deserialize, serialize}; use pyo3::prelude::*; use std::collections::HashMap; use std::fs::File; use std::io::Read; -use std::fs::OpenOptions; use crate::sample::{KneserNeyCache, Sample}; use crate::table::SuffixTable; use crate::util::transmute_slice; -use crate::table::Table; -use crate::mmap_slice::MmapSliceMut; +use crate::table::InMemoryTable; /// An in-memory index exposes suffix table functionality over text corpora small enough to fit in memory. #[pyclass] pub struct InMemoryIndex { - table: Box, + table: Box, cache: KneserNeyCache } @@ -23,7 +21,7 @@ impl InMemoryIndex { pub fn new(tokens: Vec, vocab: Option, verbose: bool) -> Self { let vocab = vocab.unwrap_or(u16::MAX as usize + 1); - let table: Box = if vocab <= u16::MAX as usize + 1 { + let table: Box = if vocab <= u16::MAX as usize + 1 { let tokens: Vec = tokens.iter().map(|&x| x as u16).collect(); Box::new(SuffixTable::, Box<[u64]>>::new(tokens, Some(vocab), verbose)) } else { @@ -63,7 +61,7 @@ impl InMemoryIndex { #[new] #[pyo3(signature = (tokens, vocab=u16::MAX as usize + 1, verbose=false))] pub fn new_py(_py: Python, tokens: Vec, vocab: usize, verbose: bool) -> Self { - let table: Box = if vocab <= u16::MAX as usize + 1 { + let table: Box = if vocab <= u16::MAX as usize + 1 { let tokens: Vec = tokens.iter().map(|&x| x as u16).collect(); Box::new(SuffixTable::, Box<[u64]>>::new(tokens, Some(vocab), verbose)) } else { @@ -119,7 +117,7 @@ impl InMemoryIndex { file.read_to_end(&mut buffer)?; }; - let table: Box = if vocab <= u16::MAX as usize + 1 { + let table: Box = if vocab <= u16::MAX as usize + 1 { let tokens = transmute_slice::(buffer.as_slice()); Box::new(SuffixTable::new(tokens, Some(vocab), verbose)) } else { @@ -158,6 +156,13 @@ impl InMemoryIndex { self.table.batch_count_next(&queries) } + pub fn save(&self, path: String) -> PyResult<()> { + // TODO: handle errors here + let bytes = serialize(&self.table).unwrap(); + std::fs::write(&path, bytes)?; + Ok(()) + } + /// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model.""" pub fn sample_unsmoothed( &self, @@ -202,24 +207,4 @@ impl InMemoryIndex { pub fn estimate_deltas(&mut self, n: usize) { self.estimate_deltas_rs(n); } - - pub fn save(&self, path: String) -> PyResult<()> { - let table_file = OpenOptions::new() - .create(true) - .read(true) - .write(true) - .open(path)?; - - let table_data = self.table.get_table(); - let table_size = table_data.len() * std::mem::size_of::(); - table_file.set_len(table_size as u64)?; - - let mut table_mmap = MmapSliceMut::::new(&table_file)?; - - assert_eq!(table_mmap.len(), table_data.len(), "Mismatch in table data length"); - - table_mmap.copy_from_slice(table_data); - - Ok(()) - } } \ No newline at end of file diff --git a/src/table.rs b/src/table.rs index b090780..c8f91b2 100644 --- a/src/table.rs +++ b/src/table.rs @@ -395,10 +395,17 @@ pub trait Table { // For a given n, produce a map from an occurrence count to the number of unique n-grams with that occurrence count. fn count_ngrams(&self, n: usize) -> HashMap; - - fn get_table(&self) -> &[u64]; } +#[typetag::serialize] +pub trait InMemoryTable: Table {} + +#[typetag::serialize] +impl InMemoryTable for SuffixTable> {} + +#[typetag::serialize] +impl InMemoryTable for SuffixTable> {} + impl Table for SuffixTable where T: Deref + Sync, @@ -442,10 +449,6 @@ where fn count_ngrams(&self, n: usize) -> HashMap { self.count_ngrams(n) } - - fn get_table(&self) -> &[u64] { - self.table.deref() - } } impl fmt::Debug for SuffixTable {