diff --git a/src/backend/ggml.rs b/src/backend/ggml.rs index 743def5..2dbab3d 100644 --- a/src/backend/ggml.rs +++ b/src/backend/ggml.rs @@ -536,14 +536,21 @@ async fn retrieve_context_with_single_qdrant_config( } }; + // get vdb_api_key if it is provided in the request, otherwise get it from the environment variable `VDB_API_KEY` + let vdb_api_key = chat_request + .vdb_api_key + .clone() + .or_else(|| std::env::var("VDB_API_KEY").ok()); + // create a embedding request let embedding_request = EmbeddingRequest { model: Some(embedding_model_names[0].clone()), input: InputText::String(query_text), encoding_format: None, user: chat_request.user.clone(), - qdrant_url: Some(qdrant_config.url.clone()), - qdrant_collection_name: Some(qdrant_config.collection_name.clone()), + vdb_server_url: Some(qdrant_config.url.clone()), + vdb_collection_name: Some(qdrant_config.collection_name.clone()), + vdb_api_key, }; // compute embeddings for query @@ -572,6 +579,12 @@ async fn retrieve_context_with_single_qdrant_config( } }; + // get vdb_api_key if it is provided in the request, otherwise get it from the environment variable `VDB_API_KEY` + let vdb_api_key = chat_request + .vdb_api_key + .clone() + .or_else(|| std::env::var("VDB_API_KEY").ok()); + // perform the context retrieval let mut retrieve_object: RetrieveObject = match rag_retrieve_context( query_embedding.as_slice(), @@ -579,6 +592,7 @@ async fn retrieve_context_with_single_qdrant_config( qdrant_config.collection_name.as_str(), qdrant_config.limit as usize, Some(qdrant_config.score_threshold), + vdb_api_key, ) .await { @@ -1514,7 +1528,9 @@ pub(crate) async fn create_rag_handler( info!(target: "stdout", "Handling the coming doc_to_embeddings request."); // upload the target rag document - let (file_object, url_vdb_server, collection_name) = if req.method() == Method::POST { + let (file_object, vdb_server_url, vdb_collection_name, vdb_api_key) = if req.method() + == Method::POST + { let boundary = "boundary="; let boundary = req.headers().get("content-type").and_then(|ct| { @@ -1541,8 +1557,9 @@ pub(crate) async fn create_rag_handler( let mut multipart = Multipart::with_body(cursor, boundary.unwrap()); let mut file_object: Option = None; - let mut url_vdb_server: String = String::new(); - let mut collection_name: String = String::new(); + let mut vdb_server_url: String = String::new(); + let mut vdb_collection_name: String = String::new(); + let mut vdb_api_key: String = String::new(); while let ReadEntryResult::Entry(mut field) = multipart.read_entry_mut() { match &*field.headers.name { "file" => { @@ -1634,10 +1651,11 @@ pub(crate) async fn create_rag_handler( purpose: "assistants".to_string(), }); } - "url_vdb_server" => match field.is_text() { + "vdb_server_url" => match field.is_text() { true => { - if let Err(e) = field.data.read_to_string(&mut url_vdb_server) { - let err_msg = format!("Failed to read the url_vdb_server field. {}", e); + if let Err(e) = field.data.read_to_string(&mut vdb_server_url) { + let err_msg = + format!("Failed to read the `vdb_server_url` field. {}", e); // log error!(target: "stdout", "{}", &err_msg); @@ -1647,7 +1665,7 @@ pub(crate) async fn create_rag_handler( } false => { let err_msg = - "Failed to get `url_vdb_server`. The `url_vdb_server` field in the request should be a text field."; + "Failed to get `vdb_server_url`. The `vdb_server_url` field in the request should be a text field."; // log error!(target: "stdout", "{}", &err_msg); @@ -1655,11 +1673,11 @@ pub(crate) async fn create_rag_handler( return error::internal_server_error(err_msg); } }, - "collection_name" => match field.is_text() { + "vdb_collection_name" => match field.is_text() { true => { - if let Err(e) = field.data.read_to_string(&mut collection_name) { + if let Err(e) = field.data.read_to_string(&mut vdb_collection_name) { let err_msg = - format!("Failed to read the collection_name field. {}", e); + format!("Failed to read the `vdb_collection_name` field. {}", e); // log error!(target: "stdout", "{}", &err_msg); @@ -1668,7 +1686,27 @@ pub(crate) async fn create_rag_handler( } } false => { - let err_msg = "Failed to get `collection_name`. The `collection_name` field in the request should be a text field."; + let err_msg = "Failed to get `vdb_collection_name`. The `vdb_collection_name` field in the request should be a text field."; + + // log + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + }, + "vdb_api_key" => match field.is_text() { + true => { + if let Err(e) = field.data.read_to_string(&mut vdb_api_key) { + let err_msg = format!("Failed to read the `vdb_api_key` field. {}", e); + + // log + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + } + false => { + let err_msg = "Failed to get `vdb_api_key`. The `vdb_api_key` field in the request should be a text field."; // log error!(target: "stdout", "{}", &err_msg); @@ -1687,7 +1725,8 @@ pub(crate) async fn create_rag_handler( } } - match (url_vdb_server.is_empty(), collection_name.is_empty()) { + // If the request does not provide the vdb_server_url and vdb_collection_name, use the default vdb config from the server info, and get the vdb_api_key from the environment variable `VDB_API_KEY` if it is set. + match (vdb_server_url.is_empty(), vdb_collection_name.is_empty()) { (true, true) => { let qdrant_config_vec = match SERVER_INFO.get() { Some(server_info) => server_info.read().await.qdrant_config.clone(), @@ -1702,11 +1741,14 @@ pub(crate) async fn create_rag_handler( }; // use the first qdrant config as the default config - url_vdb_server = qdrant_config_vec[0].url.clone(); - collection_name = qdrant_config_vec[0].collection_name.clone(); + vdb_server_url = qdrant_config_vec[0].url.clone(); + vdb_collection_name = qdrant_config_vec[0].collection_name.clone(); + if vdb_api_key.is_empty() { + vdb_api_key = std::env::var("VDB_API_KEY").unwrap_or_default(); + } } (true, false) | (false, true) => { - let err_msg = "Failed to get `url_vdb_server` or `collection_name`. The `url_vdb_server` and `collection_name` fields in the request should be provided at the same time."; + let err_msg = "Failed to get `vdb_server_url` or `vdb_collection_name`. The `vdb_server_url` and `vdb_collection_name` fields in the request should be provided at the same time."; // log error!(target: "stdout", "{}", &err_msg); @@ -1716,10 +1758,10 @@ pub(crate) async fn create_rag_handler( (false, false) => {} } - info!(target: "stdout", "url_vdb_server: {}, collection_name: {}", &url_vdb_server, &collection_name); + info!(target: "stdout", "vdb_server_url: {}, vdb_collection_name: {}", &vdb_server_url, &vdb_collection_name); match file_object { - Some(fo) => (fo, url_vdb_server, collection_name), + Some(fo) => (fo, vdb_server_url, vdb_collection_name, vdb_api_key), None => { let err_msg = "Failed to upload the target file. Not found the target file."; @@ -1878,14 +1920,20 @@ pub(crate) async fn create_rag_handler( info!(target: "stdout", "Prepare the rag embedding request."); + let api_key = match vdb_api_key.is_empty() { + true => None, + false => Some(vdb_api_key), + }; + // create an embedding request let embedding_request = EmbeddingRequest { model: Some(model), input: chunks.into(), encoding_format: None, user: None, - qdrant_url: Some(url_vdb_server), - qdrant_collection_name: Some(collection_name), + vdb_server_url: Some(vdb_server_url), + vdb_collection_name: Some(vdb_collection_name), + vdb_api_key: api_key, }; match rag_doc_chunks_to_embeddings(&embedding_request).await { @@ -2122,8 +2170,8 @@ async fn get_qdrant_configs( chat_request: &ChatCompletionRequest, ) -> Result, error::ServerError> { match ( - chat_request.url_vdb_server.as_deref(), - chat_request.collection_name.as_deref(), + chat_request.vdb_server_url.as_deref(), + chat_request.vdb_collection_name.as_deref(), chat_request.limit.as_deref(), chat_request.score_threshold.as_deref(), ) {