Skip to content

Commit

Permalink
refactor!: update --batch-size CLI option
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Liu <[email protected]>
  • Loading branch information
apepkuss committed May 18, 2024
1 parent 4589d37 commit 2680316
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct Cli {
default_value = "default,embedding"
)]
model_alias: Vec<String>,
/// Sets context sizes for chat and embedding models. The sizes are separated by comma without space, for example, '--ctx-size 4096,384'. The first value is for the chat model, and the second is for the embedding model.
/// Sets context sizes for chat and embedding models, respectively. The sizes are separated by comma without space, for example, '--ctx-size 4096,384'. The first value is for the chat model, and the second is for the embedding model.
#[arg(
short = 'c',
long,
Expand All @@ -61,9 +61,9 @@ struct Cli {
/// Halt generation at PROMPT, return control.
#[arg(short, long)]
reverse_prompt: Option<String>,
/// Batch size for prompt processing
#[arg(short, long, default_value = "512")]
batch_size: u64,
/// Sets batch sizes for chat and embedding models, respectively. The sizes are separated by comma without space, for example, '--batch-size 128,64'. The first value is for the chat model, and the second is for the embedding model.
#[arg(short, long, value_delimiter = ',', default_value = "512,512", value_parser = clap::value_parser!(u64))]
batch_size: Vec<u64>,
/// Custom rag prompt.
#[arg(long)]
rag_prompt: Option<String>,
Expand Down Expand Up @@ -155,6 +155,21 @@ async fn main() -> Result<(), ServerError> {
"[INFO] Context sizes: {ctx_sizes}",
ctx_sizes = ctx_sizes_str
));
if cli.batch_size.len() != 2 {
return Err(ServerError::ArgumentError(
"LlamaEdge RAG API server requires two batch sizes: one for chat model, one for embedding model.".to_owned(),
));
}
let batch_sizes_str: String = cli
.ctx_size
.iter()
.map(|n| n.to_string())
.collect::<Vec<String>>()
.join(",");
log(format!(
"[INFO] Batch sizes: {batch_sizes}",
batch_sizes = batch_sizes_str
));
log(format!("[INFO] Prompt template: {}", &cli.prompt_template));
if let Some(reverse_prompt) = &cli.reverse_prompt {
log(format!("[INFO] reverse prompt: {}", reverse_prompt));
Expand Down Expand Up @@ -218,7 +233,7 @@ async fn main() -> Result<(), ServerError> {
)
.with_ctx_size(cli.ctx_size[0])
.with_reverse_prompt(cli.reverse_prompt)
.with_batch_size(cli.batch_size)
.with_batch_size(cli.batch_size[0])
.enable_prompts_log(cli.log_prompts || cli.log_all)
.enable_plugin_log(cli.log_stat || cli.log_all)
.enable_debug_log(plugin_debug)
Expand Down Expand Up @@ -250,7 +265,7 @@ async fn main() -> Result<(), ServerError> {
cli.prompt_template,
)
.with_ctx_size(cli.ctx_size[1])
.with_batch_size(cli.batch_size)
.with_batch_size(cli.batch_size[1])
.enable_prompts_log(cli.log_prompts || cli.log_all)
.enable_plugin_log(cli.log_stat || cli.log_all)
.enable_debug_log(plugin_debug)
Expand Down

0 comments on commit 2680316

Please sign in to comment.