Skip to content

Commit

Permalink
refactor!: update the type and default value of --n-predict CLI option
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Liu <[email protected]>
  • Loading branch information
apepkuss committed Jan 9, 2025
1 parent e8027f5 commit 05d18d3
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ struct Cli {
/// Halt generation at PROMPT, return control.
#[arg(short, long)]
reverse_prompt: Option<String>,
/// Number of tokens to predict
#[arg(short, long, default_value = "1024")]
n_predict: u64,
/// Number of tokens to predict, -1 = infinity, -2 = until context filled.
#[arg(short, long, default_value = "-1")]
n_predict: i32,
/// Number of layers to run on the GPU
#[arg(short = 'g', long, default_value = "100")]
n_gpu_layers: u64,
Expand Down Expand Up @@ -158,8 +158,6 @@ async fn main() -> Result<(), ServerError> {
wasi_logger::Logger::install().expect("failed to install wasi_logger::Logger");
log::set_max_level(log_level.into());

info!(target: "stdout", "log_level: {}", log_level);

if let Ok(api_key) = std::env::var("API_KEY") {
// define a const variable for the API key
if let Err(e) = LLAMA_API_KEY.set(api_key) {
Expand All @@ -174,6 +172,8 @@ async fn main() -> Result<(), ServerError> {
// parse the command line arguments
let cli = Cli::parse();

info!(target: "stdout", "log_level: {}", log_level);

// log the version of the server
info!(target: "stdout", "server_version: {}", env!("CARGO_PKG_VERSION"));

Expand Down Expand Up @@ -687,7 +687,7 @@ pub(crate) struct ModelConfig {
#[serde(rename = "type")]
ty: String,
pub prompt_template: PromptTemplateType,
pub n_predict: u64,
pub n_predict: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reverse_prompt: Option<String>,
pub n_gpu_layers: u64,
Expand Down

0 comments on commit 05d18d3

Please sign in to comment.