Skip to content

Commit

Permalink
enhance: support recalls for milvus_client
Browse files Browse the repository at this point in the history
Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg committed Jan 10, 2025
1 parent c70d44c commit c3dad23
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 6 deletions.
8 changes: 7 additions & 1 deletion pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,13 @@ def __iter__(self) -> SequenceIterator:
def __str__(self) -> str:
"""Only print at most 10 query results"""
reminder = f" ... and {len(self) - 10} results remaining" if len(self) > 10 else ""
recall_msg = f", recalls: {list(map(str, self.recalls))}" if len(self.recalls) > 0 else ""
recall_msg = (
f", recalls: {list(map(str, self.recalls[:10]))}" if len(self.recalls) > 0 else ""
) + (
f" ... and {len(self.recalls) - 10} recall results remaining"
if len(self.recalls) > 10
else ""
)
cost_msg = f", cost: {self.cost}" if self.cost else ""
return f"data: {list(map(str, self[:10]))}{reminder}{recall_msg}{cost_msg}"

Expand Down
18 changes: 15 additions & 3 deletions pymilvus/client/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,15 +967,27 @@ class ExtraList(list):
ExtraList([1, 2, 3], extra={"total": 3})
"""

def __init__(self, *args, extra: Optional[Dict] = None, **kwargs) -> None:
def __init__(
self, *args, extra: Optional[Dict] = None, recalls: Optional[List[float]] = None, **kwargs
) -> None:
super().__init__(*args, **kwargs)
self.extra = OmitZeroDict(extra or {})
self.recalls = recalls

def __str__(self) -> str:
"""Only print at most 10 query results"""
recall_msg = (
f", recalls: {list(map(str, self.recalls[:10]))}"
if self.recalls is not None and len(self.recalls) > 0
else ""
) + (
f" ... and {len(self.recalls) - 10} recall results remaining"
if self.recalls is not None and len(self.recalls) > 10
else ""
)
if self.extra and self.extra.omit_zero_len() != 0:
return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}, extra_info: {self.extra}"
return f"data: {list(map(str, self[:10]))} {'...' if len(self) > 10 else ''}"
return f"data: {list(map(str, self[:10]))}{' ...' if len(self) > 10 else ''}{recall_msg}, extra_info: {self.extra}"
return f"data: {list(map(str, self[:10]))}{' ...' if len(self) > 10 else ''}{recall_msg}"

__repr__ = __str__

Expand Down
2 changes: 1 addition & 1 deletion pymilvus/milvus_client/async_milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ async def search(
query_result.append(hit.to_dict())
ret.append(query_result)

return ExtraList(ret, extra=construct_cost_extra(res.cost))
return ExtraList(ret, extra=construct_cost_extra(res.cost), recalls=res.recalls)

async def query(
self,
Expand Down
2 changes: 1 addition & 1 deletion pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def search(
query_result.append(hit.to_dict())
ret.append(query_result)

return ExtraList(ret, extra=construct_cost_extra(res.cost))
return ExtraList(ret, extra=construct_cost_extra(res.cost), recalls=res.recalls)

def query(
self,
Expand Down

0 comments on commit c3dad23

Please sign in to comment.