Skip to content

Commit

Permalink
Feature: Add new endpoint to list message hashes
Browse files Browse the repository at this point in the history
  • Loading branch information
aliel committed Dec 16, 2024
1 parent 5eb61c6 commit 6e7166a
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/aleph/db/accessors/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,69 @@ def get_programs_triggered_by_messages(session: DbSession, sort_order: SortOrder
)

return session.execute(select_stmt).all()


def make_matching_hashes_query(
start_date: Optional[Union[float, dt.datetime]] = None,
end_date: Optional[Union[float, dt.datetime]] = None,
status: Optional[MessageStatus] = None,
sort_order: SortOrder = SortOrder.DESCENDING,
page: int = 1,
pagination: int = 20,
hash_only: bool = True
) -> Select:
select_stmt = select(
MessageStatusDb.item_hash if hash_only else MessageStatusDb
)

start_datetime = coerce_to_datetime(start_date)
end_datetime = coerce_to_datetime(end_date)

if start_datetime:
select_stmt = select_stmt.where(
MessageStatusDb.reception_time >= start_datetime
)
if end_datetime:
select_stmt = select_stmt.where(
MessageStatusDb.reception_time < end_datetime
)
if status:
select_stmt = select_stmt.where(MessageStatusDb.status == status)

if sort_order == SortOrder.ASCENDING:
select_stmt = select_stmt.order_by(
MessageStatusDb.reception_time.asc()
)
else:
select_stmt = select_stmt.order_by(
MessageStatusDb.reception_time.desc()
)

select_stmt = select_stmt.offset((page - 1) * pagination)

# If pagination == 0, return all matching results
if pagination:
select_stmt = select_stmt.limit(pagination)

return select_stmt


def get_matching_hashes(
session: DbSession,
**kwargs, # Same as make_matching_hashes_query
):
select_stmt = make_matching_hashes_query(**kwargs)
return (session.execute(select_stmt)).scalars()


def count_matching_hashes(
session: DbSession,
pagination: int = 0,
**kwargs,
) -> Select:
select_stmt = make_matching_hashes_query(
pagination=0,
**kwargs
).subquery()
select_count_stmt = select(func.count()).select_from(select_stmt)
return session.execute(select_count_stmt).scalar_one()
5 changes: 5 additions & 0 deletions src/aleph/schemas/api/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ class Config:
fields = {"item_hash": {"exclude": True}}


class MessageHashes(BaseMessageStatus):
class Config:
orm_mode = True


MessageWithStatus = Union[
PendingMessageStatus,
ProcessedMessageStatus,
Expand Down
89 changes: 89 additions & 0 deletions src/aleph/web/controllers/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

import aleph.toolkit.json as aleph_json
from aleph.db.accessors.messages import (
count_matching_hashes,
count_matching_messages,
get_forgotten_message,
get_matching_hashes,
get_matching_messages,
get_message_by_item_hash,
get_message_status,
Expand All @@ -23,6 +25,7 @@
AlephMessage,
ForgottenMessage,
ForgottenMessageStatus,
MessageHashes,
MessageStatusInfo,
MessageWithStatus,
PendingMessage,
Expand Down Expand Up @@ -195,6 +198,57 @@ class WsMessageQueryParams(BaseMessageQueryParams):
)


class MessageHashesQueryParams(BaseModel):
status: Optional[MessageStatus] = Field(
default=None,
description="Message status.",
)
page: int = Field(
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
)
pagination: int = Field(
default=DEFAULT_MESSAGES_PER_PAGE,
ge=0,
description="Maximum number of messages to return. Specifying 0 removes this limit.",
)
start_date: float = Field(
default=0,
ge=0,
alias="startDate",
description="Start date timestamp. If specified, only messages with "
"a time field greater or equal to this value will be returned.",
)
end_date: float = Field(
default=0,
ge=0,
alias="endDate",
description="End date timestamp. If specified, only messages with "
"a time field lower than this value will be returned.",
)
sort_order: SortOrder = Field(
default=SortOrder.DESCENDING,
alias="sortOrder",
description="Order in which messages should be listed: "
"-1 means most recent messages first, 1 means older messages first.",
)
hash_only: bool = Field(
default=True,
description="By default, only hashes are returned. "
"Set this to false to include metadata alongside the hashes in the response.",
)

@root_validator
def validate_field_dependencies(cls, values):
start_date = values.get("start_date")
end_date = values.get("end_date")
if start_date and end_date and (end_date < start_date):
raise ValueError("end date cannot be lower than start date.")
return values

class Config:
allow_population_by_field_name = True


def message_to_dict(message: MessageDb) -> Dict[str, Any]:
message_dict = message.to_dict()
message_dict["time"] = message.time.timestamp()
Expand Down Expand Up @@ -576,3 +630,38 @@ async def view_message_status(request: web.Request):

status_info = MessageStatusInfo.from_orm(message_status)
return web.json_response(text=status_info.json())


async def view_message_hashes(request: web.Request):
print("run once?")
try:
query_params = MessageHashesQueryParams.parse_obj(request.query)
except ValidationError as e:
raise web.HTTPUnprocessableEntity(text=e.json(indent=4))

find_filters = query_params.dict(exclude_none=True)

pagination_page = query_params.page
pagination_per_page = query_params.pagination

session_factory = get_session_factory_from_request(request)
with session_factory() as session:
hashes = get_matching_hashes(session, **find_filters)

if find_filters["hash_only"]:
formatted_hashes = [h for h in hashes]
else:
formatted_hashes = [MessageHashes.from_orm(h) for h in hashes]

total_hashes = count_matching_hashes(session, **find_filters)
response = {
"hashes": formatted_hashes,
"pagination_per_page": pagination_per_page,
"pagination_page": pagination_page,
"pagination_total": total_hashes,
"pagination_item": "hashes",
}

return web.json_response(
text=aleph_json.dumps(response).decode("utf-8")
)
3 changes: 3 additions & 0 deletions src/aleph/web/controllers/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def register_routes(app: web.Application):
# Note that this endpoint is implemented in the p2p module out of simplicity because
# of the large amount of code shared with pub_json.
app.router.add_post("/api/v0/messages", p2p.pub_message)
app.router.add_get(
"/api/v0/messages/hashes", messages.view_message_hashes
)
app.router.add_get("/api/v0/messages/{item_hash}", messages.view_message)
app.router.add_get(
"/api/v0/messages/{item_hash}/content", messages.view_message_content
Expand Down

0 comments on commit 6e7166a

Please sign in to comment.