Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Nonstreaming API #85

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Conversation

JNeuvonen
Copy link
Collaborator

@JNeuvonen JNeuvonen commented Jul 24, 2023

The implementation uses the same start function inside process.rs for multithreading but just doesn't send server events back to the request sender on every new token but collects the tokens into a string buffer.

Currently, there is no client-side implementation, so merging should not affect client-side at all. Next, we could open an issue for client-side implementation as well.

Here is a request body for quickly testing the API (stream flag is false):

{"sampler":"top-p-top-k","prompt":"AI: Greeting! I am a friendly AI assistant. Feel free to ask me anything.\nHuman: Hello world\nAI: ","max_tokens":200,"temperature":1,"seed":147,"frequency_penalty":0.6,"presence_penalty":0,"top_k":42,"top_p":1,"stop":["AI: ","Human: "],"stream":false}

Issue

@vercel
Copy link

vercel bot commented Jul 24, 2023

The latest updates on your projects. Learn more about Vercel for Git ↗︎

Name Status Preview Comments Updated (UTC)
local-ai-web ✅ Ready (Inspect) Visit Preview 💬 Add feedback Sep 24, 2023 4:14am

@louisgv
Copy link
Owner

louisgv commented Jul 31, 2023

@JNeuvonen hey sorry about the slow review on my end, I've been pretty busy with summer chores/errand and also other works xD... Also was investigating #62 and why the upstream llama metal doesn't seem to work on Mac anymore :d..... Will get to this by Wednesday.

Is it ok for me to cook it up a bit if I find something wrong/missing, or would you prefer just comment and you can take care of it? LMK what type of feedback is cool for you :)

@JNeuvonen
Copy link
Collaborator Author

JNeuvonen commented Aug 1, 2023 via email

@louisgv louisgv changed the title Nonstreaming api feat: Nonstreaming api Aug 2, 2023
@louisgv louisgv changed the title feat: Nonstreaming api feat: Nonstreaming API Aug 2, 2023
});

HttpResponse::Ok()
.append_header(("Content-Type", "text/plain"))
Copy link
Owner

@louisgv louisgv Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should return application/json type here instead I think, it helps the client know to do JSON chunk parsing as needed as well based on that header type

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that makes sense, yes. Will fix those. Thanks for looking at my code.

tx: Some(tx),
});

rx.recv().unwrap();
Copy link
Owner

@louisgv louisgv Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should match for error and return HTTP error here IMO, otherwise would be hard to triage :d

} else {
if let Some(tx) = req.tx {
//Tell server thread that inference completed, and let it respond
let _ = tx.send(());
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need that _ or can we just call send here?

@@ -105,7 +109,10 @@ pub fn start(req: InferenceThreadRequest) -> JoinHandle<()> {
let start_at = std::time::SystemTime::now();

println!("Feeding prompt ...");
req.send_event("FEEDING_PROMPT");

if stream_enabled {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do this check at the trait level instead. This way we can unify the interface call (in this file), and handle the stream/non-stream logic at the trait implementation level instead, would make it much nicer and more cohesive :)

pub model_guard: ModelGuard,
pub completion_request: CompletionRequest,
pub nonstream_completion_tokens: Arc<Mutex<String>>,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can make this private if we use it as a trait state for the non-stream feature. Making it pub would allow others to inspect it while it's writing/locked, which could potentially deadlock the Mutex writer if we're not careful... :d

} else {
let abort_flag = Arc::new(RwLock::new(false));
let completion_tokens = Arc::new(Mutex::new(String::new()));
let (tx, rx) = flume::unbounded::<()>();
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can make the tokensender generic, so that we can reuse that argument. The token_sender and the tx serve very similar function here, we just need to reconcile the Byte/String type. That'd make for nicer interface I think

}),
)
})
if let Some(true) = payload.stream {
Copy link
Owner

@louisgv louisgv Aug 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be payload.0.stream I think, since it's a JSON.

If we can reconcile our trait above, we can infer the stream boolean via the completion_request as well, skipping a couple of lookup hoop!

Copy link
Owner

@louisgv louisgv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overall idea is great thus far, added some comment and idea on improvement 👍

Comment on lines +73 to +81
start(InferenceThreadRequest {
model_guard: model_guard.clone(),
abort_flag: abort_flag.clone(),
token_sender,
completion_request: payload.0,
nonstream_completion_tokens: str_buffer.clone(),
stream: true,
tx: None,
}),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have this idea which I think would make this nicer - we can create the InferenceThreadRequest before the isStream check actually, since it's non-blocking state. We can then do

let request = InferenceThreadRequest {
            model_guard: model_guard.clone(),
            abort_flag: abort_flag.clone(),
            token_sender,
            completion_request: payload.0,
            nonstream_completion_tokens: str_buffer.clone(),          
}

if request.isStream() {} else {} 

And the .isStream is a trait public method we expose via InferenceThreadRequest, which basically return completion_request.stream

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like your attention to detail and design thinking! I will try to implement this one, I agree, it is indeed cleaner.

@louisgv
Copy link
Owner

louisgv commented Aug 2, 2023

@JNeuvonen invited you as repo collaborator

@louisgv
Copy link
Owner

louisgv commented Sep 16, 2023

@JNeuvonen lmk if you're still able to update the PR - otherwise I can get on it sometime next week!

@JNeuvonen
Copy link
Collaborator Author

Hey, I apologize that I didn't come back earlier. Back when I was working on this, I was on a summer vacation, now I am back on my work schedule, and I have less time & focus. Please feel free to finish the feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants