Skip to content

Commit

Permalink
Disable download feature by default
Browse files Browse the repository at this point in the history
  • Loading branch information
ZJaume committed Oct 29, 2024
1 parent 5c8cf88 commit 4927abf
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 119 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,6 @@ test-log = "0.2.15"
[features]
# Put log features in default, to allow crates using heli as a library, disable them
default = ["cli", "log/max_level_debug", "log/release_max_level_debug"]
cli = ["download", "python", "dep:clap", "dep:target"]
cli = ["python", "dep:clap", "dep:target"]
download = ["dep:tokio", "dep:tempfile", "dep:reqwest", "dep:futures-util"]
python = ["dep:pyo3"]
66 changes: 66 additions & 0 deletions src/cli/binarize.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use std::path::{Path, PathBuf};
use std::fs;
use std::process::exit;

use clap::Args;
use log::{error, warn, info};
use pyo3::prelude::*;
use strum::IntoEnumIterator;

use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram};
use crate::utils::Abort;
use crate::python::module_path;

#[derive(Args, Clone)]
pub struct BinarizeCmd {
#[arg(help="Input directory where ngram frequency files are located")]
input_dir: Option<PathBuf>,
#[arg(help="Output directory to place the binary files")]
output_dir: Option<PathBuf>,
#[arg(short, long, help="Force overwrite of output files if they already exist")]
force: bool,
}

impl BinarizeCmd {
pub fn cli(self) -> PyResult<()> {
let model_path = self.input_dir.unwrap_or(PathBuf::from("./LanguageModels"));
let save_path = self.output_dir.unwrap_or(module_path().unwrap());

// Fail and warn the use if there is already a model
if !self.force &&
save_path.join(
format!("{}.bin", OrderNgram::Word.to_string())
).exists()
{
warn!("Binarized models are now included in the PyPi package, \
there is no need to binarize the model unless you are training a new one"
);
error!("Output model already exists, use '-f' to force overwrite");
exit(1);
}

for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = ModelNgram::from_text(&model_path, model_type, None)
.or_abort(1);
let size = model.dic.len();
info!("Created {size} entries");
let filename = save_path.join(format!("{type_repr}.bin"));
info!("Saving {type_repr} model");
model.save(Path::new(&filename)).or_abort(1);
}
info!("Copying confidence thresholds file");
fs::copy(
model_path.join(Model::CONFIDENCE_FILE),
save_path.join(Model::CONFIDENCE_FILE),
).or_abort(1);

info!("Saved models at '{}'", save_path.display());
info!("Finished");

Ok(())
}
}


34 changes: 34 additions & 0 deletions src/cli/download.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::path::PathBuf;
use std::env;

use clap::Args;
use pyo3::prelude::*;
use log::info;
use target;

use crate::python::module_path;
use crate::download;

#[derive(Args, Clone)]
pub struct DownloadCmd {
#[arg(help="Path to download the model, defaults to the module path")]
path: Option<PathBuf>,
}

impl DownloadCmd {
pub fn cli(self) -> PyResult<()> {
let download_path = self.path.unwrap_or(module_path().unwrap());

let url = format!(
"https://github.com/ZJaume/{}/releases/download/v{}/models-{}-{}.tgz",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION"),
target::os(),
target::arch());

download::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap();
info!("Finished");

Ok(())
}
}
123 changes: 5 additions & 118 deletions src/cli.rs → src/cli/identify.rs
Original file line number Diff line number Diff line change
@@ -1,121 +1,21 @@
use std::io::{self, BufRead, BufReader, Write, BufWriter};
use std::fs::{copy, File};
use std::fs::File;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::env;
use std::process::exit;

use anyhow::{Context, Result};
use clap::{Parser, Subcommand, Args};
use clap::Args;
use itertools::Itertools;
use log::{debug};
use pyo3::prelude::*;
use log::{error, warn, info, debug};
use env_logger::Env;
use strum::IntoEnumIterator;
use target;

use heliport_model::languagemodel::{Model, ModelNgram, OrderNgram};
use heliport_model::lang::Lang;
use crate::identifier::Identifier;
use crate::utils::Abort;
use crate::python::module_path;
use crate::download;

#[derive(Parser, Clone)]
#[command(version, about, long_about = None)]
pub struct Cli {
#[command(subcommand)]
command: Commands,
}

#[derive(Subcommand, Clone)]
enum Commands {
#[command(about="Download heliport model from GitHub")]
Download(DownloadCmd),
#[command(about="Binarize heliport model")]
Binarize(BinarizeCmd),
#[command(about="Identify languages of input text", visible_alias="detect")]
Identify(IdentifyCmd),
}

#[derive(Args, Clone)]
struct BinarizeCmd {
#[arg(help="Input directory where ngram frequency files are located")]
input_dir: Option<PathBuf>,
#[arg(help="Output directory to place the binary files")]
output_dir: Option<PathBuf>,
#[arg(short, long, help="Force overwrite of output files if they already exist")]
force: bool,
}

