Skip to content

Commit

Permalink
Revert "Update InMemoryIndex.save to save to mmap instead of speciali…
Browse files Browse the repository at this point in the history
…zed serialization format"

This reverts commit 21acdee.
  • Loading branch information
luciaquirke committed Aug 12, 2024
1 parent 9167c24 commit 3a3d347
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 34 deletions.
41 changes: 13 additions & 28 deletions src/in_memory_index.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,27 @@
use anyhow::Result;
use bincode::deserialize;
use bincode::{deserialize, serialize};
use pyo3::prelude::*;
use std::collections::HashMap;
use std::fs::File;
use std::io::Read;
use std::fs::OpenOptions;

use crate::sample::{KneserNeyCache, Sample};
use crate::table::SuffixTable;
use crate::util::transmute_slice;
use crate::table::Table;
use crate::mmap_slice::MmapSliceMut;
use crate::table::InMemoryTable;

/// An in-memory index exposes suffix table functionality over text corpora small enough to fit in memory.
#[pyclass]
pub struct InMemoryIndex {
table: Box<dyn Table + Send + Sync>,
table: Box<dyn InMemoryTable + Send + Sync>,
cache: KneserNeyCache
}

impl InMemoryIndex {
pub fn new(tokens: Vec<usize>, vocab: Option<usize>, verbose: bool) -> Self {
let vocab = vocab.unwrap_or(u16::MAX as usize + 1);

let table: Box<dyn Table + Send + Sync> = if vocab <= u16::MAX as usize + 1 {
let table: Box<dyn InMemoryTable + Send + Sync> = if vocab <= u16::MAX as usize + 1 {
let tokens: Vec<u16> = tokens.iter().map(|&x| x as u16).collect();
Box::new(SuffixTable::<Box<[u16]>, Box<[u64]>>::new(tokens, Some(vocab), verbose))
} else {
Expand Down Expand Up @@ -63,7 +61,7 @@ impl InMemoryIndex {
#[new]
#[pyo3(signature = (tokens, vocab=u16::MAX as usize + 1, verbose=false))]
pub fn new_py(_py: Python, tokens: Vec<usize>, vocab: usize, verbose: bool) -> Self {
let table: Box<dyn Table + Send + Sync> = if vocab <= u16::MAX as usize + 1 {
let table: Box<dyn InMemoryTable + Send + Sync> = if vocab <= u16::MAX as usize + 1 {
let tokens: Vec<u16> = tokens.iter().map(|&x| x as u16).collect();
Box::new(SuffixTable::<Box<[u16]>, Box<[u64]>>::new(tokens, Some(vocab), verbose))
} else {
Expand Down Expand Up @@ -119,7 +117,7 @@ impl InMemoryIndex {
file.read_to_end(&mut buffer)?;
};

let table: Box<dyn Table + Send + Sync> = if vocab <= u16::MAX as usize + 1 {
let table: Box<dyn InMemoryTable + Send + Sync> = if vocab <= u16::MAX as usize + 1 {
let tokens = transmute_slice::<u8, u16>(buffer.as_slice());
Box::new(SuffixTable::new(tokens, Some(vocab), verbose))
} else {
Expand Down Expand Up @@ -158,6 +156,13 @@ impl InMemoryIndex {
self.table.batch_count_next(&queries)
}

pub fn save(&self, path: String) -> PyResult<()> {
// TODO: handle errors here
let bytes = serialize(&self.table).unwrap();
std::fs::write(&path, bytes)?;
Ok(())
}

/// Autoregressively sample num_samples of k characters from an unsmoothed n-gram model."""
pub fn sample_unsmoothed(
&self,
Expand Down Expand Up @@ -202,24 +207,4 @@ impl InMemoryIndex {
pub fn estimate_deltas(&mut self, n: usize) {
self.estimate_deltas_rs(n);
}

pub fn save(&self, path: String) -> PyResult<()> {
let table_file = OpenOptions::new()
.create(true)
.read(true)
.write(true)
.open(path)?;

let table_data = self.table.get_table();
let table_size = table_data.len() * std::mem::size_of::<u64>();
table_file.set_len(table_size as u64)?;

let mut table_mmap = MmapSliceMut::<u64>::new(&table_file)?;

assert_eq!(table_mmap.len(), table_data.len(), "Mismatch in table data length");

table_mmap.copy_from_slice(table_data);

Ok(())
}
}
15 changes: 9 additions & 6 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,10 +395,17 @@ pub trait Table {

// For a given n, produce a map from an occurrence count to the number of unique n-grams with that occurrence count.
fn count_ngrams(&self, n: usize) -> HashMap<usize, usize>;

fn get_table(&self) -> &[u64];
}

#[typetag::serialize]
pub trait InMemoryTable: Table {}

#[typetag::serialize]
impl InMemoryTable for SuffixTable<Box<[u16]>> {}

#[typetag::serialize]
impl InMemoryTable for SuffixTable<Box<[u32]>> {}

impl<T, U, E> Table for SuffixTable<T, U>
where
T: Deref<Target = [E]> + Sync,
Expand Down Expand Up @@ -442,10 +449,6 @@ where
fn count_ngrams(&self, n: usize) -> HashMap<usize, usize> {
self.count_ngrams(n)
}

fn get_table(&self) -> &[u64] {
self.table.deref()
}
}

impl fmt::Debug for SuffixTable {
Expand Down

0 comments on commit 3a3d347

Please sign in to comment.