Skip to content

Commit

Permalink
feat: add --split-mode CLI option
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Liu <[email protected]>
  • Loading branch information
apepkuss committed Jan 9, 2025
1 parent f4f7c15 commit ccae7a3
Showing 1 changed file with 23 additions and 4 deletions.
27 changes: 23 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ struct Cli {
/// Number of layers to run on the GPU
#[arg(short = 'g', long, default_value = "100")]
n_gpu_layers: u64,
/// Split the model across multiple GPUs. Possible values: `none` (use one GPU only), `layer` (split layers and KV across GPUs, default), `row` (split rows across GPUs)
#[arg(long, default_value = "layer")]
split_mode: String,
/// The main GPU to use.
#[arg(long)]
main_gpu: Option<u64>,
Expand Down Expand Up @@ -246,6 +249,9 @@ async fn main() -> Result<(), ServerError> {
// log n_gpu_layers
info!(target: "stdout", "n_gpu_layers: {}", &cli.n_gpu_layers);

// log split_mode
info!(target: "stdout", "split_mode: {}", cli.split_mode);

// log main GPU
if let Some(main_gpu) = &cli.main_gpu {
info!(target: "stdout", "main_gpu: {}", main_gpu);
Expand Down Expand Up @@ -395,6 +401,7 @@ async fn main() -> Result<(), ServerError> {
.with_batch_size(cli.batch_size[0])
.with_n_predict(cli.n_predict)
.with_n_gpu_layers(cli.n_gpu_layers)
.with_split_mode(cli.split_mode.clone())
.with_main_gpu(cli.main_gpu)
.with_tensor_split(cli.tensor_split.clone())
.with_threads(cli.threads)
Expand All @@ -418,6 +425,9 @@ async fn main() -> Result<(), ServerError> {
repeat_penalty: chat_metadata.repeat_penalty,
presence_penalty: chat_metadata.presence_penalty,
frequency_penalty: chat_metadata.frequency_penalty,
split_mode: chat_metadata.split_mode.clone(),
main_gpu: chat_metadata.main_gpu,
tensor_split: chat_metadata.tensor_split.clone(),
};

// chat model
Expand All @@ -431,6 +441,7 @@ async fn main() -> Result<(), ServerError> {
)
.with_ctx_size(cli.ctx_size[1])
.with_batch_size(cli.batch_size[1])
.with_split_mode(cli.split_mode)
.with_main_gpu(cli.main_gpu)
.with_tensor_split(cli.tensor_split)
.with_threads(cli.threads)
Expand All @@ -441,17 +452,20 @@ async fn main() -> Result<(), ServerError> {
let embedding_model_info = ModelConfig {
name: embedding_metadata.model_name.clone(),
ty: "embedding".to_string(),
ctx_size: embedding_metadata.ctx_size,
batch_size: embedding_metadata.batch_size,
prompt_template: embedding_metadata.prompt_template,
n_predict: embedding_metadata.n_predict,
reverse_prompt: embedding_metadata.reverse_prompt.clone(),
n_gpu_layers: embedding_metadata.n_gpu_layers,
ctx_size: embedding_metadata.ctx_size,
batch_size: embedding_metadata.batch_size,
temperature: embedding_metadata.temperature,
top_p: embedding_metadata.top_p,
repeat_penalty: embedding_metadata.repeat_penalty,
presence_penalty: embedding_metadata.presence_penalty,
frequency_penalty: embedding_metadata.frequency_penalty,
split_mode: embedding_metadata.split_mode.clone(),
main_gpu: embedding_metadata.main_gpu,
tensor_split: embedding_metadata.tensor_split.clone(),
};

// embedding model
Expand Down Expand Up @@ -686,18 +700,23 @@ pub(crate) struct ModelConfig {
// type: chat or embedding
#[serde(rename = "type")]
ty: String,
pub ctx_size: u64,
pub batch_size: u64,
pub prompt_template: PromptTemplateType,
pub n_predict: i32,
#[serde(skip_serializing_if = "Option::is_none")]
pub reverse_prompt: Option<String>,
pub n_gpu_layers: u64,
pub ctx_size: u64,
pub batch_size: u64,
pub temperature: f64,
pub top_p: f64,
pub repeat_penalty: f64,
pub presence_penalty: f64,
pub frequency_penalty: f64,
pub split_mode: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub main_gpu: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tensor_split: Option<String>,
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down

0 comments on commit ccae7a3

Please sign in to comment.