diff --git a/apps/desktop/src-tauri/src/inference/completion.rs b/apps/desktop/src-tauri/src/inference/completion.rs index 3fefbc8..de05188 100644 --- a/apps/desktop/src-tauri/src/inference/completion.rs +++ b/apps/desktop/src-tauri/src/inference/completion.rs @@ -35,7 +35,7 @@ pub struct CompletionRequest { sampler: Option, - stream: Option, + pub stream: Option, max_tokens: Option, diff --git a/apps/desktop/src-tauri/src/inference/process.rs b/apps/desktop/src-tauri/src/inference/process.rs index 0450057..94675ca 100644 --- a/apps/desktop/src-tauri/src/inference/process.rs +++ b/apps/desktop/src-tauri/src/inference/process.rs @@ -20,9 +20,11 @@ pub type ModelGuard = Arc>>>; pub struct InferenceThreadRequest { pub token_sender: Sender, pub abort_flag: Arc>, - pub model_guard: ModelGuard, pub completion_request: CompletionRequest, + pub nonstream_completion_tokens: Arc>, + pub stream: bool, + pub tx: Option>, } impl InferenceThreadRequest { @@ -77,7 +79,7 @@ fn get_inference_params( } // Perhaps might be better to clone the model for each thread... -pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { +pub fn start<'a>(req: InferenceThreadRequest) -> JoinHandle<()> { println!("Spawning inference thread..."); actix_web::rt::task::spawn_blocking(move || { let mut rng = req.completion_request.get_rng(); @@ -86,6 +88,8 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { let mut token_utf8_buf = TokenUtf8Buffer::new(); let guard = req.model_guard.lock(); + let stream_enabled = req.stream; + let mut nonstream_res_str_buf = req.nonstream_completion_tokens.lock(); let model = match guard.as_ref() { Some(m) => m, @@ -105,7 +109,10 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { let start_at = std::time::SystemTime::now(); println!("Feeding prompt ..."); - req.send_event("FEEDING_PROMPT"); + + if stream_enabled { + req.send_event("FEEDING_PROMPT"); + } match session.feed_prompt::( model.as_ref(), @@ -118,7 +125,9 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { } if let Some(token) = token_utf8_buf.push(t) { - req.send_comment(format!("Processing token: {:?}", token).as_str()); + if stream_enabled { + req.send_comment(format!("Processing token: {:?}", token).as_str()); + } } Ok(InferenceFeedback::Continue) @@ -138,8 +147,10 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { } }; - req.send_comment("Generating tokens ..."); - req.send_event("GENERATING_TOKENS"); + if stream_enabled { + req.send_comment("Generating tokens ..."); + req.send_event("GENERATING_TOKENS"); + } // Reset the utf8 buf token_utf8_buf = TokenUtf8Buffer::new(); @@ -176,14 +187,19 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { // Buffer the token until it's valid UTF-8, then call the callback. if let Some(tokens) = token_utf8_buf.push(&token) { - match req - .token_sender - .send(CompletionResponse::to_data_bytes(tokens)) - { - Ok(_) => {} - Err(_) => { - break; + if req.stream { + match req + .token_sender + .send(CompletionResponse::to_data_bytes(tokens)) + { + Ok(_) => {} + Err(_) => { + break; + } } + } else { + //Collect tokens into str buffer + *nonstream_res_str_buf += &tokens; } } @@ -195,8 +211,15 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { println!("Inference stats: {:?}", stats); - if !req.token_sender.is_disconnected() { - req.send_done(); + if stream_enabled { + if !req.token_sender.is_disconnected() { + req.send_done(); + } + } else { + if let Some(tx) = req.tx { + //Tell server thread that inference completed, and let it respond + let _ = tx.send(()); + } } // TODO: Might make this into a callback later, for now we just abuse the singleton diff --git a/apps/desktop/src-tauri/src/inference/server.rs b/apps/desktop/src-tauri/src/inference/server.rs index 5193b9c..3b87547 100644 --- a/apps/desktop/src-tauri/src/inference/server.rs +++ b/apps/desktop/src-tauri/src/inference/server.rs @@ -1,10 +1,10 @@ use actix_cors::Cors; use actix_web::dev::ServerHandle; use actix_web::web::{Bytes, Json}; - use actix_web::{get, post, App, HttpResponse, HttpServer, Responder}; use parking_lot::{Mutex, RwLock}; use serde::Serialize; +use serde_json::json; use std::sync::{ atomic::{AtomicBool, Ordering}, @@ -58,24 +58,55 @@ async fn post_completions(payload: Json) -> impl Responder { let (token_sender, receiver) = flume::unbounded::(); - HttpResponse::Ok() - .append_header(("Content-Type", "text/event-stream")) - .append_header(("Cache-Control", "no-cache")) - .keep_alive() - .streaming({ - let abort_flag = Arc::new(RwLock::new(false)); - - AbortStream::new( - receiver, - abort_flag.clone(), - start(InferenceThreadRequest { - model_guard: model_guard.clone(), - abort_flag: abort_flag.clone(), - token_sender, - completion_request: payload.0, - }), - ) - }) + if let Some(true) = payload.stream { + HttpResponse::Ok() + .append_header(("Content-Type", "text/event-stream")) + .append_header(("Cache-Control", "no-cache")) + .keep_alive() + .streaming({ + let abort_flag = Arc::new(RwLock::new(false)); + let str_buffer = Arc::new(Mutex::new(String::new())); + + AbortStream::new( + receiver, + abort_flag.clone(), + start(InferenceThreadRequest { + model_guard: model_guard.clone(), + abort_flag: abort_flag.clone(), + token_sender, + completion_request: payload.0, + nonstream_completion_tokens: str_buffer.clone(), + stream: true, + tx: None, + }), + ) + }) + } else { + let abort_flag = Arc::new(RwLock::new(false)); + let completion_tokens = Arc::new(Mutex::new(String::new())); + let (tx, rx) = flume::unbounded::<()>(); + start(InferenceThreadRequest { + model_guard: model_guard.clone(), + abort_flag: abort_flag.clone(), + token_sender, + completion_request: payload.0, + nonstream_completion_tokens: completion_tokens.clone(), + stream: false, + tx: Some(tx), + }); + + rx.recv().unwrap(); + + let locked_str_buffer = completion_tokens.lock(); + let completion_body = json!({ + "completion": locked_str_buffer.clone() + }); + + HttpResponse::Ok() + .append_header(("Content-Type", "text/plain")) + .append_header(("Cache-Control", "no-cache")) + .json(completion_body) + } } #[tauri::command]