Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LONK-WS: prevent duplicate subscription and implement unsubscription #576

Merged
merged 1 commit into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions chris_backend/pacsfiles/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
validate_subscription,
LonkWsSubscription,
Lonk,
validate_unsubscription,
)
from pacsfiles.permissions import IsChrisOrIsPACSUserReadOnly

Expand Down Expand Up @@ -37,6 +38,9 @@ async def receive_json(self, content, **kwargs):
content['pacs_name'], content['SeriesInstanceUID']
)
return
if validate_unsubscription(content):
await self._unsubscribe_all()
return
await self.close(code=400, reason='Invalid subscription')

async def _subscribe(self, pacs_name: str, series_instance_uid: str):
Expand All @@ -63,6 +67,17 @@ async def _subscribe(self, pacs_name: str, series_instance_uid: str):
await self.close(code=500)
raise e

async def _unsubscribe_all(self):
"""
Unsubscribe from *all* series notifications.
"""
try:
await self.client.unsubscribe_all()
await self.send_json({'message': {'subscribed': False}})
except Exception as e:
await self.close(code=500)
raise e

async def disconnect(self, code):
await super().disconnect(code)
await self.client.close()
Expand Down
30 changes: 27 additions & 3 deletions chris_backend/pacsfiles/lonk.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ class SubscriptionRequest(TypedDict):
action: Literal['subscribe']


class UnsubscriptionRequest(TypedDict):
"""
A request to unsubscribe from *all* series notifications.
"""

action: Literal['unsubscribe']


def validate_subscription(data: Any) -> TypeGuard[SubscriptionRequest]:
if not isinstance(data, dict):
return False
Expand All @@ -44,6 +52,12 @@ def validate_subscription(data: Any) -> TypeGuard[SubscriptionRequest]:
)


def validate_unsubscription(data: Any) -> TypeGuard[UnsubscriptionRequest]:
if not isinstance(data, dict):
return False
return data.get('action', None) == 'unsubscribe'


class LonkProgress(TypedDict):
"""
LONK "done" message.
Expand Down Expand Up @@ -114,7 +128,7 @@ class LonkClient:

def __init__(self, nc: NATS):
self._nc = nc
self._subscriptions: list[Subscription] = []
self._subscriptions: dict[str, Subscription] = {}

@classmethod
async def connect(cls, servers: str | list[str]) -> Self:
Expand All @@ -127,13 +141,23 @@ async def subscribe(
cb: Callable[[Lonk], Awaitable[None]],
):
subject = subject_of(pacs_name, series_instance_uid)
if (
subscription := self._subscriptions.get(subject, None)
) is not None:
return subscription # already subscribed
cb = _curry_message2json(pacs_name, series_instance_uid, cb)
subscription = await self._nc.subscribe(subject, cb=cb)
self._subscriptions.append(subscription)
self._subscriptions[subscription.subject] = subscription
return subscription

async def unsubscribe_all(self):
await asyncio.gather(
*(s.unsubscribe() for s in self._subscriptions.values())
)
self._subscriptions = {}

async def close(self):
await asyncio.gather(*(s.unsubscribe() for s in self._subscriptions))
await self.unsubscribe_all()
await self._nc.close()


Expand Down
79 changes: 70 additions & 9 deletions chris_backend/pacsfiles/tests/test_consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
LonkProgress,
LonkDone,
LonkError,
UnsubscriptionRequest,
)
from pacsfiles.consumers import PACSFileProgress
from pacsfiles.tests.mocks import Mockidicom
Expand All @@ -39,15 +40,7 @@ def setUp(self):

@tag('integration')
async def test_lonk_ws(self):
token = await self._get_download_token()
app = TokenQsAuthMiddleware(PACSFileProgress.as_asgi())
communicator = WebsocketCommunicator(
app, f'v1/pacs/ws/?token={token.token}'
)
connected, subprotocol = await communicator.connect()
assert connected

oxidicom: Mockidicom = await Mockidicom.connect(settings.NATS_ADDRESS)
communicator, oxidicom = await self.connect()

series1 = {'pacs_name': 'MyPACS', 'SeriesInstanceUID': '1.234.567890'}
subscription_request = SubscriptionRequest(
Expand Down Expand Up @@ -103,6 +96,74 @@ async def test_lonk_ws(self):
Lonk(message=LonkDone(done=True), **series1),
)

@tag('integration')
async def test_unsubscribe(self):
"""
https://chrisproject.org/docs/oxidicom/lonk-ws#unsubscribe
"""
communicator, oxidicom = await self.connect()

series1 = {
'pacs_name': 'MyPACSUnsub',
'SeriesInstanceUID': '1.234.567890',
}
subscription_request = SubscriptionRequest(
action='subscribe', **series1
)
await communicator.send_json_to(subscription_request)
self.assertEqual(
await communicator.receive_json_from(),
Lonk(
message=LonkWsSubscription(subscribed=True),
**series1,
),
)

unsubscription_request = UnsubscriptionRequest(action='unsubscribe')
await communicator.send_json_to(unsubscription_request)
self.assertEqual(
await communicator.receive_json_from(),
{'message': {'subscribed': False}},
)

series2 = {
'pacs_name': 'MyPACSUnsub',
'SeriesInstanceUID': '5.678.90123',
}
subscription_request = SubscriptionRequest(
action='subscribe', **series2
)
await communicator.send_json_to(subscription_request)
self.assertEqual(
await communicator.receive_json_from(),
Lonk(
message=LonkWsSubscription(subscribed=True),
**series2,
),
)

await oxidicom.send_progress(ndicom=1, **series1)
await oxidicom.send_progress(ndicom=2, **series2)
self.assertEqual(
await communicator.receive_json_from(),
Lonk(
message=LonkProgress(ndicom=2),
**series2, # unsubscribed from series1, should not be a message for it
),
)

async def connect(self) -> tuple[WebsocketCommunicator, Mockidicom]:
token = await self._get_download_token()
app = TokenQsAuthMiddleware(PACSFileProgress.as_asgi())
communicator = WebsocketCommunicator(
app, f'v1/pacs/ws/?token={token.token}'
)
connected, subprotocol = await communicator.connect()
assert connected

oxidicom: Mockidicom = await Mockidicom.connect(settings.NATS_ADDRESS)
return communicator, oxidicom

async def test_unauthenticated_not_connected(self):
app = TokenQsAuthMiddleware(PACSFileProgress.as_asgi())
communicator = WebsocketCommunicator(app, 'v1/pacs/ws/') # no token
Expand Down
Loading