diff --git a/src/main.rs b/src/main.rs index c5e6bcd..31a9e14 100644 --- a/src/main.rs +++ b/src/main.rs @@ -78,6 +78,9 @@ struct Cli { /// Number of layers to run on the GPU #[arg(short = 'g', long, default_value = "100")] n_gpu_layers: u64, + /// Split the model across multiple GPUs. Possible values: `none` (use one GPU only), `layer` (split layers and KV across GPUs, default), `row` (split rows across GPUs) + #[arg(long, default_value = "layer")] + split_mode: String, /// The main GPU to use. #[arg(long)] main_gpu: Option, @@ -246,6 +249,9 @@ async fn main() -> Result<(), ServerError> { // log n_gpu_layers info!(target: "stdout", "n_gpu_layers: {}", &cli.n_gpu_layers); + // log split_mode + info!(target: "stdout", "split_mode: {}", cli.split_mode); + // log main GPU if let Some(main_gpu) = &cli.main_gpu { info!(target: "stdout", "main_gpu: {}", main_gpu); @@ -395,6 +401,7 @@ async fn main() -> Result<(), ServerError> { .with_batch_size(cli.batch_size[0]) .with_n_predict(cli.n_predict) .with_n_gpu_layers(cli.n_gpu_layers) + .with_split_mode(cli.split_mode.clone()) .with_main_gpu(cli.main_gpu) .with_tensor_split(cli.tensor_split.clone()) .with_threads(cli.threads) @@ -418,6 +425,9 @@ async fn main() -> Result<(), ServerError> { repeat_penalty: chat_metadata.repeat_penalty, presence_penalty: chat_metadata.presence_penalty, frequency_penalty: chat_metadata.frequency_penalty, + split_mode: chat_metadata.split_mode.clone(), + main_gpu: chat_metadata.main_gpu, + tensor_split: chat_metadata.tensor_split.clone(), }; // chat model @@ -431,6 +441,7 @@ async fn main() -> Result<(), ServerError> { ) .with_ctx_size(cli.ctx_size[1]) .with_batch_size(cli.batch_size[1]) + .with_split_mode(cli.split_mode) .with_main_gpu(cli.main_gpu) .with_tensor_split(cli.tensor_split) .with_threads(cli.threads) @@ -441,17 +452,20 @@ async fn main() -> Result<(), ServerError> { let embedding_model_info = ModelConfig { name: embedding_metadata.model_name.clone(), ty: "embedding".to_string(), + ctx_size: embedding_metadata.ctx_size, + batch_size: embedding_metadata.batch_size, prompt_template: embedding_metadata.prompt_template, n_predict: embedding_metadata.n_predict, reverse_prompt: embedding_metadata.reverse_prompt.clone(), n_gpu_layers: embedding_metadata.n_gpu_layers, - ctx_size: embedding_metadata.ctx_size, - batch_size: embedding_metadata.batch_size, temperature: embedding_metadata.temperature, top_p: embedding_metadata.top_p, repeat_penalty: embedding_metadata.repeat_penalty, presence_penalty: embedding_metadata.presence_penalty, frequency_penalty: embedding_metadata.frequency_penalty, + split_mode: embedding_metadata.split_mode.clone(), + main_gpu: embedding_metadata.main_gpu, + tensor_split: embedding_metadata.tensor_split.clone(), }; // embedding model @@ -686,18 +700,23 @@ pub(crate) struct ModelConfig { // type: chat or embedding #[serde(rename = "type")] ty: String, + pub ctx_size: u64, + pub batch_size: u64, pub prompt_template: PromptTemplateType, pub n_predict: i32, #[serde(skip_serializing_if = "Option::is_none")] pub reverse_prompt: Option, pub n_gpu_layers: u64, - pub ctx_size: u64, - pub batch_size: u64, pub temperature: f64, pub top_p: f64, pub repeat_penalty: f64, pub presence_penalty: f64, pub frequency_penalty: f64, + pub split_mode: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub main_gpu: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tensor_split: Option, } #[derive(Debug, Serialize, Deserialize)]