diff --git a/src/aleph/chains/bsc.py b/src/aleph/chains/bsc.py index 1a0a06086..56dff95c3 100644 --- a/src/aleph/chains/bsc.py +++ b/src/aleph/chains/bsc.py @@ -1,8 +1,8 @@ from aleph_message.models import Chain from configmanager import Config -from aleph.chains.chain_data_service import ChainDataService from aleph.chains.abc import ChainReader +from aleph.chains.chain_data_service import PendingTxPublisher from aleph.chains.indexer_reader import AlephIndexerReader from aleph.types.chain_sync import ChainEventType from aleph.types.db_session import DbSessionFactory @@ -10,12 +10,14 @@ class BscConnector(ChainReader): def __init__( - self, session_factory: DbSessionFactory, chain_data_service: ChainDataService + self, + session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, ): self.indexer_reader = AlephIndexerReader( chain=Chain.BSC, session_factory=session_factory, - chain_data_service=chain_data_service, + pending_tx_publisher=pending_tx_publisher, ) async def fetcher(self, config: Config): diff --git a/src/aleph/chains/chain_data_service.py b/src/aleph/chains/chain_data_service.py index 498b472ea..bbc7fa741 100644 --- a/src/aleph/chains/chain_data_service.py +++ b/src/aleph/chains/chain_data_service.py @@ -1,8 +1,10 @@ import asyncio from io import StringIO -from typing import Dict, Optional, List, Any, Mapping, Set, cast, Type, Union +from typing import Dict, Optional, List, Any, Mapping, Set, cast, Type, Union, Self +import aio_pika.abc from aleph_message.models import StoreContent, ItemType, Chain, MessageType +from configmanager import Config from pydantic import ValidationError from aleph.chains.common import LOGGER @@ -36,7 +38,9 @@ class ChainDataService: def __init__( - self, session_factory: DbSessionFactory, storage_service: StorageService + self, + session_factory: DbSessionFactory, + storage_service: StorageService, ): self.session_factory = session_factory self.storage_service = storage_service @@ -215,11 +219,54 @@ async def get_tx_messages( LOGGER.info("%s", error_msg) raise InvalidContent(error_msg) + +async def make_pending_tx_exchange(config: Config) -> aio_pika.abc.AbstractExchange: + mq_conn = await aio_pika.connect_robust( + host=config.rabbitmq.host.value, + port=config.rabbitmq.port.value, + login=config.rabbitmq.username.value, + password=config.rabbitmq.password.value, + ) + channel = await mq_conn.channel() + pending_tx_exchange = await channel.declare_exchange( + name=config.rabbitmq.pending_tx_exchange.value, + type=aio_pika.ExchangeType.TOPIC, + auto_delete=False, + ) + return pending_tx_exchange + + +class PendingTxPublisher: + def __init__(self, pending_tx_exchange: aio_pika.abc.AbstractExchange): + self.pending_tx_exchange = pending_tx_exchange + @staticmethod - async def incoming_chaindata(session: DbSession, tx: ChainTxDb): - """Incoming data from a chain. - Content can be inline of "offchain" through an ipfs hash. - For now, we only add it to the database, it will be processed later. - """ + def add_pending_tx(session: DbSession, tx: ChainTxDb): upsert_chain_tx(session=session, tx=tx) upsert_pending_tx(session=session, tx_hash=tx.hash) + + async def publish_pending_tx(self, tx: ChainTxDb): + message = aio_pika.Message(body=tx.hash.encode("utf-8")) + await self.pending_tx_exchange.publish( + message=message, routing_key=f"{tx.chain.value}.{tx.publisher}.{tx.hash}" + ) + + async def add_and_publish_pending_tx(self, session: DbSession, tx: ChainTxDb): + """ + Add an event published on one of the supported chains. + Adds the tx to the database, creates a pending tx entry in the pending tx table + and publishes a message on the pending tx exchange. + + Note that this function commits changes to the database for consistency + between the DB and the message queue. + """ + self.add_pending_tx(session=session, tx=tx) + session.commit() + await self.publish_pending_tx(tx) + + @classmethod + async def new(cls, config: Config) -> Self: + pending_tx_exchange = await make_pending_tx_exchange(config=config) + return cls( + pending_tx_exchange=pending_tx_exchange, + ) diff --git a/src/aleph/chains/connector.py b/src/aleph/chains/connector.py index 825c0e00c..112ca9075 100644 --- a/src/aleph/chains/connector.py +++ b/src/aleph/chains/connector.py @@ -5,11 +5,10 @@ from aleph_message.models import Chain from configmanager import Config -from aleph.storage import StorageService from aleph.types.db_session import DbSessionFactory -from .bsc import BscConnector -from .chain_data_service import ChainDataService from .abc import ChainReader, ChainWriter +from .bsc import BscConnector +from .chain_data_service import ChainDataService, PendingTxPublisher from .ethereum import EthereumConnector from .nuls2 import Nuls2Connector from .tezos import TezosConnector @@ -28,9 +27,13 @@ class ChainConnector: writers: Dict[Chain, ChainWriter] def __init__( - self, session_factory: DbSessionFactory, chain_data_service: ChainDataService + self, + session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, + chain_data_service: ChainDataService, ): self._session_factory = session_factory + self.pending_tx_publisher = pending_tx_publisher self._chain_data_service = chain_data_service self.readers = {} @@ -96,13 +99,14 @@ def _register_chains(self): Chain.BSC, BscConnector( session_factory=self._session_factory, - chain_data_service=self._chain_data_service, + pending_tx_publisher=self.pending_tx_publisher, ), ) self._add_chain( Chain.NULS2, Nuls2Connector( session_factory=self._session_factory, + pending_tx_publisher=self.pending_tx_publisher, chain_data_service=self._chain_data_service, ), ) @@ -110,6 +114,7 @@ def _register_chains(self): Chain.ETH, EthereumConnector( session_factory=self._session_factory, + pending_tx_publisher=self.pending_tx_publisher, chain_data_service=self._chain_data_service, ), ) @@ -117,6 +122,6 @@ def _register_chains(self): Chain.TEZOS, TezosConnector( session_factory=self._session_factory, - chain_data_service=self._chain_data_service, + pending_tx_publisher=self.pending_tx_publisher, ), ) diff --git a/src/aleph/chains/ethereum.py b/src/aleph/chains/ethereum.py index e63ceeafe..f2a5812b0 100644 --- a/src/aleph/chains/ethereum.py +++ b/src/aleph/chains/ethereum.py @@ -22,16 +22,16 @@ from aleph.db.accessors.messages import get_unconfirmed_messages from aleph.db.accessors.pending_messages import count_pending_messages from aleph.db.accessors.pending_txs import count_pending_txs +from aleph.db.models.chains import ChainTxDb from aleph.schemas.chains.tx_context import TxContext from aleph.schemas.pending_messages import BasePendingMessage from aleph.toolkit.timestamp import utc_now +from aleph.types.chain_sync import ChainEventType from aleph.types.db_session import DbSessionFactory from aleph.utils import run_in_executor -from .chain_data_service import ChainDataService -from .abc import ChainWriter, Verifier, ChainReader +from .abc import ChainWriter, Verifier +from .chain_data_service import ChainDataService, PendingTxPublisher from .indexer_reader import AlephIndexerReader -from ..db.models import ChainTxDb -from ..types.chain_sync import ChainEventType LOGGER = logging.getLogger("chains.ethereum") CHAIN_NAME = "ETH" @@ -105,15 +105,17 @@ class EthereumConnector(ChainWriter): def __init__( self, session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, chain_data_service: ChainDataService, ): self.session_factory = session_factory + self.pending_tx_publisher = pending_tx_publisher self.chain_data_service = chain_data_service self.indexer_reader = AlephIndexerReader( chain=Chain.ETH, session_factory=session_factory, - chain_data_service=chain_data_service, + pending_tx_publisher=pending_tx_publisher, ) async def get_last_height(self, sync_type: ChainEventType) -> int: @@ -212,7 +214,9 @@ async def _request_transactions( except json.JSONDecodeError: # if it's not valid json, just ignore it... - LOGGER.info("Incoming logic data is not JSON, ignoring. %r" % message) + LOGGER.info( + "Incoming logic data is not JSON, ignoring. %r" % message + ) except Exception: LOGGER.exception("Can't decode incoming logic data %r" % message) @@ -256,7 +260,7 @@ async def fetch_ethereum_sync_events(self, config: Config): ): tx = ChainTxDb.from_sync_tx_context(tx_context=context, tx_data=jdata) with self.session_factory() as session: - await self.chain_data_service.incoming_chaindata( + await self.pending_tx_publisher.add_and_publish_pending_tx( session=session, tx=tx ) session.commit() @@ -313,7 +317,6 @@ async def packer(self, config: Config): gas_price = web3.eth.generate_gas_price() while True: with self.session_factory() as session: - # Wait for sync operations to complete if (count_pending_txs(session=session, chain=Chain.ETH)) or ( count_pending_messages(session=session, chain=Chain.ETH) @@ -344,8 +347,10 @@ async def packer(self, config: Config): LOGGER.info("Chain sync: %d unconfirmed messages") # This function prepares a chain data file and makes it downloadable from the node. - sync_event_payload = await self.chain_data_service.prepare_sync_event_payload( - session=session, messages=messages + sync_event_payload = ( + await self.chain_data_service.prepare_sync_event_payload( + session=session, messages=messages + ) ) # Required to apply update to the files table in get_chaindata session.commit() diff --git a/src/aleph/chains/indexer_reader.py b/src/aleph/chains/indexer_reader.py index 933a4b1f0..a11419873 100644 --- a/src/aleph/chains/indexer_reader.py +++ b/src/aleph/chains/indexer_reader.py @@ -21,7 +21,7 @@ from pydantic import BaseModel import aleph.toolkit.json as aleph_json -from aleph.chains.chain_data_service import ChainDataService +from aleph.chains.chain_data_service import PendingTxPublisher from aleph.db.accessors.chains import ( get_missing_indexer_datetime_multirange, add_indexer_range, @@ -154,7 +154,6 @@ async def fetch_account_state( blockchain: IndexerBlockchain, accounts: List[str], ) -> IndexerAccountStateResponse: - query = make_account_state_query( blockchain=blockchain, accounts=accounts, type_=EntityType.LOG ) @@ -194,7 +193,6 @@ def indexer_event_to_chain_tx( chain: Chain, indexer_event: Union[MessageEvent, SyncEvent], ) -> ChainTxDb: - if isinstance(indexer_event, MessageEvent): protocol = ChainSyncProtocol.SMART_CONTRACT protocol_version = 1 @@ -225,7 +223,6 @@ async def extract_aleph_messages_from_indexer_response( chain: Chain, indexer_response: IndexerEventResponse, ) -> List[ChainTxDb]: - message_events = indexer_response.data.message_events sync_events = indexer_response.data.sync_events @@ -240,7 +237,6 @@ async def extract_aleph_messages_from_indexer_response( class AlephIndexerReader: - BLOCKCHAIN_MAP: Mapping[Chain, IndexerBlockchain] = { Chain.BSC: IndexerBlockchain.BSC, Chain.ETH: IndexerBlockchain.ETHEREUM, @@ -251,11 +247,11 @@ def __init__( self, chain: Chain, session_factory: DbSessionFactory, - chain_data_service: ChainDataService, + pending_tx_publisher: PendingTxPublisher, ): self.chain = chain self.session_factory = session_factory - self.chain_data_service = chain_data_service + self.pending_tx_publisher = pending_tx_publisher self.blockchain = self.BLOCKCHAIN_MAP[chain] @@ -299,9 +295,7 @@ async def fetch_range( LOGGER.info("%d new txs", len(txs)) # Events are listed in reverse order in the indexer response for tx in txs: - await self.chain_data_service.incoming_chaindata( - session=session, tx=tx - ) + self.pending_tx_publisher.add_pending_tx(session=session, tx=tx) if nb_events_fetched >= limit: last_event_datetime = txs[-1].datetime @@ -317,6 +311,7 @@ async def fetch_range( ) else: synced_range = Range(start_datetime, end_datetime, upper_inc=True) + txs = [] LOGGER.info( "%s %s indexer: fetched %s", @@ -336,6 +331,10 @@ async def fetch_range( # of events. session.commit() + # Now that the txs are committed to the DB, add them to the pending tx message queue + for tx in txs: + await self.pending_tx_publisher.publish_pending_tx(tx) + if nb_events_fetched < limit: LOGGER.info( "%s %s event indexer: done fetching events.", diff --git a/src/aleph/chains/nuls2.py b/src/aleph/chains/nuls2.py index 8e08d3a0b..39f9b8ec4 100644 --- a/src/aleph/chains/nuls2.py +++ b/src/aleph/chains/nuls2.py @@ -30,7 +30,7 @@ from aleph.toolkit.timestamp import utc_now from aleph.types.db_session import DbSessionFactory from aleph.utils import run_in_executor -from .chain_data_service import ChainDataService +from .chain_data_service import ChainDataService, PendingTxPublisher from .abc import Verifier, ChainWriter from aleph.schemas.chains.tx_context import TxContext from ..db.models import ChainTxDb @@ -78,9 +78,13 @@ async def verify_signature(self, message: BasePendingMessage) -> bool: class Nuls2Connector(ChainWriter): def __init__( - self, session_factory: DbSessionFactory, chain_data_service: ChainDataService + self, + session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, + chain_data_service: ChainDataService, ): self.session_factory = session_factory + self.pending_tx_publisher = pending_tx_publisher self.chain_data_service = chain_data_service async def get_last_height(self, sync_type: ChainEventType) -> int: @@ -154,7 +158,7 @@ async def fetcher(self, config: Config): tx_context=context, tx_data=jdata ) with self.session_factory() as db_session: - await self.chain_data_service.incoming_chaindata( + await self.pending_tx_publisher.add_and_publish_pending_tx( session=db_session, tx=tx ) db_session.commit() @@ -197,8 +201,10 @@ async def packer(self, config: Config): if len(messages): # This function prepares a chain data file and makes it downloadable from the node. - sync_event_payload = await self.chain_data_service.prepare_sync_event_payload( - session=session, messages=messages + sync_event_payload = ( + await self.chain_data_service.prepare_sync_event_payload( + session=session, messages=messages + ) ) # Required to apply update to the files table in get_chaindata session.commit() diff --git a/src/aleph/chains/tezos.py b/src/aleph/chains/tezos.py index 5e29dd197..279c1cb2c 100644 --- a/src/aleph/chains/tezos.py +++ b/src/aleph/chains/tezos.py @@ -11,9 +11,9 @@ from nacl.exceptions import BadSignatureError import aleph.toolkit.json as aleph_json -from aleph.chains.chain_data_service import ChainDataService -from aleph.chains.common import get_verification_buffer from aleph.chains.abc import Verifier, ChainReader +from aleph.chains.chain_data_service import PendingTxPublisher +from aleph.chains.common import get_verification_buffer from aleph.db.accessors.chains import get_last_height, upsert_chain_sync_status from aleph.db.models import PendingMessageDb, ChainTxDb from aleph.schemas.chains.tezos_indexer_response import ( @@ -248,10 +248,12 @@ async def verify_signature(self, message: BasePendingMessage) -> bool: class TezosConnector(ChainReader): def __init__( - self, session_factory: DbSessionFactory, chain_data_service: ChainDataService + self, + session_factory: DbSessionFactory, + pending_tx_publisher: PendingTxPublisher, ): self.session_factory = session_factory - self.chain_data_service = chain_data_service + self.pending_tx_publisher = pending_tx_publisher async def get_last_height(self, sync_type: ChainEventType) -> int: """Returns the last height for which we already have the ethereum data.""" @@ -307,7 +309,7 @@ async def fetch_incoming_messages( ) LOGGER.info("%d new txs", len(txs)) for tx in txs: - await self.chain_data_service.incoming_chaindata( + await self.pending_tx_publisher.add_and_publish_pending_tx( session=session, tx=tx ) diff --git a/src/aleph/commands.py b/src/aleph/commands.py index 1e55ad984..6b816db8f 100644 --- a/src/aleph/commands.py +++ b/src/aleph/commands.py @@ -23,7 +23,7 @@ from configmanager import Config import aleph.config -from aleph.chains.chain_data_service import ChainDataService +from aleph.chains.chain_data_service import ChainDataService, PendingTxPublisher from aleph.chains.connector import ChainConnector from aleph.cli.args import parse_args from aleph.db.connection import make_engine, make_session_factory, make_db_url @@ -45,7 +45,6 @@ __copyright__ = "Moshe Malawach" __license__ = "mit" - LOGGER = logging.getLogger(__name__) @@ -145,10 +144,14 @@ async def main(args: List[str]) -> None: node_cache=node_cache, ) chain_data_service = ChainDataService( - session_factory=session_factory, storage_service=storage_service + session_factory=session_factory, + storage_service=storage_service, ) + pending_tx_publisher = await PendingTxPublisher.new(config=config) chain_connector = ChainConnector( - session_factory=session_factory, chain_data_service=chain_data_service + session_factory=session_factory, + pending_tx_publisher=pending_tx_publisher, + chain_data_service=chain_data_service, ) set_start_method("spawn") diff --git a/src/aleph/config.py b/src/aleph/config.py index 34df86ecf..e996c7130 100644 --- a/src/aleph/config.py +++ b/src/aleph/config.py @@ -125,6 +125,8 @@ def get_defaults(): "pub_exchange": "p2p-publish", "sub_exchange": "p2p-subscribe", "message_exchange": "aleph-messages", + "pending_message_exchange": "aleph-pending-messages", + "pending_tx_exchange": "aleph-pending-txs", }, "redis": { "host": "redis", diff --git a/src/aleph/db/accessors/pending_txs.py b/src/aleph/db/accessors/pending_txs.py index 3c64f41b0..ada9802fc 100644 --- a/src/aleph/db/accessors/pending_txs.py +++ b/src/aleph/db/accessors/pending_txs.py @@ -1,7 +1,7 @@ from typing import Optional, Iterable from aleph_message.models import Chain -from sqlalchemy import select, func +from sqlalchemy import select, func, delete from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import selectinload @@ -9,6 +9,15 @@ from aleph.types.db_session import DbSession +def get_pending_tx(session: DbSession, tx_hash: str) -> Optional[PendingTxDb]: + select_stmt = ( + select(PendingTxDb) + .where(PendingTxDb.tx_hash == tx_hash) + .options(selectinload(PendingTxDb.tx)) + ) + return (session.execute(select_stmt)).scalar_one_or_none() + + def get_pending_txs(session: DbSession, limit: int = 200) -> Iterable[PendingTxDb]: select_stmt = ( select(PendingTxDb) @@ -33,3 +42,8 @@ def count_pending_txs(session: DbSession, chain: Optional[Chain] = None) -> int: def upsert_pending_tx(session: DbSession, tx_hash: str) -> None: upsert_stmt = insert(PendingTxDb).values(tx_hash=tx_hash).on_conflict_do_nothing() session.execute(upsert_stmt) + + +def delete_pending_tx(session: DbSession, tx_hash: str) -> None: + delete_stmt = delete(PendingTxDb).where(PendingTxDb.tx_hash == tx_hash) + session.execute(delete_stmt) diff --git a/src/aleph/jobs/process_pending_txs.py b/src/aleph/jobs/process_pending_txs.py index e1f7e9e2d..5d0df4371 100644 --- a/src/aleph/jobs/process_pending_txs.py +++ b/src/aleph/jobs/process_pending_txs.py @@ -6,16 +6,19 @@ import logging from typing import Dict, Optional, Set +import aio_pika.abc from configmanager import Config from setproctitle import setproctitle -from sqlalchemy import delete -from ..chains.signature_verifier import SignatureVerifier from aleph.chains.chain_data_service import ChainDataService -from aleph.db.accessors.pending_txs import get_pending_txs +from aleph.chains.signature_verifier import SignatureVerifier +from aleph.db.accessors.pending_txs import ( + get_pending_tx, + delete_pending_tx, +) from aleph.db.connection import make_engine, make_session_factory -from aleph.db.models.pending_txs import PendingTxDb from aleph.handlers.message_handler import MessageHandler +from aleph.services.cache.node_cache import NodeCache from aleph.services.ipfs.common import make_ipfs_client from aleph.services.ipfs.service import IpfsService from aleph.services.storage.fileystem_engine import FileSystemStorageEngine @@ -26,7 +29,6 @@ from aleph.types.chain_sync import ChainSyncProtocol from aleph.types.db_session import DbSessionFactory from .job_utils import prepare_loop -from ..services.cache.node_cache import NodeCache LOGGER = logging.getLogger(__name__) @@ -35,86 +37,85 @@ class PendingTxProcessor: def __init__( self, session_factory: DbSessionFactory, - storage_service: StorageService, message_handler: MessageHandler, + chain_data_service: ChainDataService, + pending_tx_queue: aio_pika.abc.AbstractQueue, ): self.session_factory = session_factory - self.storage_service = storage_service self.message_handler = message_handler - self.chain_data_service = ChainDataService( - session_factory=session_factory, storage_service=storage_service - ) + self.chain_data_service = chain_data_service + self.pending_tx_queue = pending_tx_queue async def handle_pending_tx( - self, pending_tx: PendingTxDb, seen_ids: Optional[Set[str]] = None + self, pending_tx_hash: str, seen_ids: Optional[Set[str]] = None ) -> None: - LOGGER.info( - "%s Handling TX in block %s", pending_tx.tx.chain, pending_tx.tx.height - ) - - tx = pending_tx.tx - - # If the chain data file is unavailable, we leave it to the pending tx - # processor to log the content unavailable exception and retry later. - messages = await self.chain_data_service.get_tx_messages( - tx=pending_tx.tx, seen_ids=seen_ids - ) - - if messages: - for i, message_dict in enumerate(messages): - await self.message_handler.add_pending_message( - message_dict=message_dict, - reception_time=utc_now(), - tx_hash=tx.hash, - check_message=tx.protocol != ChainSyncProtocol.SMART_CONTRACT, - ) - - else: - LOGGER.debug("TX contains no message") - - if messages is not None: - # bogus or handled, we remove it. - with self.session_factory() as session: - session.execute( - delete(PendingTxDb).where( - PendingTxDb.tx_hash == pending_tx.tx_hash - ), - ) + with self.session_factory() as session: + pending_tx = get_pending_tx(session=session, tx_hash=pending_tx_hash) + + if pending_tx is None: + LOGGER.warning("TX %s is not pending anymore", pending_tx_hash) + return + + tx = pending_tx.tx + LOGGER.info("%s Handling TX in block %s", tx.chain, tx.height) + + # If the chain data file is unavailable, we leave it to the pending tx + # processor to log the content unavailable exception and retry later. + messages = await self.chain_data_service.get_tx_messages( + tx=pending_tx.tx, seen_ids=seen_ids + ) + + if messages: + for i, message_dict in enumerate(messages): + await self.message_handler.add_pending_message( + message_dict=message_dict, + reception_time=utc_now(), + tx_hash=tx.hash, + check_message=tx.protocol != ChainSyncProtocol.SMART_CONTRACT, + ) + + else: + LOGGER.debug("TX contains no message") + + if messages is not None: + # bogus or handled, we remove it. + delete_pending_tx(session=session, tx_hash=pending_tx_hash) session.commit() - async def process_pending_txs(self, max_concurrent_tasks: int): + async def process_pending_txs(self) -> None: """ Process chain transactions in the Pending TX queue. """ - tasks: Set[asyncio.Task] = set() - - seen_offchain_hashes = set() seen_ids: Set[str] = set() LOGGER.info("handling TXs") - with self.session_factory() as session: - for pending_tx in get_pending_txs(session): - # TODO: remove this feature? It doesn't seem necessary. - if pending_tx.tx.protocol == ChainSyncProtocol.OFF_CHAIN_SYNC: - if pending_tx.tx.content in seen_offchain_hashes: - continue - - if len(tasks) == max_concurrent_tasks: - done, tasks = await asyncio.wait( - tasks, return_when=asyncio.FIRST_COMPLETED + async with self.pending_tx_queue.iterator() as queue_iter: + async for pending_tx_message in queue_iter: + async with pending_tx_message.process(): + pending_tx_hash = pending_tx_message.body.decode("utf-8") + await self.handle_pending_tx( + pending_tx_hash=pending_tx_hash, seen_ids=seen_ids ) - if pending_tx.tx.protocol == ChainSyncProtocol.OFF_CHAIN_SYNC: - seen_offchain_hashes.add(pending_tx.tx.content) - tx_task = asyncio.create_task( - self.handle_pending_tx(pending_tx, seen_ids=seen_ids) - ) - tasks.add(tx_task) - - # Wait for the last tasks - if tasks: - done, _ = await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) +async def make_pending_tx_queue(config: Config) -> aio_pika.abc.AbstractQueue: + mq_conn = await aio_pika.connect_robust( + host=config.rabbitmq.host.value, + port=config.rabbitmq.port.value, + login=config.rabbitmq.username.value, + password=config.rabbitmq.password.value, + ) + channel = await mq_conn.channel() + pending_tx_exchange = await channel.declare_exchange( + name=config.rabbitmq.pending_tx_exchange.value, + type=aio_pika.ExchangeType.TOPIC, + auto_delete=False, + ) + pending_tx_queue = await channel.declare_queue( + name="pending-tx-queue", durable=True, auto_delete=False + ) + await pending_tx_queue.bind(pending_tx_exchange, routing_key="*") + return pending_tx_queue async def handle_txs_task(config: Config): @@ -141,15 +142,20 @@ async def handle_txs_task(config: Config): storage_service=storage_service, config=config, ) + chain_data_service = ChainDataService( + session_factory=session_factory, storage_service=storage_service + ) + pending_tx_queue = await make_pending_tx_queue(config=config) pending_tx_processor = PendingTxProcessor( session_factory=session_factory, - storage_service=storage_service, message_handler=message_handler, + chain_data_service=chain_data_service, + pending_tx_queue=pending_tx_queue, ) while True: try: - await pending_tx_processor.process_pending_txs(max_concurrent_tasks) + await pending_tx_processor.process_pending_txs() await asyncio.sleep(5) except Exception: LOGGER.exception("Error in pending txs job") diff --git a/tests/message_processing/test_process_pending_txs.py b/tests/message_processing/test_process_pending_txs.py index 602a3c0a3..ee45b1649 100644 --- a/tests/message_processing/test_process_pending_txs.py +++ b/tests/message_processing/test_process_pending_txs.py @@ -43,13 +43,14 @@ async def test_process_pending_tx_on_chain_protocol( signature_verifier = SignatureVerifier() pending_tx_processor = PendingTxProcessor( session_factory=session_factory, - storage_service=test_storage_service, message_handler=MessageHandler( session_factory=session_factory, signature_verifier=signature_verifier, storage_service=test_storage_service, config=mock_config, ), + chain_data_service=chain_data_service, + pending_tx_queue=mocker.AsyncMock(), ) pending_tx_processor.chain_data_service = chain_data_service @@ -72,7 +73,7 @@ async def test_process_pending_tx_on_chain_protocol( seen_ids: Set[str] = set() await pending_tx_processor.handle_pending_tx( - pending_tx=pending_tx, seen_ids=seen_ids + pending_tx_hash=pending_tx.tx_hash, seen_ids=seen_ids ) fixture_messages = load_fixture_messages(f"{pending_tx.tx.content}.json") @@ -118,13 +119,14 @@ async def _process_smart_contract_tx( signature_verifier = SignatureVerifier() pending_tx_processor = PendingTxProcessor( session_factory=session_factory, - storage_service=test_storage_service, message_handler=MessageHandler( session_factory=session_factory, signature_verifier=signature_verifier, storage_service=test_storage_service, config=mock_config, ), + chain_data_service=chain_data_service, + pending_tx_queue=mocker.AsyncMock(), ) pending_tx_processor.chain_data_service = chain_data_service @@ -145,7 +147,7 @@ async def _process_smart_contract_tx( session.add(pending_tx) session.commit() - await pending_tx_processor.handle_pending_tx(pending_tx=pending_tx) + await pending_tx_processor.handle_pending_tx(pending_tx_hash=pending_tx.tx_hash) with session_factory() as session: pending_txs = session.execute(select(PendingTxDb)).scalars().all()