-
Notifications
You must be signed in to change notification settings - Fork 56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Nonstreaming API #85
base: main
Are you sure you want to change the base?
Changes from 7 commits
e19274d
b7e238a
ca81d59
715b6da
48a0a83
16d82d9
b7befbf
a89acd2
f201f6f
50da5f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,9 +20,11 @@ pub type ModelGuard = Arc<Mutex<Option<Box<dyn Model>>>>; | |
pub struct InferenceThreadRequest { | ||
pub token_sender: Sender<Bytes>, | ||
pub abort_flag: Arc<RwLock<bool>>, | ||
|
||
pub model_guard: ModelGuard, | ||
pub completion_request: CompletionRequest, | ||
pub nonstream_completion_tokens: Arc<Mutex<String>>, | ||
pub stream: bool, | ||
pub tx: Option<Sender<()>>, | ||
} | ||
|
||
impl InferenceThreadRequest { | ||
|
@@ -77,7 +79,7 @@ fn get_inference_params( | |
} | ||
|
||
// Perhaps might be better to clone the model for each thread... | ||
pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { | ||
pub fn start<'a>(req: InferenceThreadRequest) -> JoinHandle<()> { | ||
println!("Spawning inference thread..."); | ||
actix_web::rt::task::spawn_blocking(move || { | ||
let mut rng = req.completion_request.get_rng(); | ||
|
@@ -86,6 +88,8 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { | |
|
||
let mut token_utf8_buf = TokenUtf8Buffer::new(); | ||
let guard = req.model_guard.lock(); | ||
let stream_enabled = req.stream; | ||
let mut nonstream_res_str_buf = req.nonstream_completion_tokens.lock(); | ||
|
||
let model = match guard.as_ref() { | ||
Some(m) => m, | ||
|
@@ -105,7 +109,10 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { | |
let start_at = std::time::SystemTime::now(); | ||
|
||
println!("Feeding prompt ..."); | ||
req.send_event("FEEDING_PROMPT"); | ||
|
||
if stream_enabled { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do this check at the trait level instead. This way we can unify the interface call (in this file), and handle the stream/non-stream logic at the trait implementation level instead, would make it much nicer and more cohesive :) |
||
req.send_event("FEEDING_PROMPT"); | ||
} | ||
|
||
match session.feed_prompt::<Infallible, Prompt>( | ||
model.as_ref(), | ||
|
@@ -118,7 +125,9 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { | |
} | ||
|
||
if let Some(token) = token_utf8_buf.push(t) { | ||
req.send_comment(format!("Processing token: {:?}", token).as_str()); | ||
if stream_enabled { | ||
req.send_comment(format!("Processing token: {:?}", token).as_str()); | ||
} | ||
} | ||
|
||
Ok(InferenceFeedback::Continue) | ||
|
@@ -138,8 +147,10 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { | |
} | ||
}; | ||
|
||
req.send_comment("Generating tokens ..."); | ||
req.send_event("GENERATING_TOKENS"); | ||
if stream_enabled { | ||
req.send_comment("Generating tokens ..."); | ||
req.send_event("GENERATING_TOKENS"); | ||
} | ||
|
||
// Reset the utf8 buf | ||
token_utf8_buf = TokenUtf8Buffer::new(); | ||
|
@@ -176,14 +187,19 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { | |
|
||
// Buffer the token until it's valid UTF-8, then call the callback. | ||
if let Some(tokens) = token_utf8_buf.push(&token) { | ||
match req | ||
.token_sender | ||
.send(CompletionResponse::to_data_bytes(tokens)) | ||
{ | ||
Ok(_) => {} | ||
Err(_) => { | ||
break; | ||
if req.stream { | ||
match req | ||
.token_sender | ||
.send(CompletionResponse::to_data_bytes(tokens)) | ||
{ | ||
Ok(_) => {} | ||
Err(_) => { | ||
break; | ||
} | ||
} | ||
} else { | ||
//Collect tokens into str buffer | ||
*nonstream_res_str_buf += &tokens; | ||
} | ||
} | ||
|
||
|
@@ -195,8 +211,15 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> { | |
|
||
println!("Inference stats: {:?}", stats); | ||
|
||
if !req.token_sender.is_disconnected() { | ||
req.send_done(); | ||
if stream_enabled { | ||
if !req.token_sender.is_disconnected() { | ||
req.send_done(); | ||
} | ||
} else { | ||
if let Some(tx) = req.tx { | ||
//Tell server thread that inference completed, and let it respond | ||
let _ = tx.send(()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need that |
||
} | ||
} | ||
|
||
// TODO: Might make this into a callback later, for now we just abuse the singleton | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,10 @@ | ||
use actix_cors::Cors; | ||
use actix_web::dev::ServerHandle; | ||
use actix_web::web::{Bytes, Json}; | ||
|
||
use actix_web::{get, post, App, HttpResponse, HttpServer, Responder}; | ||
use parking_lot::{Mutex, RwLock}; | ||
use serde::Serialize; | ||
use serde_json::json; | ||
|
||
use std::sync::{ | ||
atomic::{AtomicBool, Ordering}, | ||
|
@@ -58,24 +58,55 @@ async fn post_completions(payload: Json<CompletionRequest>) -> impl Responder { | |
|
||
let (token_sender, receiver) = flume::unbounded::<Bytes>(); | ||
|
||
HttpResponse::Ok() | ||
.append_header(("Content-Type", "text/event-stream")) | ||
.append_header(("Cache-Control", "no-cache")) | ||
.keep_alive() | ||
.streaming({ | ||
let abort_flag = Arc::new(RwLock::new(false)); | ||
|
||
AbortStream::new( | ||
receiver, | ||
abort_flag.clone(), | ||
start(InferenceThreadRequest { | ||
model_guard: model_guard.clone(), | ||
abort_flag: abort_flag.clone(), | ||
token_sender, | ||
completion_request: payload.0, | ||
}), | ||
) | ||
}) | ||
if let Some(true) = payload.stream { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be payload.0.stream I think, since it's a JSON. If we can reconcile our trait above, we can infer the stream boolean via the |
||
HttpResponse::Ok() | ||
.append_header(("Content-Type", "text/event-stream")) | ||
.append_header(("Cache-Control", "no-cache")) | ||
.keep_alive() | ||
.streaming({ | ||
let abort_flag = Arc::new(RwLock::new(false)); | ||
let str_buffer = Arc::new(Mutex::new(String::new())); | ||
|
||
AbortStream::new( | ||
receiver, | ||
abort_flag.clone(), | ||
start(InferenceThreadRequest { | ||
model_guard: model_guard.clone(), | ||
abort_flag: abort_flag.clone(), | ||
token_sender, | ||
completion_request: payload.0, | ||
nonstream_completion_tokens: str_buffer.clone(), | ||
stream: true, | ||
tx: None, | ||
}), | ||
Comment on lines
+73
to
+81
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have this idea which I think would make this nicer - we can create the let request = InferenceThreadRequest {
model_guard: model_guard.clone(),
abort_flag: abort_flag.clone(),
token_sender,
completion_request: payload.0,
nonstream_completion_tokens: str_buffer.clone(),
}
if request.isStream() {} else {} And the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I really like your attention to detail and design thinking! I will try to implement this one, I agree, it is indeed cleaner. |
||
) | ||
}) | ||
} else { | ||
let abort_flag = Arc::new(RwLock::new(false)); | ||
let completion_tokens = Arc::new(Mutex::new(String::new())); | ||
let (tx, rx) = flume::unbounded::<()>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we can make the tokensender generic, so that we can reuse that argument. The |
||
start(InferenceThreadRequest { | ||
model_guard: model_guard.clone(), | ||
abort_flag: abort_flag.clone(), | ||
token_sender, | ||
completion_request: payload.0, | ||
nonstream_completion_tokens: completion_tokens.clone(), | ||
stream: false, | ||
tx: Some(tx), | ||
}); | ||
|
||
rx.recv().unwrap(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should match for error and return HTTP error here IMO, otherwise would be hard to triage :d |
||
|
||
let locked_str_buffer = completion_tokens.lock(); | ||
let completion_body = json!({ | ||
"completion": locked_str_buffer.clone() | ||
}); | ||
|
||
HttpResponse::Ok() | ||
.append_header(("Content-Type", "text/plain")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that makes sense, yes. Will fix those. Thanks for looking at my code. |
||
.append_header(("Cache-Control", "no-cache")) | ||
.json(completion_body) | ||
} | ||
} | ||
|
||
#[tauri::command] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can make this private if we use it as a trait state for the non-stream feature. Making it pub would allow others to inspect it while it's writing/locked, which could potentially deadlock the Mutex writer if we're not careful... :d