From 0e4a3ea2c365cb7f47be381d8b42e16f8a5dc590 Mon Sep 17 00:00:00 2001
From: Olivier Desenfans <desenfans.olivier@gmail.com>
Date: Fri, 3 Nov 2023 12:13:26 +0100
Subject: [PATCH] Fix: warning on failure to close node cache properly (#514)

Problem: a warning occurs in the tests because of an improper cleanup of
the Redis client object.

Solution: make the NodeCache class an asynchronous context manager and
make it clean up after itself.
---
 src/aleph/api_entrypoint.py                |   3 +
 src/aleph/commands.py                      | 103 +++++++++++----------
 src/aleph/jobs/fetch_pending_messages.py   |  72 +++++++-------
 src/aleph/jobs/process_pending_messages.py |  94 ++++++++++---------
 src/aleph/jobs/process_pending_txs.py      |  72 +++++++-------
 src/aleph/services/cache/node_cache.py     |  32 ++++++-
 tests/conftest.py                          |   6 +-
 7 files changed, 210 insertions(+), 172 deletions(-)

diff --git a/src/aleph/api_entrypoint.py b/src/aleph/api_entrypoint.py
index 394dedba5..deaf03d48 100644
--- a/src/aleph/api_entrypoint.py
+++ b/src/aleph/api_entrypoint.py
@@ -45,6 +45,9 @@ async def configure_aiohttp_app(
         node_cache = NodeCache(
             redis_host=config.redis.host.value, redis_port=config.redis.port.value
         )
+        # TODO: find a way to close the node cache when exiting the API process, not closing it causes
+        #       a warning.
+        await node_cache.open()
 
         ipfs_client = make_ipfs_client(config)
         ipfs_service = IpfsService(ipfs_client=ipfs_client)
diff --git a/src/aleph/commands.py b/src/aleph/commands.py
index b861f483c..d82276bd8 100644
--- a/src/aleph/commands.py
+++ b/src/aleph/commands.py
@@ -62,9 +62,6 @@ async def init_node_cache(config: Config) -> NodeCache:
     node_cache = NodeCache(
         redis_host=config.redis.host.value, redis_port=config.redis.port.value
     )
-
-    # Reset the cache
-    await node_cache.reset()
     return node_cache
 
 
@@ -135,64 +132,68 @@ async def main(args: List[str]) -> None:
     mq_channel = await mq_conn.channel()
 
     node_cache = await init_node_cache(config)
-    ipfs_service = IpfsService(ipfs_client=make_ipfs_client(config))
-    storage_service = StorageService(
-        storage_engine=FileSystemStorageEngine(folder=config.storage.folder.value),
-        ipfs_service=ipfs_service,
-        node_cache=node_cache,
-    )
-    chain_data_service = ChainDataService(
-        session_factory=session_factory,
-        storage_service=storage_service,
-    )
-    pending_tx_publisher = await PendingTxPublisher.new(config=config)
-    chain_connector = ChainConnector(
-        session_factory=session_factory,
-        pending_tx_publisher=pending_tx_publisher,
-        chain_data_service=chain_data_service,
-    )
+    async with node_cache:
+        # Reset the cache
+        await node_cache.reset()
 
-    set_start_method("spawn")
+        ipfs_service = IpfsService(ipfs_client=make_ipfs_client(config))
+        storage_service = StorageService(
+            storage_engine=FileSystemStorageEngine(folder=config.storage.folder.value),
+            ipfs_service=ipfs_service,
+            node_cache=node_cache,
+        )
+        chain_data_service = ChainDataService(
+            session_factory=session_factory,
+            storage_service=storage_service,
+        )
+        pending_tx_publisher = await PendingTxPublisher.new(config=config)
+        chain_connector = ChainConnector(
+            session_factory=session_factory,
+            pending_tx_publisher=pending_tx_publisher,
+            chain_data_service=chain_data_service,
+        )
+
+        set_start_method("spawn")
 
-    tasks: List[Coroutine] = []
+        tasks: List[Coroutine] = []
 
-    if not args.no_jobs:
-        LOGGER.debug("Creating jobs")
-        tasks += start_jobs(
+        if not args.no_jobs:
+            LOGGER.debug("Creating jobs")
+            tasks += start_jobs(
+                config=config,
+                session_factory=session_factory,
+                ipfs_service=ipfs_service,
+                use_processes=True,
+            )
+
+        LOGGER.debug("Initializing p2p")
+        p2p_client, p2p_tasks = await p2p.init_p2p(
             config=config,
             session_factory=session_factory,
+            service_name="network-monitor",
             ipfs_service=ipfs_service,
-            use_processes=True,
+            node_cache=node_cache,
         )
+        tasks += p2p_tasks
+        LOGGER.debug("Initialized p2p")
 
-    LOGGER.debug("Initializing p2p")
-    p2p_client, p2p_tasks = await p2p.init_p2p(
-        config=config,
-        session_factory=session_factory,
-        service_name="network-monitor",
-        ipfs_service=ipfs_service,
-        node_cache=node_cache,
-    )
-    tasks += p2p_tasks
-    LOGGER.debug("Initialized p2p")
-
-    LOGGER.debug("Initializing listeners")
-    tasks += await listener_tasks(
-        config=config,
-        session_factory=session_factory,
-        node_cache=node_cache,
-        p2p_client=p2p_client,
-        mq_channel=mq_channel,
-    )
-    tasks.append(chain_connector.chain_event_loop(config))
-    LOGGER.debug("Initialized listeners")
+        LOGGER.debug("Initializing listeners")
+        tasks += await listener_tasks(
+            config=config,
+            session_factory=session_factory,
+            node_cache=node_cache,
+            p2p_client=p2p_client,
+            mq_channel=mq_channel,
+        )
+        tasks.append(chain_connector.chain_event_loop(config))
+        LOGGER.debug("Initialized listeners")
 
-    LOGGER.debug("Initializing cache tasks")
-    tasks.append(refresh_cache_materialized_views(session_factory))
-    LOGGER.debug("Initialized cache tasks")
+        LOGGER.debug("Initializing cache tasks")
+        tasks.append(refresh_cache_materialized_views(session_factory))
+        LOGGER.debug("Initialized cache tasks")
 
-    LOGGER.debug("Running event loop")
-    await asyncio.gather(*tasks)
+        LOGGER.debug("Running event loop")
+        await asyncio.gather(*tasks)
 
 
 def run():
diff --git a/src/aleph/jobs/fetch_pending_messages.py b/src/aleph/jobs/fetch_pending_messages.py
index d8e4ecd45..bc3e4d6db 100644
--- a/src/aleph/jobs/fetch_pending_messages.py
+++ b/src/aleph/jobs/fetch_pending_messages.py
@@ -175,46 +175,46 @@ async def fetch_messages_task(config: Config):
         config=config, routing_key="fetch.*", channel=mq_channel
     )
 
-    node_cache = NodeCache(
+    async with NodeCache(
         redis_host=config.redis.host.value, redis_port=config.redis.port.value
-    )
-    ipfs_client = make_ipfs_client(config)
-    ipfs_service = IpfsService(ipfs_client=ipfs_client)
-    storage_service = StorageService(
-        storage_engine=FileSystemStorageEngine(folder=config.storage.folder.value),
-        ipfs_service=ipfs_service,
-        node_cache=node_cache,
-    )
-    signature_verifier = SignatureVerifier()
-    message_handler = MessageHandler(
-        signature_verifier=signature_verifier,
-        storage_service=storage_service,
-        config=config,
-    )
-    fetcher = PendingMessageFetcher(
-        session_factory=session_factory,
-        message_handler=message_handler,
-        max_retries=config.aleph.jobs.pending_messages.max_retries.value,
-        pending_message_queue=pending_message_queue,
-    )
+    ) as node_cache:
+        ipfs_client = make_ipfs_client(config)
+        ipfs_service = IpfsService(ipfs_client=ipfs_client)
+        storage_service = StorageService(
+            storage_engine=FileSystemStorageEngine(folder=config.storage.folder.value),
+            ipfs_service=ipfs_service,
+            node_cache=node_cache,
+        )
+        signature_verifier = SignatureVerifier()
+        message_handler = MessageHandler(
+            signature_verifier=signature_verifier,
+            storage_service=storage_service,
+            config=config,
+        )
+        fetcher = PendingMessageFetcher(
+            session_factory=session_factory,
+            message_handler=message_handler,
+            max_retries=config.aleph.jobs.pending_messages.max_retries.value,
+            pending_message_queue=pending_message_queue,
+        )
 
-    async with fetcher:
-        while True:
-            try:
-                fetch_pipeline = fetcher.make_pipeline(
-                    config=config, node_cache=node_cache
-                )
-                async for fetched_messages in fetch_pipeline:
-                    for fetched_message in fetched_messages:
-                        LOGGER.info(
-                            "Successfully fetched %s", fetched_message.item_hash
-                        )
+        async with fetcher:
+            while True:
+                try:
+                    fetch_pipeline = fetcher.make_pipeline(
+                        config=config, node_cache=node_cache
+                    )
+                    async for fetched_messages in fetch_pipeline:
+                        for fetched_message in fetched_messages:
+                            LOGGER.info(
+                                "Successfully fetched %s", fetched_message.item_hash
+                            )
 
-            except Exception:
-                LOGGER.exception("Unexpected error in pending messages fetch job")
+                except Exception:
+                    LOGGER.exception("Unexpected error in pending messages fetch job")
 
-            LOGGER.debug("Waiting 1 second(s) for new pending messages...")
-            await asyncio.sleep(1)
+                LOGGER.debug("Waiting 1 second(s) for new pending messages...")
+                await asyncio.sleep(1)
 
 
 def fetch_pending_messages_subprocess(config_values: Dict):
diff --git a/src/aleph/jobs/process_pending_messages.py b/src/aleph/jobs/process_pending_messages.py
index f43d23fc3..c825062e7 100644
--- a/src/aleph/jobs/process_pending_messages.py
+++ b/src/aleph/jobs/process_pending_messages.py
@@ -154,54 +154,58 @@ async def fetch_and_process_messages_task(config: Config):
     engine = make_engine(config=config, application_name="aleph-process")
     session_factory = make_session_factory(engine)
 
-    node_cache = NodeCache(
+    async with NodeCache(
         redis_host=config.redis.host.value, redis_port=config.redis.port.value
-    )
-    ipfs_client = make_ipfs_client(config)
-    ipfs_service = IpfsService(ipfs_client=ipfs_client)
-    storage_service = StorageService(
-        storage_engine=FileSystemStorageEngine(folder=config.storage.folder.value),
-        ipfs_service=ipfs_service,
-        node_cache=node_cache,
-    )
-    signature_verifier = SignatureVerifier()
-    message_handler = MessageHandler(
-        signature_verifier=signature_verifier,
-        storage_service=storage_service,
-        config=config,
-    )
-    pending_message_processor = await PendingMessageProcessor.new(
-        session_factory=session_factory,
-        message_handler=message_handler,
-        max_retries=config.aleph.jobs.pending_messages.max_retries.value,
-        mq_host=config.p2p.mq_host.value,
-        mq_port=config.rabbitmq.port.value,
-        mq_username=config.rabbitmq.username.value,
-        mq_password=config.rabbitmq.password.value,
-        message_exchange_name=config.rabbitmq.message_exchange.value,
-        pending_message_exchange_name=config.rabbitmq.pending_message_exchange.value,
-    )
+    ) as node_cache:
+        ipfs_client = make_ipfs_client(config)
+        ipfs_service = IpfsService(ipfs_client=ipfs_client)
+        storage_service = StorageService(
+            storage_engine=FileSystemStorageEngine(folder=config.storage.folder.value),
+            ipfs_service=ipfs_service,
+            node_cache=node_cache,
+        )
+        signature_verifier = SignatureVerifier()
+        message_handler = MessageHandler(
+            signature_verifier=signature_verifier,
+            storage_service=storage_service,
+            config=config,
+        )
+        pending_message_processor = await PendingMessageProcessor.new(
+            session_factory=session_factory,
+            message_handler=message_handler,
+            max_retries=config.aleph.jobs.pending_messages.max_retries.value,
+            mq_host=config.p2p.mq_host.value,
+            mq_port=config.rabbitmq.port.value,
+            mq_username=config.rabbitmq.username.value,
+            mq_password=config.rabbitmq.password.value,
+            message_exchange_name=config.rabbitmq.message_exchange.value,
+            pending_message_exchange_name=config.rabbitmq.pending_message_exchange.value,
+        )
 
-    async with pending_message_processor:
-        while True:
-            with session_factory() as session:
+        async with pending_message_processor:
+            while True:
+                with session_factory() as session:
+                    try:
+                        message_processing_pipeline = (
+                            pending_message_processor.make_pipeline()
+                        )
+                        async for processing_results in message_processing_pipeline:
+                            for result in processing_results:
+                                LOGGER.info(
+                                    "Successfully processed %s", result.item_hash
+                                )
+
+                    except Exception:
+                        LOGGER.exception("Error in pending messages job")
+                        session.rollback()
+
+                LOGGER.info("Waiting for new pending messages...")
+                # We still loop periodically for retried messages as we do not bother sending a message
+                # on the MQ for these.
                 try:
-                    message_processing_pipeline = pending_message_processor.make_pipeline()
-                    async for processing_results in message_processing_pipeline:
-                        for result in processing_results:
-                            LOGGER.info("Successfully processed %s", result.item_hash)
-
-                except Exception:
-                    LOGGER.exception("Error in pending messages job")
-                    session.rollback()
-
-            LOGGER.info("Waiting for new pending messages...")
-            # We still loop periodically for retried messages as we do not bother sending a message
-            # on the MQ for these.
-            try:
-                await asyncio.wait_for(pending_message_processor.ready(), 1)
-            except TimeoutError:
-                pass
+                    await asyncio.wait_for(pending_message_processor.ready(), 1)
+                except TimeoutError:
+                    pass
 
 
 def pending_messages_subprocess(config_values: Dict):
diff --git a/src/aleph/jobs/process_pending_txs.py b/src/aleph/jobs/process_pending_txs.py
index 2903f6baa..77f609377 100644
--- a/src/aleph/jobs/process_pending_txs.py
+++ b/src/aleph/jobs/process_pending_txs.py
@@ -129,45 +129,45 @@ async def handle_txs_task(config: Config):
     )
     pending_tx_queue = await make_pending_tx_queue(config=config, channel=mq_channel)
 
-    node_cache = NodeCache(
+    async with NodeCache(
         redis_host=config.redis.host.value, redis_port=config.redis.port.value
-    )
-    ipfs_client = make_ipfs_client(config)
-    ipfs_service = IpfsService(ipfs_client=ipfs_client)
-    storage_service = StorageService(
-        storage_engine=FileSystemStorageEngine(folder=config.storage.folder.value),
-        ipfs_service=ipfs_service,
-        node_cache=node_cache,
-    )
-    message_publisher = MessagePublisher(
-        session_factory=session_factory,
-        storage_service=storage_service,
-        config=config,
-        pending_message_exchange=pending_message_exchange,
-    )
-    chain_data_service = ChainDataService(
-        session_factory=session_factory, storage_service=storage_service
-    )
-    pending_tx_processor = PendingTxProcessor(
-        session_factory=session_factory,
-        message_publisher=message_publisher,
-        chain_data_service=chain_data_service,
-        pending_tx_queue=pending_tx_queue,
-    )
+    ) as node_cache:
+        ipfs_client = make_ipfs_client(config)
+        ipfs_service = IpfsService(ipfs_client=ipfs_client)
+        storage_service = StorageService(
+            storage_engine=FileSystemStorageEngine(folder=config.storage.folder.value),
+            ipfs_service=ipfs_service,
+            node_cache=node_cache,
+        )
+        message_publisher = MessagePublisher(
+            session_factory=session_factory,
+            storage_service=storage_service,
+            config=config,
+            pending_message_exchange=pending_message_exchange,
+        )
+        chain_data_service = ChainDataService(
+            session_factory=session_factory, storage_service=storage_service
+        )
+        pending_tx_processor = PendingTxProcessor(
+            session_factory=session_factory,
+            message_publisher=message_publisher,
+            chain_data_service=chain_data_service,
+            pending_tx_queue=pending_tx_queue,
+        )
 
-    async with pending_tx_processor:
-        while True:
-            try:
-                await pending_tx_processor.process_pending_txs(
-                    max_concurrent_tasks=max_concurrent_tasks
-                )
-            except Exception:
-                LOGGER.exception("Error in pending txs job")
+        async with pending_tx_processor:
+            while True:
+                try:
+                    await pending_tx_processor.process_pending_txs(
+                        max_concurrent_tasks=max_concurrent_tasks
+                    )
+                except Exception:
+                    LOGGER.exception("Error in pending txs job")
 
-            try:
-                await asyncio.wait_for(pending_tx_processor.ready(), 5)
-            except TimeoutError:
-                pass
+                try:
+                    await asyncio.wait_for(pending_tx_processor.ready(), 5)
+                except TimeoutError:
+                    pass
 
 
 def pending_txs_subprocess(config_values: Dict):
diff --git a/src/aleph/services/cache/node_cache.py b/src/aleph/services/cache/node_cache.py
index df23b653d..24d6d8842 100644
--- a/src/aleph/services/cache/node_cache.py
+++ b/src/aleph/services/cache/node_cache.py
@@ -9,13 +9,41 @@
 class NodeCache:
     API_SERVERS_KEY = "api_servers"
     PUBLIC_ADDRESSES_KEY = "public_addresses"
-    redis_client: redis_asyncio.Redis
 
     def __init__(self, redis_host: str, redis_port: int):
         self.redis_host = redis_host
         self.redis_port = redis_port
 
-        self.redis_client = redis_asyncio.Redis(host=redis_host, port=redis_port)
+        self._redis_client: Optional[redis_asyncio.Redis] = None
+
+
+    @property
+    def redis_client(self) -> redis_asyncio.Redis:
+        if (redis_client := self._redis_client) is None:
+            raise ValueError(
+                "Redis client must be initialized. "
+                f"Call open() first or use `async with {self.__class__.__name__}()`."
+            )
+
+        return redis_client
+
+
+    async def open(self):
+        self._redis_client = redis_asyncio.Redis(
+            host=self.redis_host, port=self.redis_port
+        )
+
+    async def __aenter__(self):
+        await self.open()
+        return self
+
+    async def close(self):
+        if self.redis_client:
+            await self.redis_client.close()
+            self._redis_client = None
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        await self.close()
 
     async def reset(self):
         """
diff --git a/tests/conftest.py b/tests/conftest.py
index b0aa84a99..67b333930 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -109,9 +109,11 @@ def mock_config(mocker):
 
 @pytest_asyncio.fixture
 async def node_cache(mock_config: Config):
-    return NodeCache(
+    async with NodeCache(
         redis_host=mock_config.redis.host.value, redis_port=mock_config.redis.port.value
-    )
+    ) as node_cache:
+        yield node_cache
+
 
 
 @pytest_asyncio.fixture