impl BinarizeCmd {
fn cli(self) -> PyResult<()> {
let model_path = self.input_dir.unwrap_or(PathBuf::from("./LanguageModels"));
let save_path = self.output_dir.unwrap_or(module_path().unwrap());

// Fail and warn the use if there is already a model
if !self.force &&
save_path.join(
format!("{}.bin", OrderNgram::Word.to_string())
).exists()
{
warn!("Binarized models are now included in the PyPi package, \
there is no need to binarize the model unless you are training a new one"
);
error!("Output model already exists, use '-f' to force overwrite");
exit(1);
}

for model_type in OrderNgram::iter() {
let type_repr = model_type.to_string();
info!("Loading {type_repr} model");
let model = ModelNgram::from_text(&model_path, model_type, None)
.or_abort(1);
let size = model.dic.len();
info!("Created {size} entries");
let filename = save_path.join(format!("{type_repr}.bin"));
info!("Saving {type_repr} model");
model.save(Path::new(&filename)).or_abort(1);
}
info!("Copying confidence thresholds file");
copy(
model_path.join(Model::CONFIDENCE_FILE),
save_path.join(Model::CONFIDENCE_FILE),
).or_abort(1);

info!("Saved models at '{}'", save_path.display());
info!("Finished");

Ok(())
}
}

#[derive(Args, Clone)]
struct DownloadCmd {
#[arg(help="Path to download the model, defaults to the module path")]
path: Option<PathBuf>,
}

impl DownloadCmd {
fn cli(self) -> PyResult<()> {
let download_path = self.path.unwrap_or(module_path().unwrap());

let url = format!(
"https://github.com/ZJaume/{}/releases/download/v{}/models-{}-{}.tgz",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION"),
target::os(),
target::arch());

download::download_file_and_extract(&url, download_path.to_str().unwrap()).unwrap();
info!("Finished");

Ok(())
}
}

#[derive(Args, Clone, Debug)]
struct IdentifyCmd {
pub struct IdentifyCmd {
#[arg(help="Number of parallel threads to use.\n0 means no multi-threading\n1 means running the identification in a separated thread\n>1 run multithreading",
short='j',
long,
Expand Down Expand Up @@ -170,7 +70,7 @@ fn parse_langs(langs_text: &Vec<String>) -> Result<Vec<Lang>> {
}

impl IdentifyCmd {
fn cli(self) -> PyResult<()> {
pub fn cli(self) -> PyResult<()> {
// If provided, parse the list of relevant languages
let mut relevant_langs = None;
if let Some(r) = &self.relevant_langs {
Expand Down Expand Up @@ -280,17 +180,4 @@ impl IdentifyCmd {
}
}

#[pyfunction]
pub fn cli_run() -> PyResult<()> {
// parse the cli arguments, skip the first one that is the path to the Python entry point
let os_args = std::env::args_os().skip(1);
let args = Cli::parse_from(os_args);
debug!("Module path found at: {}", module_path().expect("Could not found module path").display());
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();

match args.command {
Commands::Download(cmd) => { cmd.cli() },
Commands::Binarize(cmd) => { cmd.cli() },
Commands::Identify(cmd) => { cmd.cli() },
}
}
52 changes: 52 additions & 0 deletions src/cli/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
mod identify;
#[cfg(feature = "download")]
mod download;
mod binarize;

use clap::{Subcommand, Parser};
use log::{debug};
use pyo3::prelude::*;
use env_logger::Env;

use crate::python::module_path;
#[cfg(feature = "download")]
use self::download::DownloadCmd;
use self::binarize::BinarizeCmd;
use self::identify::IdentifyCmd;

#[derive(Parser, Clone)]
#[command(version, about, long_about = None)]
pub struct Cli {
#[command(subcommand)]
command: Commands,
}

#[derive(Subcommand, Clone)]
enum Commands {
#[cfg(feature = "download")]
#[command(about="Download heliport model from GitHub")]
#[cfg(feature = "download")]
Download(DownloadCmd),
#[command(about="Binarize heliport model")]
Binarize(BinarizeCmd),
#[command(about="Identify languages of input text", visible_alias="detect")]
Identify(IdentifyCmd),
}



#[pyfunction]
pub fn cli_run() -> PyResult<()> {
// parse the cli arguments, skip the first one that is the path to the Python entry point
let os_args = std::env::args_os().skip(1);
let args = Cli::parse_from(os_args);
debug!("Module path found at: {}", module_path().expect("Could not found module path").display());
env_logger::Builder::from_env(Env::default().default_filter_or("info")).init();

match args.command {
#[cfg(feature = "download")]
Commands::Download(cmd) => { cmd.cli() },
Commands::Binarize(cmd) => { cmd.cli() },
Commands::Identify(cmd) => { cmd.cli() },
}
}

0 comments on commit 4927abf

Please sign in to comment.