Skip to content

Commit

Permalink
feat: update code to support vdb_api_key
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Liu <[email protected]>
  • Loading branch information
apepkuss committed Dec 10, 2024
1 parent 5764426 commit 7f1ff2d
Showing 1 changed file with 71 additions and 23 deletions.
94 changes: 71 additions & 23 deletions src/backend/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -536,14 +536,21 @@ async fn retrieve_context_with_single_qdrant_config(
}
};

// get vdb_api_key if it is provided in the request, otherwise get it from the environment variable `VDB_API_KEY`
let vdb_api_key = chat_request
.vdb_api_key
.clone()
.or_else(|| std::env::var("VDB_API_KEY").ok());

// create a embedding request
let embedding_request = EmbeddingRequest {
model: Some(embedding_model_names[0].clone()),
input: InputText::String(query_text),
encoding_format: None,
user: chat_request.user.clone(),
qdrant_url: Some(qdrant_config.url.clone()),
qdrant_collection_name: Some(qdrant_config.collection_name.clone()),
vdb_server_url: Some(qdrant_config.url.clone()),
vdb_collection_name: Some(qdrant_config.collection_name.clone()),
vdb_api_key,
};

// compute embeddings for query
Expand Down Expand Up @@ -572,13 +579,20 @@ async fn retrieve_context_with_single_qdrant_config(
}
};

// get vdb_api_key if it is provided in the request, otherwise get it from the environment variable `VDB_API_KEY`
let vdb_api_key = chat_request
.vdb_api_key
.clone()
.or_else(|| std::env::var("VDB_API_KEY").ok());

