Skip to content

Commit

Permalink
Update languagemodel tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ZJaume committed Oct 3, 2024
1 parent 595ca4d commit d8a3fbe
Showing 1 changed file with 45 additions and 38 deletions.
83 changes: 45 additions & 38 deletions src/languagemodel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ impl ModelNgram {
}
}

/// Load the model from plain text for a subset of languages
pub fn from_text_langs(
model_dir: &Path,
model_type: OrderNgram,
Expand All @@ -80,6 +81,7 @@ impl ModelNgram {
Ok(model)
}

/// Load the model from plain text for all languages
pub fn from_text_all(model_dir: &Path, model_type: OrderNgram) -> Result<Self> {
let mut model = ModelNgram {
dic: HashMap::default(),
Expand Down Expand Up @@ -112,6 +114,7 @@ impl ModelNgram {
Ok(model)
}

/// Parse the ngram file, compute probabilities and insert into the model
fn read_model(&mut self, p: &Path, langcode: &Lang) -> Result<()> {
// Read the language model file to a string all at once
let modelfile =
Expand Down Expand Up @@ -289,47 +292,51 @@ impl Index<usize> for Model {
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use std::thread;

use tempfile::NamedTempFile;

#[test]
fn test_langs() {
let tempf = NamedTempFile::new().unwrap();
let temppath = tempf.into_temp_path();
let modelpath = Path::new("./LanguageModels");
let wordmodel = ModelNgram::from_text(&modelpath, OrderNgram::Word);
let path = Path::new("wordict.ser");
wordmodel.save(path);

let charmodel = ModelNgram::from_text(&modelpath, OrderNgram::Quadgram);
let path = Path::new("gramdict.ser");
charmodel.save(path);

let char_handle = thread::spawn(move || {
let path = Path::new("gramdict.ser");
ModelNgram::from_bin(path)
});

let word_handle = thread::spawn(move || {
let path = Path::new("wordict.ser");
ModelNgram::from_bin(path)
});

// let word_model = word_handle.join().unwrap();
let char_model = char_handle.join().unwrap();

// failing because original HeLI is using a java float
// instead of a double for accumulating frequencies
let mut expected = HashMap::default();
expected.insert(Lang::Cat, 3.4450269f32);
expected.insert(Lang::Epo, 4.5279417f32);
expected.insert(Lang::Ext, 2.5946937f32);
expected.insert(Lang::Gla, 4.7058706f32);
expected.insert(Lang::Glg, 2.3187783f32);
expected.insert(Lang::Grn, 2.9653773f32);
expected.insert(Lang::Nhn, 4.774119f32);
expected.insert(Lang::Que, 3.8074818f32);
expected.insert(Lang::Spa, 2.480955f32);

let probs = char_model.dic.get("ación").unwrap();
assert_eq!(probs, &expected);

let model = ModelNgram::from_text(&modelpath,
OrderNgram::Quingram,
None).unwrap();
// let path = Path::new("gramdict.ser");
model.save(&temppath).unwrap();
let model = ModelNgram::from_bin(&temppath).unwrap();
temppath.close().unwrap();

let mut expected = Vec::new();
expected.push((Lang::ayr, 4.2863530f32));
expected.push((Lang::cat, 3.3738296f32));
expected.push((Lang::epo, 4.5279417f32));
expected.push((Lang::ext, 2.5946038f32));
expected.push((Lang::gla, 4.7052390f32));
expected.push((Lang::glg, 2.3186955f32));
expected.push((Lang::grn, 3.1885893f32));
expected.push((Lang::kac, 5.5482570f32));
expected.push((Lang::lmo, 5.2805230f32));
expected.push((Lang::nhn, 5.0725970f32));
expected.push((Lang::que, 3.8049161f32));
expected.push((Lang::spa, 2.3922930f32));
expected.push((Lang::vol, 5.1173210f32));

let mut probs = model.dic.get("ación")
.expect("Could not found the ngram in the model")
.clone();
// round to less decimals to be a lit permissive
// as there are differences between java and rust
let round_to = 10000.0;
for i in expected.iter_mut() {
i.1 = (i.1 * round_to).round() / round_to;
}
for i in probs.iter_mut() {
i.1 = (i.1 * round_to).round() / round_to;
}
assert_eq!(&probs, &expected);

}
}

0 comments on commit d8a3fbe

Please sign in to comment.