From 54beec1809df6c8d545afe83fcb5acb2b944d3b7 Mon Sep 17 00:00:00 2001 From: Olivier Desenfans Date: Mon, 23 Oct 2023 13:34:04 +0200 Subject: [PATCH] Internal: TX processing is now event-based Problem: the pending TX processor uses a polling loop to determine which transactions must be processed. This leads to some latency when adding a new transaction when there is no currently pending transaction as the task sleeps while waiting for new transactions. Solution: make the TX processor event-based by using a RabbitMQ exchange + message queue. Upon receiving a new transaction from on-chain, we now publish a RabbitMQ message containing the hash of the transaction. The processor can then find the transaction in the DB. We still store an entry in the pending TXs table for monitoring purposes. --- src/aleph/chains/bsc.py | 8 +- src/aleph/chains/chain_data_service.py | 61 +++++++- src/aleph/chains/connector.py | 17 ++- src/aleph/chains/ethereum.py | 25 ++-- src/aleph/chains/indexer_reader.py | 19 ++- src/aleph/chains/nuls2.py | 16 +- src/aleph/chains/tezos.py | 12 +- src/aleph/commands.py | 11 +- src/aleph/config.py | 2 + src/aleph/db/accessors/pending_txs.py | 16 +- src/aleph/jobs/process_pending_txs.py | 140 +++++++++--------- .../test_process_pending_txs.py | 10 +- 12 files changed, 215 insertions(+), 122 deletions(-) diff --git a/src/aleph/chains/bsc.py b/src/aleph/chains/bsc.py index 1a0a06086..56dff95c3 100644 --- a/src/aleph/chains/bsc.py +++ b/src/aleph/chains/bsc.py @@ -1,8 +1,8 @@ from aleph_message.models import Chain from configmanager import Config -from aleph.chains.chain_data_service import ChainDataService from aleph.chains.abc import ChainReader +from aleph.chains.chain_data_service import PendingTxPublisher from aleph.chains.indexer_reader import AlephIndexerReader from aleph.types.chain_sync import ChainEventType from aleph.types.db_session import DbSessionFactory @@ -10,12 +10,14 @@ class BscConnector(ChainReader): def __init__( - self, session_factory: DbSessionFactory, chain_data_service: ChainDataService + self, + session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, ): self.indexer_reader = AlephIndexerReader( chain=Chain.BSC, session_factory=session_factory, - chain_data_service=chain_data_service, + pending_tx_publisher=pending_tx_publisher, ) async def fetcher(self, config: Config): diff --git a/src/aleph/chains/chain_data_service.py b/src/aleph/chains/chain_data_service.py index 498b472ea..bbc7fa741 100644 --- a/src/aleph/chains/chain_data_service.py +++ b/src/aleph/chains/chain_data_service.py @@ -1,8 +1,10 @@ import asyncio from io import StringIO -from typing import Dict, Optional, List, Any, Mapping, Set, cast, Type, Union +from typing import Dict, Optional, List, Any, Mapping, Set, cast, Type, Union, Self +import aio_pika.abc from aleph_message.models import StoreContent, ItemType, Chain, MessageType +from configmanager import Config from pydantic import ValidationError from aleph.chains.common import LOGGER @@ -36,7 +38,9 @@ class ChainDataService: def __init__( - self, session_factory: DbSessionFactory, storage_service: StorageService + self, + session_factory: DbSessionFactory, + storage_service: StorageService, ): self.session_factory = session_factory self.storage_service = storage_service @@ -215,11 +219,54 @@ async def get_tx_messages( LOGGER.info("%s", error_msg) raise InvalidContent(error_msg) + +async def make_pending_tx_exchange(config: Config) -> aio_pika.abc.AbstractExchange: + mq_conn = await aio_pika.connect_robust( + host=config.rabbitmq.host.value, + port=config.rabbitmq.port.value, + login=config.rabbitmq.username.value, + password=config.rabbitmq.password.value, + ) + channel = await mq_conn.channel() + pending_tx_exchange = await channel.declare_exchange( + name=config.rabbitmq.pending_tx_exchange.value, + type=aio_pika.ExchangeType.TOPIC, + auto_delete=False, + ) + return pending_tx_exchange + + +class PendingTxPublisher: + def __init__(self, pending_tx_exchange: aio_pika.abc.AbstractExchange): + self.pending_tx_exchange = pending_tx_exchange + @staticmethod - async def incoming_chaindata(session: DbSession, tx: ChainTxDb): - """Incoming data from a chain. - Content can be inline of "offchain" through an ipfs hash. - For now, we only add it to the database, it will be processed later. - """ + def add_pending_tx(session: DbSession, tx: ChainTxDb): upsert_chain_tx(session=session, tx=tx) upsert_pending_tx(session=session, tx_hash=tx.hash) + + async def publish_pending_tx(self, tx: ChainTxDb): + message = aio_pika.Message(body=tx.hash.encode("utf-8")) + await self.pending_tx_exchange.publish( + message=message, routing_key=f"{tx.chain.value}.{tx.publisher}.{tx.hash}" + ) + + async def add_and_publish_pending_tx(self, session: DbSession, tx: ChainTxDb): + """ + Add an event published on one of the supported chains. + Adds the tx to the database, creates a pending tx entry in the pending tx table + and publishes a message on the pending tx exchange. + + Note that this function commits changes to the database for consistency + between the DB and the message queue. + """ + self.add_pending_tx(session=session, tx=tx) + session.commit() + await self.publish_pending_tx(tx) + + @classmethod + async def new(cls, config: Config) -> Self: + pending_tx_exchange = await make_pending_tx_exchange(config=config) + return cls( + pending_tx_exchange=pending_tx_exchange, + ) diff --git a/src/aleph/chains/connector.py b/src/aleph/chains/connector.py index 625d0da72..5018dd5bf 100644 --- a/src/aleph/chains/connector.py +++ b/src/aleph/chains/connector.py @@ -5,11 +5,10 @@ from aleph_message.models import Chain from configmanager import Config -from aleph.storage import StorageService from aleph.types.db_session import DbSessionFactory -from .bsc import BscConnector -from .chain_data_service import ChainDataService from .abc import ChainReader, ChainWriter +from .bsc import BscConnector +from .chain_data_service import ChainDataService, PendingTxPublisher from .ethereum import EthereumConnector from .nuls2 import Nuls2Connector from .tezos import TezosConnector @@ -28,9 +27,13 @@ class ChainConnector: writers: Dict[Chain, ChainWriter] def __init__( - self, session_factory: DbSessionFactory, chain_data_service: ChainDataService + self, + session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, + chain_data_service: ChainDataService, ): self._session_factory = session_factory + self.pending_tx_publisher = pending_tx_publisher self._chain_data_service = chain_data_service self.readers = {} @@ -101,13 +104,14 @@ def _register_chains(self): Chain.BSC, BscConnector( session_factory=self._session_factory, - chain_data_service=self._chain_data_service, + pending_tx_publisher=self.pending_tx_publisher, ), ) self._add_chain( Chain.NULS2, Nuls2Connector( session_factory=self._session_factory, + pending_tx_publisher=self.pending_tx_publisher, chain_data_service=self._chain_data_service, ), ) @@ -115,6 +119,7 @@ def _register_chains(self): Chain.ETH, EthereumConnector( session_factory=self._session_factory, + pending_tx_publisher=self.pending_tx_publisher, chain_data_service=self._chain_data_service, ), ) @@ -122,6 +127,6 @@ def _register_chains(self): Chain.TEZOS, TezosConnector( session_factory=self._session_factory, - chain_data_service=self._chain_data_service, + pending_tx_publisher=self.pending_tx_publisher, ), ) diff --git a/src/aleph/chains/ethereum.py b/src/aleph/chains/ethereum.py index e63ceeafe..f2a5812b0 100644 --- a/src/aleph/chains/ethereum.py +++ b/src/aleph/chains/ethereum.py @@ -22,16 +22,16 @@ from aleph.db.accessors.messages import get_unconfirmed_messages from aleph.db.accessors.pending_messages import count_pending_messages from aleph.db.accessors.pending_txs import count_pending_txs +from aleph.db.models.chains import ChainTxDb from aleph.schemas.chains.tx_context import TxContext from aleph.schemas.pending_messages import BasePendingMessage from aleph.toolkit.timestamp import utc_now +from aleph.types.chain_sync import ChainEventType from aleph.types.db_session import DbSessionFactory from aleph.utils import run_in_executor -from .chain_data_service import ChainDataService -from .abc import ChainWriter, Verifier, ChainReader +from .abc import ChainWriter, Verifier +from .chain_data_service import ChainDataService, PendingTxPublisher from .indexer_reader import AlephIndexerReader -from ..db.models import ChainTxDb -from ..types.chain_sync import ChainEventType LOGGER = logging.getLogger("chains.ethereum") CHAIN_NAME = "ETH" @@ -105,15 +105,17 @@ class EthereumConnector(ChainWriter): def __init__( self, session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, chain_data_service: ChainDataService, ): self.session_factory = session_factory + self.pending_tx_publisher = pending_tx_publisher self.chain_data_service = chain_data_service self.indexer_reader = AlephIndexerReader( chain=Chain.ETH, session_factory=session_factory, - chain_data_service=chain_data_service, + pending_tx_publisher=pending_tx_publisher, ) async def get_last_height(self, sync_type: ChainEventType) -> int: @@ -212,7 +214,9 @@ async def _request_transactions( except json.JSONDecodeError: # if it's not valid json, just ignore it... - LOGGER.info("Incoming logic data is not JSON, ignoring. %r" % message) + LOGGER.info( + "Incoming logic data is not JSON, ignoring. %r" % message + ) except Exception: LOGGER.exception("Can't decode incoming logic data %r" % message) @@ -256,7 +260,7 @@ async def fetch_ethereum_sync_events(self, config: Config): ): tx = ChainTxDb.from_sync_tx_context(tx_context=context, tx_data=jdata) with self.session_factory() as session: - await self.chain_data_service.incoming_chaindata( + await self.pending_tx_publisher.add_and_publish_pending_tx( session=session, tx=tx ) session.commit() @@ -313,7 +317,6 @@ async def packer(self, config: Config): gas_price = web3.eth.generate_gas_price() while True: with self.session_factory() as session: - # Wait for sync operations to complete if (count_pending_txs(session=session, chain=Chain.ETH)) or ( count_pending_messages(session=session, chain=Chain.ETH) @@ -344,8 +347,10 @@ async def packer(self, config: Config): LOGGER.info("Chain sync: %d unconfirmed messages") # This function prepares a chain data file and makes it downloadable from the node. - sync_event_payload = await self.chain_data_service.prepare_sync_event_payload( - session=session, messages=messages + sync_event_payload = ( + await self.chain_data_service.prepare_sync_event_payload( + session=session, messages=messages + ) ) # Required to apply update to the files table in get_chaindata session.commit() diff --git a/src/aleph/chains/indexer_reader.py b/src/aleph/chains/indexer_reader.py index 933a4b1f0..a11419873 100644 --- a/src/aleph/chains/indexer_reader.py +++ b/src/aleph/chains/indexer_reader.py @@ -21,7 +21,7 @@ from pydantic import BaseModel import aleph.toolkit.json as aleph_json -from aleph.chains.chain_data_service import ChainDataService +from aleph.chains.chain_data_service import PendingTxPublisher from aleph.db.accessors.chains import ( get_missing_indexer_datetime_multirange, add_indexer_range, @@ -154,7 +154,6 @@ async def fetch_account_state( blockchain: IndexerBlockchain, accounts: List[str], ) -> IndexerAccountStateResponse: - query = make_account_state_query( blockchain=blockchain, accounts=accounts, type_=EntityType.LOG ) @@ -194,7 +193,6 @@ def indexer_event_to_chain_tx( chain: Chain, indexer_event: Union[MessageEvent, SyncEvent], ) -> ChainTxDb: - if isinstance(indexer_event, MessageEvent): protocol = ChainSyncProtocol.SMART_CONTRACT protocol_version = 1 @@ -225,7 +223,6 @@ async def extract_aleph_messages_from_indexer_response( chain: Chain, indexer_response: IndexerEventResponse, ) -> List[ChainTxDb]: - message_events = indexer_response.data.message_events sync_events = indexer_response.data.sync_events @@ -240,7 +237,6 @@ async def extract_aleph_messages_from_indexer_response( class AlephIndexerReader: - BLOCKCHAIN_MAP: Mapping[Chain, IndexerBlockchain] = { Chain.BSC: IndexerBlockchain.BSC, Chain.ETH: IndexerBlockchain.ETHEREUM, @@ -251,11 +247,11 @@ def __init__( self, chain: Chain, session_factory: DbSessionFactory, - chain_data_service: ChainDataService, + pending_tx_publisher: PendingTxPublisher, ): self.chain = chain self.session_factory = session_factory - self.chain_data_service = chain_data_service + self.pending_tx_publisher = pending_tx_publisher self.blockchain = self.BLOCKCHAIN_MAP[chain] @@ -299,9 +295,7 @@ async def fetch_range( LOGGER.info("%d new txs", len(txs)) # Events are listed in reverse order in the indexer response for tx in txs: - await self.chain_data_service.incoming_chaindata( - session=session, tx=tx - ) + self.pending_tx_publisher.add_pending_tx(session=session, tx=tx) if nb_events_fetched >= limit: last_event_datetime = txs[-1].datetime @@ -317,6 +311,7 @@ async def fetch_range( ) else: synced_range = Range(start_datetime, end_datetime, upper_inc=True) + txs = [] LOGGER.info( "%s %s indexer: fetched %s", @@ -336,6 +331,10 @@ async def fetch_range( # of events. session.commit() + # Now that the txs are committed to the DB, add them to the pending tx message queue + for tx in txs: + await self.pending_tx_publisher.publish_pending_tx(tx) + if nb_events_fetched < limit: LOGGER.info( "%s %s event indexer: done fetching events.", diff --git a/src/aleph/chains/nuls2.py b/src/aleph/chains/nuls2.py index 8e08d3a0b..39f9b8ec4 100644 --- a/src/aleph/chains/nuls2.py +++ b/src/aleph/chains/nuls2.py @@ -30,7 +30,7 @@ from aleph.toolkit.timestamp import utc_now from aleph.types.db_session import DbSessionFactory from aleph.utils import run_in_executor -from .chain_data_service import ChainDataService +from .chain_data_service import ChainDataService, PendingTxPublisher from .abc import Verifier, ChainWriter from aleph.schemas.chains.tx_context import TxContext from ..db.models import ChainTxDb @@ -78,9 +78,13 @@ async def verify_signature(self, message: BasePendingMessage) -> bool: class Nuls2Connector(ChainWriter): def __init__( - self, session_factory: DbSessionFactory, chain_data_service: ChainDataService + self, + session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, + chain_data_service: ChainDataService, ): self.session_factory = session_factory + self.pending_tx_publisher = pending_tx_publisher self.chain_data_service = chain_data_service async def get_last_height(self, sync_type: ChainEventType) -> int: @@ -154,7 +158,7 @@ async def fetcher(self, config: Config): tx_context=context, tx_data=jdata ) with self.session_factory() as db_session: - await self.chain_data_service.incoming_chaindata( + await self.pending_tx_publisher.add_and_publish_pending_tx( session=db_session, tx=tx ) db_session.commit() @@ -197,8 +201,10 @@ async def packer(self, config: Config): if len(messages): # This function prepares a chain data file and makes it downloadable from the node. - sync_event_payload = await self.chain_data_service.prepare_sync_event_payload( - session=session, messages=messages + sync_event_payload = ( + await self.chain_data_service.prepare_sync_event_payload( + session=session, messages=messages + ) ) # Required to apply update to the files table in get_chaindata session.commit() diff --git a/src/aleph/chains/tezos.py b/src/aleph/chains/tezos.py index 5e29dd197..279c1cb2c 100644 --- a/src/aleph/chains/tezos.py +++ b/src/aleph/chains/tezos.py @@ -11,9 +11,9 @@ from nacl.exceptions import BadSignatureError import aleph.toolkit.json as aleph_json -from aleph.chains.chain_data_service import ChainDataService -from aleph.chains.common import get_verification_buffer from aleph.chains.abc import Verifier, ChainReader +from aleph.chains.chain_data_service import PendingTxPublisher +from aleph.chains.common import get_verification_buffer from aleph.db.accessors.chains import get_last_height, upsert_chain_sync_status from aleph.db.models import PendingMessageDb, ChainTxDb from aleph.schemas.chains.tezos_indexer_response import ( @@ -248,10 +248,12 @@ async def verify_signature(self, message: BasePendingMessage) -> bool: class TezosConnector(ChainReader): def __init__( - self, session_factory: DbSessionFactory, chain_data_service: ChainDataService + self, + session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, ): self.session_factory = session_factory - self.chain_data_service = chain_data_service + self.pending_tx_publisher = pending_tx_publisher async def get_last_height(self, sync_type: ChainEventType) -> int: """Returns the last height for which we already have the ethereum data.""" @@ -307,7 +309,7 @@ async def fetch_incoming_messages( ) LOGGER.info("%d new txs", len(txs)) for tx in txs: - await self.chain_data_service.incoming_chaindata( + await self.pending_tx_publisher.add_and_publish_pending_tx( session=session, tx=tx ) diff --git a/src/aleph/commands.py b/src/aleph/commands.py index 20cd8e35e..c5a8a750e 100644 --- a/src/aleph/commands.py +++ b/src/aleph/commands.py @@ -23,7 +23,7 @@ from configmanager import Config import aleph.config -from aleph.chains.chain_data_service import ChainDataService +from aleph.chains.chain_data_service import ChainDataService, PendingTxPublisher from aleph.chains.connector import ChainConnector from aleph.cli.args import parse_args from aleph.db.connection import make_engine, make_session_factory, make_db_url @@ -45,7 +45,6 @@ __copyright__ = "Moshe Malawach" __license__ = "mit" - LOGGER = logging.getLogger(__name__) @@ -138,10 +137,14 @@ async def main(args: List[str]) -> None: node_cache=node_cache, ) chain_data_service = ChainDataService( - session_factory=session_factory, storage_service=storage_service + session_factory=session_factory, + storage_service=storage_service, ) + pending_tx_publisher = await PendingTxPublisher.new(config=config) chain_connector = ChainConnector( - session_factory=session_factory, chain_data_service=chain_data_service + session_factory=session_factory, + pending_tx_publisher=pending_tx_publisher, + chain_data_service=chain_data_service, ) set_start_method("spawn") diff --git a/src/aleph/config.py b/src/aleph/config.py index a15ef1240..9cded5516 100644 --- a/src/aleph/config.py +++ b/src/aleph/config.py @@ -174,6 +174,8 @@ def get_defaults(): "sub_exchange": "p2p-subscribe", # Name of the exchange used to publish processed messages (output of the message processor). "message_exchange": "aleph-messages", + "pending_message_exchange": "aleph-pending-messages", + "pending_tx_exchange": "aleph-pending-txs", }, "redis": { # Hostname of the Redis service. diff --git a/src/aleph/db/accessors/pending_txs.py b/src/aleph/db/accessors/pending_txs.py index 3c64f41b0..ada9802fc 100644 --- a/src/aleph/db/accessors/pending_txs.py +++ b/src/aleph/db/accessors/pending_txs.py @@ -1,7 +1,7 @@ from typing import Optional, Iterable from aleph_message.models import Chain -from sqlalchemy import select, func +from sqlalchemy import select, func, delete from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import selectinload @@ -9,6 +9,15 @@ from aleph.types.db_session import DbSession +def get_pending_tx(session: DbSession, tx_hash: str) -> Optional[PendingTxDb]: + select_stmt = ( + select(PendingTxDb) + .where(PendingTxDb.tx_hash == tx_hash) + .options(selectinload(PendingTxDb.tx)) + ) + return (session.execute(select_stmt)).scalar_one_or_none() + + def get_pending_txs(session: DbSession, limit: int = 200) -> Iterable[PendingTxDb]: select_stmt = ( select(PendingTxDb) @@ -33,3 +42,8 @@ def count_pending_txs(session: DbSession, chain: Optional[Chain] = None) -> int: def upsert_pending_tx(session: DbSession, tx_hash: str) -> None: upsert_stmt = insert(PendingTxDb).values(tx_hash=tx_hash).on_conflict_do_nothing() session.execute(upsert_stmt) + + +def delete_pending_tx(session: DbSession, tx_hash: str) -> None: + delete_stmt = delete(PendingTxDb).where(PendingTxDb.tx_hash == tx_hash) + session.execute(delete_stmt) diff --git a/src/aleph/jobs/process_pending_txs.py b/src/aleph/jobs/process_pending_txs.py index 0fd522f72..d90837226 100644 --- a/src/aleph/jobs/process_pending_txs.py +++ b/src/aleph/jobs/process_pending_txs.py @@ -6,14 +6,16 @@ import logging from typing import Dict, Optional, Set +import aio_pika.abc from configmanager import Config from setproctitle import setproctitle -from sqlalchemy import delete from aleph.chains.chain_data_service import ChainDataService -from aleph.db.accessors.pending_txs import get_pending_txs +from aleph.db.accessors.pending_txs import ( + get_pending_tx, + delete_pending_tx, +) from aleph.db.connection import make_engine, make_session_factory -from aleph.db.models.pending_txs import PendingTxDb from aleph.handlers.message_handler import MessagePublisher from aleph.services.cache.node_cache import NodeCache from aleph.services.ipfs.common import make_ipfs_client @@ -34,86 +36,85 @@ class PendingTxProcessor: def __init__( self, session_factory: DbSessionFactory, - storage_service: StorageService, message_publisher: MessagePublisher, + chain_data_service: ChainDataService, + pending_tx_queue: aio_pika.abc.AbstractQueue, ): self.session_factory = session_factory - self.storage_service = storage_service self.message_publisher = message_publisher - self.chain_data_service = ChainDataService( - session_factory=session_factory, storage_service=storage_service - ) + self.chain_data_service = chain_data_service + self.pending_tx_queue = pending_tx_queue async def handle_pending_tx( - self, pending_tx: PendingTxDb, seen_ids: Optional[Set[str]] = None + self, pending_tx_hash: str, seen_ids: Optional[Set[str]] = None ) -> None: - LOGGER.info( - "%s Handling TX in block %s", pending_tx.tx.chain, pending_tx.tx.height - ) - - tx = pending_tx.tx - - # If the chain data file is unavailable, we leave it to the pending tx - # processor to log the content unavailable exception and retry later. - messages = await self.chain_data_service.get_tx_messages( - tx=pending_tx.tx, seen_ids=seen_ids - ) - - if messages: - for i, message_dict in enumerate(messages): - await self.message_publisher.add_pending_message( - message_dict=message_dict, - reception_time=utc_now(), - tx_hash=tx.hash, - check_message=tx.protocol != ChainSyncProtocol.SMART_CONTRACT, - ) - - else: - LOGGER.debug("TX contains no message") - - if messages is not None: - # bogus or handled, we remove it. - with self.session_factory() as session: - session.execute( - delete(PendingTxDb).where( - PendingTxDb.tx_hash == pending_tx.tx_hash - ), - ) + with self.session_factory() as session: + pending_tx = get_pending_tx(session=session, tx_hash=pending_tx_hash) + + if pending_tx is None: + LOGGER.warning("TX %s is not pending anymore", pending_tx_hash) + return + + tx = pending_tx.tx + LOGGER.info("%s Handling TX in block %s", tx.chain, tx.height) + + # If the chain data file is unavailable, we leave it to the pending tx + # processor to log the content unavailable exception and retry later. + messages = await self.chain_data_service.get_tx_messages( + tx=pending_tx.tx, seen_ids=seen_ids + ) + + if messages: + for i, message_dict in enumerate(messages): + await self.message_publisher.add_pending_message( + message_dict=message_dict, + reception_time=utc_now(), + tx_hash=tx.hash, + check_message=tx.protocol != ChainSyncProtocol.SMART_CONTRACT, + ) + + else: + LOGGER.debug("TX contains no message") + + if messages is not None: + # bogus or handled, we remove it. + delete_pending_tx(session=session, tx_hash=pending_tx_hash) session.commit() - async def process_pending_txs(self, max_concurrent_tasks: int): + async def process_pending_txs(self) -> None: """ Process chain transactions in the Pending TX queue. """ - tasks: Set[asyncio.Task] = set() - - seen_offchain_hashes = set() seen_ids: Set[str] = set() LOGGER.info("handling TXs") - with self.session_factory() as session: - for pending_tx in get_pending_txs(session): - # TODO: remove this feature? It doesn't seem necessary. - if pending_tx.tx.protocol == ChainSyncProtocol.OFF_CHAIN_SYNC: - if pending_tx.tx.content in seen_offchain_hashes: - continue - - if len(tasks) == max_concurrent_tasks: - done, tasks = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED + async with self.pending_tx_queue.iterator() as queue_iter: + async for pending_tx_message in queue_iter: + async with pending_tx_message.process(): + pending_tx_hash = pending_tx_message.body.decode("utf-8") + await self.handle_pending_tx( + pending_tx_hash=pending_tx_hash, seen_ids=seen_ids ) - if pending_tx.tx.protocol == ChainSyncProtocol.OFF_CHAIN_SYNC: - seen_offchain_hashes.add(pending_tx.tx.content) - tx_task = asyncio.create_task( - self.handle_pending_tx(pending_tx, seen_ids=seen_ids) - ) - tasks.add(tx_task) - - # Wait for the last tasks - if tasks: - done, _ = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) +async def make_pending_tx_queue(config: Config) -> aio_pika.abc.AbstractQueue: + mq_conn = await aio_pika.connect_robust( + host=config.rabbitmq.host.value, + port=config.rabbitmq.port.value, + login=config.rabbitmq.username.value, + password=config.rabbitmq.password.value, + ) + channel = await mq_conn.channel() + pending_tx_exchange = await channel.declare_exchange( + name=config.rabbitmq.pending_tx_exchange.value, + type=aio_pika.ExchangeType.TOPIC, + auto_delete=False, + ) + pending_tx_queue = await channel.declare_queue( + name="pending-tx-queue", durable=True, auto_delete=False + ) + await pending_tx_queue.bind(pending_tx_exchange, routing_key="*") + return pending_tx_queue async def handle_txs_task(config: Config): @@ -138,15 +139,20 @@ async def handle_txs_task(config: Config): storage_service=storage_service, config=config, ) + chain_data_service = ChainDataService( + session_factory=session_factory, storage_service=storage_service + ) + pending_tx_queue = await make_pending_tx_queue(config=config) pending_tx_processor = PendingTxProcessor( session_factory=session_factory, - storage_service=storage_service, message_publisher=message_publisher, + chain_data_service=chain_data_service, + pending_tx_queue=pending_tx_queue, ) while True: try: - await pending_tx_processor.process_pending_txs(max_concurrent_tasks) + await pending_tx_processor.process_pending_txs() await asyncio.sleep(5) except Exception: LOGGER.exception("Error in pending txs job") diff --git a/tests/message_processing/test_process_pending_txs.py b/tests/message_processing/test_process_pending_txs.py index d30bdc818..a21a913a7 100644 --- a/tests/message_processing/test_process_pending_txs.py +++ b/tests/message_processing/test_process_pending_txs.py @@ -41,12 +41,13 @@ async def test_process_pending_tx_on_chain_protocol( chain_data_service.get_tx_messages = get_fixture_chaindata_messages pending_tx_processor = PendingTxProcessor( session_factory=session_factory, - storage_service=test_storage_service, message_publisher=MessagePublisher( session_factory=session_factory, storage_service=test_storage_service, config=mock_config, ), + chain_data_service=chain_data_service, + pending_tx_queue=mocker.AsyncMock(), ) pending_tx_processor.chain_data_service = chain_data_service @@ -69,7 +70,7 @@ async def test_process_pending_tx_on_chain_protocol( seen_ids: Set[str] = set() await pending_tx_processor.handle_pending_tx( - pending_tx=pending_tx, seen_ids=seen_ids + pending_tx_hash=pending_tx.tx_hash, seen_ids=seen_ids ) fixture_messages = load_fixture_messages(f"{pending_tx.tx.content}.json") @@ -114,12 +115,13 @@ async def _process_smart_contract_tx( ) pending_tx_processor = PendingTxProcessor( session_factory=session_factory, - storage_service=test_storage_service, message_publisher=MessagePublisher( session_factory=session_factory, storage_service=test_storage_service, config=mock_config, ), + chain_data_service=chain_data_service, + pending_tx_queue=mocker.AsyncMock(), ) pending_tx_processor.chain_data_service = chain_data_service @@ -140,7 +142,7 @@ async def _process_smart_contract_tx( session.add(pending_tx) session.commit() - await pending_tx_processor.handle_pending_tx(pending_tx=pending_tx) + await pending_tx_processor.handle_pending_tx(pending_tx_hash=pending_tx.tx_hash) with session_factory() as session: pending_txs = session.execute(select(PendingTxDb)).scalars().all()