diff --git a/src/main.rs b/src/main.rs index ba6aac1..310e1f3 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,7 +46,7 @@ struct Cli { default_value = "default,embedding" )] model_alias: Vec, - /// Sets context sizes for chat and embedding models. The sizes are separated by comma without space, for example, '--ctx-size 4096,384'. The first value is for the chat model, and the second is for the embedding model. + /// Sets context sizes for chat and embedding models, respectively. The sizes are separated by comma without space, for example, '--ctx-size 4096,384'. The first value is for the chat model, and the second is for the embedding model. #[arg( short = 'c', long, @@ -61,9 +61,9 @@ struct Cli { /// Halt generation at PROMPT, return control. #[arg(short, long)] reverse_prompt: Option, - /// Batch size for prompt processing - #[arg(short, long, default_value = "512")] - batch_size: u64, + /// Sets batch sizes for chat and embedding models, respectively. The sizes are separated by comma without space, for example, '--batch-size 128,64'. The first value is for the chat model, and the second is for the embedding model. + #[arg(short, long, value_delimiter = ',', default_value = "512,512", value_parser = clap::value_parser!(u64))] + batch_size: Vec, /// Custom rag prompt. #[arg(long)] rag_prompt: Option, @@ -155,6 +155,21 @@ async fn main() -> Result<(), ServerError> { "[INFO] Context sizes: {ctx_sizes}", ctx_sizes = ctx_sizes_str )); + if cli.batch_size.len() != 2 { + return Err(ServerError::ArgumentError( + "LlamaEdge RAG API server requires two batch sizes: one for chat model, one for embedding model.".to_owned(), + )); + } + let batch_sizes_str: String = cli + .ctx_size + .iter() + .map(|n| n.to_string()) + .collect::>() + .join(","); + log(format!( + "[INFO] Batch sizes: {batch_sizes}", + batch_sizes = batch_sizes_str + )); log(format!("[INFO] Prompt template: {}", &cli.prompt_template)); if let Some(reverse_prompt) = &cli.reverse_prompt { log(format!("[INFO] reverse prompt: {}", reverse_prompt)); @@ -218,7 +233,7 @@ async fn main() -> Result<(), ServerError> { ) .with_ctx_size(cli.ctx_size[0]) .with_reverse_prompt(cli.reverse_prompt) - .with_batch_size(cli.batch_size) + .with_batch_size(cli.batch_size[0]) .enable_prompts_log(cli.log_prompts || cli.log_all) .enable_plugin_log(cli.log_stat || cli.log_all) .enable_debug_log(plugin_debug) @@ -250,7 +265,7 @@ async fn main() -> Result<(), ServerError> { cli.prompt_template, ) .with_ctx_size(cli.ctx_size[1]) - .with_batch_size(cli.batch_size) + .with_batch_size(cli.batch_size[1]) .enable_prompts_log(cli.log_prompts || cli.log_all) .enable_plugin_log(cli.log_stat || cli.log_all) .enable_debug_log(plugin_debug)