-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
158 additions
and
119 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
use std::path::{Path, PathBuf}; | ||
use std::fs; | ||
use std::process::exit; | ||
|
||
use clap::Args; | ||
use log::{error, warn, info}; | ||
use pyo3::prelude::*; | ||
use strum::IntoEnumIterator; | ||
|
||
use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram}; | ||
use crate::utils::Abort; | ||
use crate::python::module_path; | ||
|
||
#[derive(Args, Clone)] | ||
pub struct BinarizeCmd { | ||
#[arg(help="Input directory where ngram frequency files are located")] | ||
input_dir: Option<PathBuf>, | ||
#[arg(help="Output directory to place the binary files")] | ||
output_dir: Option<PathBuf>, | ||
#[arg(short, long, help="Force overwrite of output files if they already exist")] | ||
force: bool, | ||
} | ||
|
||
impl BinarizeCmd { | ||
pub fn cli(self) -> PyResult<()> { | ||
let model_path = self.input_dir.unwrap_or(PathBuf::from("./LanguageModels")); | ||
let save_path = self.output_dir.unwrap_or(module_path().unwrap()); | ||
|
||
// Fail and warn the use if there is already a model | ||
if !self.force && | ||
save_path.join( | ||
format!("{}.bin", OrderNgram::Word.to_string()) | ||
).exists() | ||
{ | ||
warn!("Binarized models are now included in the PyPi package, \ | ||
there is no need to binarize the model unless you are training a new one" | ||
); | ||
error!("Output model already exists, use '-f' to force overwrite"); | ||
exit(1); | ||
} | ||
|
||
for model_type in OrderNgram::iter() { | ||
let type_repr = model_type.to_string(); | ||
info!("Loading {type_repr} model"); | ||
let model = ModelNgram::from_text(&model_path, model_type, None) | ||
.or_abort(1); | ||
let size = model.dic.len(); | ||
info!("Created {size} entries"); | ||
let filename = save_path.join(format!("{type_repr}.bin")); | ||
info!("Saving {type_repr} model"); | ||
model.save(Path::new(&filename)).or_abort(1); | ||
} | ||
info!("Copying confidence thresholds file"); | ||
fs::copy( | ||
model_path.join(Model::CONFIDENCE_FILE), | ||
save_path.join(Model::CONFIDENCE_FILE), | ||
).or_abort(1); | ||
|
||
info!("Saved models at '{}'", save_path.display()); | ||
info!("Finished"); | ||
|
||
Ok(()) | ||
} | ||
} | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
use std::path::PathBuf; | ||
use std::env; | ||
|
||
use clap::Args; | ||
use pyo3::prelude::*; | ||
use log::info; | ||
use target; | ||
|
||
use crate::python::module_path; | ||
use crate::download; | ||
|
||
#[derive(Args, Clone)] | ||
pub struct DownloadCmd { | ||
#[arg(help="Path to download the model, defaults to the module path")] | ||
path: Option<PathBuf>, | ||
} | ||
|
||
impl DownloadCmd { | ||
pub fn cli(self) -> PyResult<()> { | ||
let download_path = self.path.unwrap_or(module_path().unwrap()); | ||
|
||
let url = format!( | ||
"https://github.com/ZJaume/{}/releases/download/v{}/models-{}-{}.tgz", | ||
env!("CARGO_PKG_NAME"), | ||
env!("CARGO_PKG_VERSION"), | ||
target::os(), | ||
target::arch()); | ||
|
||
download::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap(); | ||
info!("Finished"); | ||
|
||
Ok(()) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
mod identify; | ||
#[cfg(feature = "download")] | ||
mod download; | ||
mod binarize; | ||
|
||
use clap::{Subcommand, Parser}; | ||
use log::{debug}; | ||
use pyo3::prelude::*; | ||
use env_logger::Env; | ||
|
||
use crate::python::module_path; | ||
#[cfg(feature = "download")] | ||
use self::download::DownloadCmd; | ||
use self::binarize::BinarizeCmd; | ||
use self::identify::IdentifyCmd; | ||
|
||
#[derive(Parser, Clone)] | ||
#[command(version, about, long_about = None)] | ||
pub struct Cli { | ||
#[command(subcommand)] | ||
command: Commands, | ||
} | ||
|
||
#[derive(Subcommand, Clone)] | ||
enum Commands { | ||
#[cfg(feature = "download")] | ||
#[command(about="Download heliport model from GitHub")] | ||
#[cfg(feature = "download")] | ||
Download(DownloadCmd), | ||
#[command(about="Binarize heliport model")] | ||
Binarize(BinarizeCmd), | ||
#[command(about="Identify languages of input text", visible_alias="detect")] | ||
Identify(IdentifyCmd), | ||
} | ||
|
||
|
||
|
||
#[pyfunction] | ||
pub fn cli_run() -> PyResult<()> { | ||
// parse the cli arguments, skip the first one that is the path to the Python entry point | ||
let os_args = std::env::args_os().skip(1); | ||
let args = Cli::parse_from(os_args); | ||
debug!("Module path found at: {}", module_path().expect("Could not found module path").display()); | ||
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init(); | ||
|
||
match args.command { | ||
#[cfg(feature = "download")] | ||
Commands::Download(cmd) => { cmd.cli() }, | ||
Commands::Binarize(cmd) => { cmd.cli() }, | ||
Commands::Identify(cmd) => { cmd.cli() }, | ||
} | ||
} |