// perform the context retrieval
let mut retrieve_object: RetrieveObject = match rag_retrieve_context(
query_embedding.as_slice(),
qdrant_config.url.to_string().as_str(),
qdrant_config.collection_name.as_str(),
qdrant_config.limit as usize,
Some(qdrant_config.score_threshold),
vdb_api_key,
)
.await
{
Expand Down Expand Up @@ -1514,7 +1528,9 @@ pub(crate) async fn create_rag_handler(
info!(target: "stdout", "Handling the coming doc_to_embeddings request.");

// upload the target rag document
let (file_object, url_vdb_server, collection_name) = if req.method() == Method::POST {
let (file_object, vdb_server_url, vdb_collection_name, vdb_api_key) = if req.method()
== Method::POST
{
let boundary = "boundary=";

let boundary = req.headers().get("content-type").and_then(|ct| {
Expand All @@ -1541,8 +1557,9 @@ pub(crate) async fn create_rag_handler(
let mut multipart = Multipart::with_body(cursor, boundary.unwrap());

let mut file_object: Option<FileObject> = None;
let mut url_vdb_server: String = String::new();
let mut collection_name: String = String::new();
let mut vdb_server_url: String = String::new();
let mut vdb_collection_name: String = String::new();
let mut vdb_api_key: String = String::new();
while let ReadEntryResult::Entry(mut field) = multipart.read_entry_mut() {
match &*field.headers.name {
"file" => {
Expand Down Expand Up @@ -1634,10 +1651,11 @@ pub(crate) async fn create_rag_handler(
purpose: "assistants".to_string(),
});
}
"url_vdb_server" => match field.is_text() {
"vdb_server_url" => match field.is_text() {
true => {
if let Err(e) = field.data.read_to_string(&mut url_vdb_server) {
let err_msg = format!("Failed to read the url_vdb_server field. {}", e);
if let Err(e) = field.data.read_to_string(&mut vdb_server_url) {
let err_msg =
format!("Failed to read the `vdb_server_url` field. {}", e);

// log
error!(target: "stdout", "{}", &err_msg);
Expand All @@ -1647,19 +1665,19 @@ pub(crate) async fn create_rag_handler(
}
false => {
let err_msg =
"Failed to get `url_vdb_server`. The `url_vdb_server` field in the request should be a text field.";
"Failed to get `vdb_server_url`. The `vdb_server_url` field in the request should be a text field.";

// log
error!(target: "stdout", "{}", &err_msg);

return error::internal_server_error(err_msg);
}
},
"collection_name" => match field.is_text() {
"vdb_collection_name" => match field.is_text() {
true => {
if let Err(e) = field.data.read_to_string(&mut collection_name) {
if let Err(e) = field.data.read_to_string(&mut vdb_collection_name) {
let err_msg =
format!("Failed to read the collection_name field. {}", e);
format!("Failed to read the `vdb_collection_name` field. {}", e);

// log
error!(target: "stdout", "{}", &err_msg);
Expand All @@ -1668,7 +1686,27 @@ pub(crate) async fn create_rag_handler(
}
}
false => {
let err_msg = "Failed to get `collection_name`. The `collection_name` field in the request should be a text field.";
let err_msg = "Failed to get `vdb_collection_name`. The `vdb_collection_name` field in the request should be a text field.";

// log
error!(target: "stdout", "{}", &err_msg);

return error::internal_server_error(err_msg);
}
},
"vdb_api_key" => match field.is_text() {
true => {
if let Err(e) = field.data.read_to_string(&mut vdb_api_key) {
let err_msg = format!("Failed to read the `vdb_api_key` field. {}", e);

// log
error!(target: "stdout", "{}", &err_msg);

return error::internal_server_error(err_msg);
}
}
false => {
let err_msg = "Failed to get `vdb_api_key`. The `vdb_api_key` field in the request should be a text field.";

// log
error!(target: "stdout", "{}", &err_msg);
Expand All @@ -1687,7 +1725,8 @@ pub(crate) async fn create_rag_handler(
}
}

match (url_vdb_server.is_empty(), collection_name.is_empty()) {
// If the request does not provide the vdb_server_url and vdb_collection_name, use the default vdb config from the server info, and get the vdb_api_key from the environment variable `VDB_API_KEY` if it is set.
match (vdb_server_url.is_empty(), vdb_collection_name.is_empty()) {
(true, true) => {
let qdrant_config_vec = match SERVER_INFO.get() {
Some(server_info) => server_info.read().await.qdrant_config.clone(),
Expand All @@ -1702,11 +1741,14 @@ pub(crate) async fn create_rag_handler(
};

// use the first qdrant config as the default config
url_vdb_server = qdrant_config_vec[0].url.clone();
collection_name = qdrant_config_vec[0].collection_name.clone();
vdb_server_url = qdrant_config_vec[0].url.clone();
vdb_collection_name = qdrant_config_vec[0].collection_name.clone();
if vdb_api_key.is_empty() {
vdb_api_key = std::env::var("VDB_API_KEY").unwrap_or_default();
}
}
(true, false) | (false, true) => {
let err_msg = "Failed to get `url_vdb_server` or `collection_name`. The `url_vdb_server` and `collection_name` fields in the request should be provided at the same time.";
let err_msg = "Failed to get `vdb_server_url` or `vdb_collection_name`. The `vdb_server_url` and `vdb_collection_name` fields in the request should be provided at the same time.";

// log
error!(target: "stdout", "{}", &err_msg);
Expand All @@ -1716,10 +1758,10 @@ pub(crate) async fn create_rag_handler(
(false, false) => {}
}

info!(target: "stdout", "url_vdb_server: {}, collection_name: {}", &url_vdb_server, &collection_name);
info!(target: "stdout", "vdb_server_url: {}, vdb_collection_name: {}", &vdb_server_url, &vdb_collection_name);

match file_object {
Some(fo) => (fo, url_vdb_server, collection_name),
Some(fo) => (fo, vdb_server_url, vdb_collection_name, vdb_api_key),
None => {
let err_msg = "Failed to upload the target file. Not found the target file.";

Expand Down Expand Up @@ -1878,14 +1920,20 @@ pub(crate) async fn create_rag_handler(

info!(target: "stdout", "Prepare the rag embedding request.");

let api_key = match vdb_api_key.is_empty() {
true => None,
false => Some(vdb_api_key),
};

// create an embedding request
let embedding_request = EmbeddingRequest {
model: Some(model),
input: chunks.into(),
encoding_format: None,
user: None,
qdrant_url: Some(url_vdb_server),
qdrant_collection_name: Some(collection_name),
vdb_server_url: Some(vdb_server_url),
vdb_collection_name: Some(vdb_collection_name),
vdb_api_key: api_key,
};

match rag_doc_chunks_to_embeddings(&embedding_request).await {
Expand Down Expand Up @@ -2122,8 +2170,8 @@ async fn get_qdrant_configs(
chat_request: &ChatCompletionRequest,
) -> Result<Vec<QdrantConfig>, error::ServerError> {
match (
chat_request.url_vdb_server.as_deref(),
chat_request.collection_name.as_deref(),
chat_request.vdb_server_url.as_deref(),
chat_request.vdb_collection_name.as_deref(),
chat_request.limit.as_deref(),
chat_request.score_threshold.as_deref(),
) {
Expand Down

0 comments on commit 7f1ff2d

Please sign in to comment.