diff --git a/src/main.rs b/src/main.rs index 287c97f..c5e6bcd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -72,9 +72,9 @@ struct Cli { /// Halt generation at PROMPT, return control. #[arg(short, long)] reverse_prompt: Option, - /// Number of tokens to predict - #[arg(short, long, default_value = "1024")] - n_predict: u64, + /// Number of tokens to predict, -1 = infinity, -2 = until context filled. + #[arg(short, long, default_value = "-1")] + n_predict: i32, /// Number of layers to run on the GPU #[arg(short = 'g', long, default_value = "100")] n_gpu_layers: u64, @@ -158,8 +158,6 @@ async fn main() -> Result<(), ServerError> { wasi_logger::Logger::install().expect("failed to install wasi_logger::Logger"); log::set_max_level(log_level.into()); - info!(target: "stdout", "log_level: {}", log_level); - if let Ok(api_key) = std::env::var("API_KEY") { // define a const variable for the API key if let Err(e) = LLAMA_API_KEY.set(api_key) { @@ -174,6 +172,8 @@ async fn main() -> Result<(), ServerError> { // parse the command line arguments let cli = Cli::parse(); + info!(target: "stdout", "log_level: {}", log_level); + // log the version of the server info!(target: "stdout", "server_version: {}", env!("CARGO_PKG_VERSION")); @@ -687,7 +687,7 @@ pub(crate) struct ModelConfig { #[serde(rename = "type")] ty: String, pub prompt_template: PromptTemplateType, - pub n_predict: u64, + pub n_predict: i32, #[serde(skip_serializing_if = "Option::is_none")] pub reverse_prompt: Option, pub n_gpu_layers: u64,