diff --git a/.github/workflows/_test_futures_private.yaml b/.github/workflows/_test_futures_private.yaml index 95148a03..ca4d966a 100644 --- a/.github/workflows/_test_futures_private.yaml +++ b/.github/workflows/_test_futures_private.yaml @@ -46,7 +46,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest + python -m pip install pytest pytest-mock - name: Install package run: python -m pip install . diff --git a/.github/workflows/_test_futures_public.yaml b/.github/workflows/_test_futures_public.yaml index 0de24810..01cf174c 100644 --- a/.github/workflows/_test_futures_public.yaml +++ b/.github/workflows/_test_futures_public.yaml @@ -36,7 +36,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest + python -m pip install pytest pytest-mock - name: Install package run: python -m pip install . diff --git a/.github/workflows/_test_spot_private.yaml b/.github/workflows/_test_spot_private.yaml index 8b330d81..b61c0253 100644 --- a/.github/workflows/_test_spot_private.yaml +++ b/.github/workflows/_test_spot_private.yaml @@ -43,7 +43,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest + python -m pip install pytest pytest-mock - name: Install package run: python -m pip install . diff --git a/.github/workflows/_test_spot_public.yaml b/.github/workflows/_test_spot_public.yaml index f5a37f43..82cd5ae5 100644 --- a/.github/workflows/_test_spot_public.yaml +++ b/.github/workflows/_test_spot_public.yaml @@ -36,7 +36,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest + python -m pip install pytest pytest-mock - name: Install package run: python -m pip install . diff --git a/examples/futures_trading_bot_template.py b/examples/futures_trading_bot_template.py index 7d2eaecb..c69f034d 100644 --- a/examples/futures_trading_bot_template.py +++ b/examples/futures_trading_bot_template.py @@ -88,7 +88,7 @@ async def on_message(self: TradingBot, message: list | dict) -> None: # Add more functions to customize the trading strategy … - def save_exit(self: TradingBot, reason: Optional[str] = "") -> None: + def save_exit(self: TradingBot, reason: str = "") -> None: """Controlled shutdown of the strategy""" logging.warning("Save exit triggered, reason: %s", reason) # some ideas: diff --git a/kraken/spot/websocket/connectors.py b/kraken/spot/websocket/connectors.py index 25aadce7..2f9b29d0 100644 --- a/kraken/spot/websocket/connectors.py +++ b/kraken/spot/websocket/connectors.py @@ -20,7 +20,7 @@ from copy import deepcopy from random import random from time import time -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Final, Optional import websockets @@ -176,6 +176,10 @@ async def __run_forever(self: ConnectSpotWebsocketBase) -> None: ) self.__client.exception_occur = True + async def close_connection(self: ConnectSpotWebsocketBase) -> None: + """Closes the websocket connection and thus forces a reconnect""" + await self.socket.close() + async def __reconnect(self: ConnectSpotWebsocketBase) -> None: """ Handles the reconnect - before starting the connection and after an @@ -345,7 +349,7 @@ async def _recover_subscriptions( :type event: asyncio.Event """ log_msg: str = f'Recover {"authenticated" if self.is_auth else "public"} subscriptions {self._subscriptions}' - self.LOG.info("%s waiting.", log_msg) + self.LOG.info("%s: waiting", log_msg) await event.wait() for sub in self._subscriptions: @@ -362,7 +366,7 @@ async def _recover_subscriptions( await self.client.send_message(cpy, private=private) self.LOG.info("%s: OK", sub) - self.LOG.info("%s done.", log_msg) + self.LOG.info("%s: done", log_msg) def _manage_subscriptions( self: ConnectSpotWebsocketV1, @@ -501,14 +505,14 @@ async def _recover_subscriptions( :type event: asyncio.Event """ log_msg: str = f'Recover {"authenticated" if self.is_auth else "public"} subscriptions {self._subscriptions}' - self.LOG.info("%s waiting.", log_msg) + self.LOG.info("%s: waiting", log_msg) await event.wait() for subscription in self._subscriptions: await self.client.subscribe(params=subscription) self.LOG.info("%s: OK", subscription) - self.LOG.info("%s done.", log_msg) + self.LOG.info("%s: done", log_msg) def _manage_subscriptions(self: ConnectSpotWebsocketV2, message: dict) -> None: # type: ignore[override] """ @@ -574,24 +578,32 @@ def __transform_subscription( # Without deepcopy, the passed message will be modified, which is *not* # intended. subscription_copy: dict = deepcopy(subscription) - - # Subscriptions for specific symbols must contain the 'symbols' key with - # a value of type list[str]. The python-kraken-sdk is caching active - # subscriptions from that moment, the successful response arrives. These - # responses must be parsed to use them to resubscribe on connection - # losses. - if subscription["result"].get("channel", "") in { - "book", - "ticker", - "ohlc", - "trade", - } and not isinstance( - subscription["result"].get("symbol"), - list, - ): - subscription_copy["result"]["symbol"] = [ - subscription_copy["result"]["symbol"], - ] + channel: Final[str] = subscription["result"].get("channel", "") + + match channel: + case "book" | "ticker" | "ohlc" | "trade": + # Subscriptions for specific symbols must contain the 'symbols' + # key with a value of type list[str]. The python-kraken-sdk is + # caching active subscriptions from that moment, the successful + # response arrives. These responses must be parsed to use them + # to resubscribe on connection losses. + if not isinstance( + subscription["result"].get("symbol"), + list, + ): + subscription_copy["result"]["symbol"] = [ + subscription_copy["result"]["symbol"], + ] + case "executions": + # Kraken somehow responds with this key - but this is not + # accepted when subscribing (Dec 2023). + if subscription_copy["method"] == "unsubscribe": + del subscription_copy["result"]["maxratecount"] + + # Sometimes Kraken responds with hints about deprecation - we don't want + # to save those data as resubscribing would fail for those cases. + if "warnings" in subscription["result"]: + del subscription_copy["result"]["warnings"] return subscription_copy diff --git a/tests/futures/test_futures_websocket.py b/tests/futures/test_futures_websocket.py index 2cbb64b8..e34f4960 100644 --- a/tests/futures/test_futures_websocket.py +++ b/tests/futures/test_futures_websocket.py @@ -287,7 +287,6 @@ async def check_subscriptions() -> None: assert expected in caplog.text -@pytest.mark.wip() @pytest.mark.futures() @pytest.mark.futures_auth() @pytest.mark.futures_websocket() diff --git a/tests/spot/test_spot_websocket_v1.py b/tests/spot/test_spot_websocket_v1.py index 94051c79..10e0af83 100644 --- a/tests/spot/test_spot_websocket_v1.py +++ b/tests/spot/test_spot_websocket_v1.py @@ -18,11 +18,11 @@ output. todo: check also if reqid matches -todo: check recover subscriptions """ from __future__ import annotations +import logging from asyncio import run as asyncio_run from typing import Any @@ -657,7 +657,6 @@ async def execute_edit_order() -> None: @pytest.mark.spot() -@pytest.mark.spot_auth() @pytest.mark.spot_websocket() @pytest.mark.spot_websocket_v1() def test_edit_order_failing_no_connection(caplog: Any) -> None: @@ -851,7 +850,6 @@ async def execute_cancel_after() -> None: @pytest.mark.spot() -@pytest.mark.spot_auth() @pytest.mark.spot_websocket() @pytest.mark.spot_websocket_v1() def test_cancel_all_orders_after_failing_no_connection(caplog: Any) -> None: @@ -877,3 +875,62 @@ async def execute_cancel_all_orders() -> None: "Can't cancel all orders after - Authenticated websocket not connected!" not in caplog.text ) + + +@pytest.mark.spot() +@pytest.mark.spot_auth() +@pytest.mark.spot_websocket() +@pytest.mark.spot_websocket_v1() +def test_reconnect( + spot_api_key: str, + spot_secret_key: str, + caplog: Any, + mocker: Any, +) -> None: + """ + Checks if the reconnect works properly when forcing a closed connection. + """ + caplog.set_level(logging.INFO) + + async def check_reconnect() -> None: + client: SpotWebsocketClientV1TestWrapper = SpotWebsocketClientV1TestWrapper( + key=spot_api_key, + secret=spot_secret_key, + ) + await async_wait(seconds=2) + + await client.subscribe(subscription={"name": "ticker"}, pair=["XBT/USD"]) + await client.subscribe(subscription={"name": "openOrders"}) + await async_wait(seconds=2) + + for obj in (client._priv_conn, client._pub_conn): + mocker.patch.object( + obj, + "_ConnectSpotWebsocketBase__get_reconnect_wait", + return_value=2, + ) + await client._pub_conn.close_connection() + await client._priv_conn.close_connection() + + await async_wait(seconds=5) + + asyncio_run(check_reconnect()) + + for phrase in ( + "Recover public subscriptions []: waiting", + "Recover authenticated subscriptions []: waiting", + "Recover public subscriptions []: done", + "Recover authenticated subscriptions []: done", + "Websocket connected!", + "'event': 'systemStatus', 'status': 'online', 'version': '1.9.1'}", + "'openOrders', 'event': 'subscriptionStatus', 'status': 'subscribed',", + "'channelName': 'ticker', 'event': 'subscriptionStatus', 'pair': 'XBT/USD', 'status': 'subscribed', 'subscription': {'name': 'ticker'}", + "exception=ConnectionClosedOK(Close(code=1000, reason=''), Close(code=1000, reason=''), False)> got an exception sent 1000 (OK); then received 1000 (OK)", + "Recover public subscriptions [{'event': 'subscribe', 'pair': ['XBT/USD'], 'subscription': {'name': 'ticker'}}]: waiting", + "Recover authenticated subscriptions [{'event': 'subscribe', 'subscription': {'name': 'openOrders'}}]: waiting", + "{'event': 'subscribe', 'pair': ['XBT/USD'], 'subscription': {'name': 'ticker'}}: OK", + "{'event': 'subscribe', 'subscription': {'name': 'openOrders'}}: OK", + "Recover public subscriptions [{'event': 'subscribe', 'pair': ['XBT/USD'], 'subscription': {'name': 'ticker'}}]: done", + "Recover authenticated subscriptions [{'event': 'subscribe', 'subscription': {'name': 'openOrders'}}]: done", + ): + assert phrase in caplog.text diff --git a/tests/spot/test_spot_websocket_v2.py b/tests/spot/test_spot_websocket_v2.py index e2015a83..8be98aca 100644 --- a/tests/spot/test_spot_websocket_v2.py +++ b/tests/spot/test_spot_websocket_v2.py @@ -14,11 +14,11 @@ finally the logs are read out and its input is checked for the expected output. -todo: check recover subscriptions """ from __future__ import annotations +import logging from asyncio import run as asyncio_run from copy import deepcopy from typing import Any @@ -285,11 +285,11 @@ async def test_subscription() -> None: await async_wait(seconds=2) asyncio_run(test_subscription()) - - assert ( - '{"method": "subscribe", "req_id": 123456789, "result": {"channel": "executions", "maxratecount": 180, "snapshot": true}, "success": true, "time_in": ' - in caplog.text - ) + for phrase in ( + '{"method": "subscribe", "req_id": 123456789, "result": {"channel": "executions", "maxratecount": 180, "snapshot": true,', # for some reason they provide a "warnings" key + '"success": true, "time_in": ', + ): + assert phrase in caplog.text @pytest.mark.spot() @@ -458,3 +458,63 @@ def test___transform_subscription_no_change() -> None: ) == incoming_subscription ) + + +@pytest.mark.spot() +@pytest.mark.spot_auth() +@pytest.mark.spot_websocket() +@pytest.mark.spot_websocket_v2() +def test_reconnect( + spot_api_key: str, + spot_secret_key: str, + caplog: Any, + mocker: Any, +) -> None: + """ + Checks if the reconnect works properly when forcing a closed connection. + """ + caplog.set_level(logging.INFO) + + async def check_reconnect() -> None: + client: SpotWebsocketClientV2TestWrapper = SpotWebsocketClientV2TestWrapper( + key=spot_api_key, + secret=spot_secret_key, + ) + await async_wait(seconds=2) + + await client.subscribe(params={"channel": "ticker", "symbol": ["BTC/USD"]}) + await client.subscribe(params={"channel": "executions"}) + await async_wait(seconds=2) + + for obj in (client._priv_conn, client._pub_conn): + mocker.patch.object( + obj, + "_ConnectSpotWebsocketBase__get_reconnect_wait", + return_value=2, + ) + await client._pub_conn.close_connection() + await client._priv_conn.close_connection() + + await async_wait(seconds=5) + + asyncio_run(check_reconnect()) + + for phrase in ( + "Recover public subscriptions []: waiting", + "Recover authenticated subscriptions []: waiting", + "Recover public subscriptions []: done", + "Recover authenticated subscriptions []: done", + "Websocket connected!", + '{"channel": "status", "data": [{"api_version": "v2", "connection_id": ', + '"system": "online", "version": "2.0.0"}], "type": "update"}', + '{"method": "subscribe", "result": {"channel": "ticker", "snapshot": true, "symbol": "BTC/USD"}, "success": true,', + '"channel": "ticker", "type": "snapshot", "data": [{"symbol": "BTC/USD", ', + "exception=ConnectionClosedOK(Close(code=1000, reason=''), Close(code=1000, reason=''), False)> got an exception sent 1000 (OK); then received 1000 (OK)", + "Recover public subscriptions [{'channel': 'ticker', 'snapshot': True, 'symbol': ['BTC/USD']}]: waiting", + "Recover public subscriptions [{'channel': 'ticker', 'snapshot': True, 'symbol': ['BTC/USD']}]: done", + "Recover authenticated subscriptions [{'channel': 'executions', 'snapshot': True}]: waiting", + "Recover authenticated subscriptions [{'channel': 'executions', 'snapshot': True}]: done", + ): + assert phrase in caplog.text + + assert '"success": False' not in caplog.text