From 55df22d0e228e117027254bde174c017df6c9794 Mon Sep 17 00:00:00 2001 From: aliel Date: Thu, 21 Nov 2024 23:23:44 +0100 Subject: [PATCH] Feature: Add new endpoint to list message hashes --- src/aleph/db/accessors/messages.py | 55 +++++++++++++++++ src/aleph/schemas/api/messages.py | 5 ++ src/aleph/web/controllers/messages.py | 87 +++++++++++++++++++++++++++ src/aleph/web/controllers/routes.py | 1 + 4 files changed, 148 insertions(+) diff --git a/src/aleph/db/accessors/messages.py b/src/aleph/db/accessors/messages.py index 34b5d6e3a..b5c084ba7 100644 --- a/src/aleph/db/accessors/messages.py +++ b/src/aleph/db/accessors/messages.py @@ -623,3 +623,58 @@ def get_programs_triggered_by_messages(session: DbSession, sort_order: SortOrder ) return session.execute(select_stmt).all() + + +def make_matching_hashes_query( + start_date: Optional[Union[float, dt.datetime]] = None, + end_date: Optional[Union[float, dt.datetime]] = None, + status: Optional[MessageStatus] = None, + sort_order: SortOrder = SortOrder.DESCENDING, + page: int = 1, + pagination: int = 20, + hash_only: bool = True, +) -> Select: + select_stmt = select(MessageStatusDb.item_hash if hash_only else MessageStatusDb) + + start_datetime = coerce_to_datetime(start_date) + end_datetime = coerce_to_datetime(end_date) + + if start_datetime: + select_stmt = select_stmt.where( + MessageStatusDb.reception_time >= start_datetime + ) + if end_datetime: + select_stmt = select_stmt.where(MessageStatusDb.reception_time < end_datetime) + if status: + select_stmt = select_stmt.where(MessageStatusDb.status == status) + + if sort_order == SortOrder.ASCENDING: + select_stmt = select_stmt.order_by(MessageStatusDb.reception_time.asc()) + else: + select_stmt = select_stmt.order_by(MessageStatusDb.reception_time.desc()) + + select_stmt = select_stmt.offset((page - 1) * pagination) + + # If pagination == 0, return all matching results + if pagination: + select_stmt = select_stmt.limit(pagination) + + return select_stmt + + +def get_matching_hashes( + session: DbSession, + **kwargs, # Same as make_matching_hashes_query +): + select_stmt = make_matching_hashes_query(**kwargs) + return (session.execute(select_stmt)).scalars() + + +def count_matching_hashes( + session: DbSession, + pagination: int = 0, + **kwargs, +) -> Select: + select_stmt = make_matching_hashes_query(pagination=0, **kwargs).subquery() + select_count_stmt = select(func.count()).select_from(select_stmt) + return session.execute(select_count_stmt).scalar_one() diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index 8cd6fd1af..dea39f1aa 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -210,6 +210,11 @@ class Config: fields = {"item_hash": {"exclude": True}} +class MessageHashes(BaseMessageStatus): + class Config: + orm_mode = True + + MessageWithStatus = Union[ PendingMessageStatus, ProcessedMessageStatus, diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 961b6c217..b512d47c4 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -10,8 +10,10 @@ import aleph.toolkit.json as aleph_json from aleph.db.accessors.messages import ( + count_matching_hashes, count_matching_messages, get_forgotten_message, + get_matching_hashes, get_matching_messages, get_message_by_item_hash, get_message_status, @@ -23,6 +25,7 @@ AlephMessage, ForgottenMessage, ForgottenMessageStatus, + MessageHashes, MessageStatusInfo, MessageWithStatus, PendingMessage, @@ -195,6 +198,57 @@ class WsMessageQueryParams(BaseMessageQueryParams): ) +class MessageHashesQueryParams(BaseModel): + status: Optional[MessageStatus] = Field( + default=None, + description="Message status.", + ) + page: int = Field( + default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1." + ) + pagination: int = Field( + default=DEFAULT_MESSAGES_PER_PAGE, + ge=0, + description="Maximum number of messages to return. Specifying 0 removes this limit.", + ) + start_date: float = Field( + default=0, + ge=0, + alias="startDate", + description="Start date timestamp. If specified, only messages with " + "a time field greater or equal to this value will be returned.", + ) + end_date: float = Field( + default=0, + ge=0, + alias="endDate", + description="End date timestamp. If specified, only messages with " + "a time field lower than this value will be returned.", + ) + sort_order: SortOrder = Field( + default=SortOrder.DESCENDING, + alias="sortOrder", + description="Order in which messages should be listed: " + "-1 means most recent messages first, 1 means older messages first.", + ) + hash_only: bool = Field( + default=True, + description="By default, only hashes are returned. " + "Set this to false to include metadata alongside the hashes in the response.", + ) + + @root_validator + def validate_field_dependencies(cls, values): + start_date = values.get("start_date") + end_date = values.get("end_date") + if start_date and end_date and (end_date < start_date): + raise ValueError("end date cannot be lower than start date.") + return values + + class Config: + allow_population_by_field_name = True + + def message_to_dict(message: MessageDb) -> Dict[str, Any]: message_dict = message.to_dict() message_dict["time"] = message.time.timestamp() @@ -585,3 +639,36 @@ async def view_message_status(request: web.Request): status_info = MessageStatusInfo.from_orm(message_status) return web.json_response(text=status_info.json()) + + +async def view_message_hashes(request: web.Request): + print("run once?") + try: + query_params = MessageHashesQueryParams.parse_obj(request.query) + except ValidationError as e: + raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) + + find_filters = query_params.dict(exclude_none=True) + + pagination_page = query_params.page + pagination_per_page = query_params.pagination + + session_factory = get_session_factory_from_request(request) + with session_factory() as session: + hashes = get_matching_hashes(session, **find_filters) + + if find_filters["hash_only"]: + formatted_hashes = [h for h in hashes] + else: + formatted_hashes = [MessageHashes.from_orm(h) for h in hashes] + + total_hashes = count_matching_hashes(session, **find_filters) + response = { + "hashes": formatted_hashes, + "pagination_per_page": pagination_per_page, + "pagination_page": pagination_page, + "pagination_total": total_hashes, + "pagination_item": "hashes", + } + + return web.json_response(text=aleph_json.dumps(response).decode("utf-8")) diff --git a/src/aleph/web/controllers/routes.py b/src/aleph/web/controllers/routes.py index f58bfbb55..97cd87c54 100644 --- a/src/aleph/web/controllers/routes.py +++ b/src/aleph/web/controllers/routes.py @@ -46,6 +46,7 @@ def register_routes(app: web.Application): # Note that this endpoint is implemented in the p2p module out of simplicity because # of the large amount of code shared with pub_json. app.router.add_post("/api/v0/messages", p2p.pub_message) + app.router.add_get("/api/v0/messages/hashes", messages.view_message_hashes) app.router.add_get("/api/v0/messages/{item_hash}", messages.view_message) app.router.add_get( "/api/v0/messages/{item_hash}/content", messages.view_message_content