From 8d0d5297449ffab40c1b4946b53b8a3d089d88a6 Mon Sep 17 00:00:00 2001 From: chasingegg Date: Thu, 9 Jan 2025 17:25:33 +0800 Subject: [PATCH] enhance: support recalls for milvus_client Signed-off-by: chasingegg --- pymilvus/client/types.py | 14 +++++++++++--- pymilvus/milvus_client/async_milvus_client.py | 2 +- pymilvus/milvus_client/milvus_client.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pymilvus/client/types.py b/pymilvus/client/types.py index 2916d41f8..c5283242a 100644 --- a/pymilvus/client/types.py +++ b/pymilvus/client/types.py @@ -967,15 +967,23 @@ class ExtraList(list): ExtraList([1, 2, 3], extra={"total": 3}) """ - def __init__(self, *args, extra: Optional[Dict] = None, **kwargs) -> None: + def __init__( + self, *args, extra: Optional[Dict] = None, recalls: Optional[List[float]] = None, **kwargs + ) -> None: super().__init__(*args, **kwargs) self.extra = OmitZeroDict(extra or {}) + self.recalls = recalls def __str__(self) -> str: """Only print at most 10 query results""" + recall_msg = ( + f", recalls: {list(map(str, self.recalls))}" + if self.recalls is not None and len(self.recalls) > 0 + else "" + ) if self.extra and self.extra.omit_zero_len() != 0: - return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}, extra_info: {self.extra}" - return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}" + return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}{recall_msg}, extra_info: {self.extra}" + return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}{recall_msg}" __repr__ = __str__ diff --git a/pymilvus/milvus_client/async_milvus_client.py b/pymilvus/milvus_client/async_milvus_client.py index 2544caeb0..f95ade792 100644 --- a/pymilvus/milvus_client/async_milvus_client.py +++ b/pymilvus/milvus_client/async_milvus_client.py @@ -400,7 +400,7 @@ async def search( query_result.append(hit.to_dict()) ret.append(query_result) - return ExtraList(ret, extra=construct_cost_extra(res.cost)) + return ExtraList(ret, extra=construct_cost_extra(res.cost), recalls=res.recalls) async def query( self, diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index 8535036e6..483bc029c 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -421,7 +421,7 @@ def search( query_result.append(hit.to_dict()) ret.append(query_result) - return ExtraList(ret, extra=construct_cost_extra(res.cost)) + return ExtraList(ret, extra=construct_cost_extra(res.cost), recalls=res.recalls) def query( self,