From d9495dffbbf0487a3c920aed66eccea01fd1c890 Mon Sep 17 00:00:00 2001 From: fubuloubu <3859395+fubuloubu@users.noreply.github.com> Date: Tue, 16 Apr 2024 22:37:38 -0400 Subject: [PATCH] refactor!: migrate recorder config to CLI callback also refactor result handling --- silverback/_cli.py | 19 ++++- silverback/middlewares.py | 24 +----- silverback/runner.py | 159 ++++++++++++++++++-------------------- 3 files changed, 93 insertions(+), 109 deletions(-) diff --git a/silverback/_cli.py b/silverback/_cli.py index 4a604690..f3a48f1d 100644 --- a/silverback/_cli.py +++ b/silverback/_cli.py @@ -34,6 +34,16 @@ def _runner_callback(ctx, param, val): raise ValueError(f"Failed to import runner '{val}'.") +def _recorder_callback(ctx, param, val): + if not val: + return None + + elif recorder := import_from_string(val): + return recorder() + + raise ValueError(f"Failed to import recorder '{val}'.") + + def _account_callback(ctx, param, val): if val: val = val.alias.replace("dev_", "TEST::") @@ -92,11 +102,16 @@ async def run_worker(broker: AsyncBroker, worker_count=2, shutdown_timeout=90): help="An import str in format ':'", callback=_runner_callback, ) +@click.option( + "--recorder", + help="An import string in format ':'", + callback=_recorder_callback, +) @click.option("-x", "--max-exceptions", type=int, default=3) @click.argument("path") -def run(cli_ctx, account, runner, max_exceptions, path): +def run(cli_ctx, account, runner, recorder, max_exceptions, path): app = import_from_string(path) - runner = runner(app, max_exceptions=max_exceptions) + runner = runner(app, recorder=recorder, max_exceptions=max_exceptions) asyncio.run(runner.run()) diff --git a/silverback/middlewares.py b/silverback/middlewares.py index 98cadd49..02ab15e0 100644 --- a/silverback/middlewares.py +++ b/silverback/middlewares.py @@ -6,8 +6,7 @@ from eth_utils.conversions import to_hex from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult -from silverback.recorder import HandlerResult -from silverback.types import SilverbackID, TaskType +from silverback.types import TaskType from silverback.utils import hexbytes_dict @@ -22,11 +21,7 @@ def compute_block_time() -> int: return int((head.timestamp - genesis.timestamp) / head.number) - settings = kwargs.pop("silverback_settings") - self.block_time = self.chain_manager.provider.network.block_time or compute_block_time() - self.ident = SilverbackID.from_settings(settings) - self.recorder = settings.get_recorder() def pre_send(self, message: TaskiqMessage) -> TaskiqMessage: # TODO: Necessary because bytes/HexBytes doesn't encode/deocde well for some reason @@ -95,21 +90,4 @@ def post_execute(self, message: TaskiqMessage, result: TaskiqResult): f"{self._create_label(message)} " f"- {result.execution_time:.3f}s{percent_display}" ) - async def post_save(self, message: TaskiqMessage, result: TaskiqResult): - if not self.recorder: - return - - handler_result = HandlerResult.from_taskiq( - self.ident, - message.task_name, - message.labels.get("block_number"), - message.labels.get("log_index"), - result, - ) - - try: - await self.recorder.add_result(handler_result) - except Exception as err: - logger.error(f"Error storing result: {err}") - # NOTE: Unless stdout is ignored, error traceback appears in stdout, no need for `on_error` diff --git a/silverback/runner.py b/silverback/runner.py index 7ba5e738..907068b7 100644 --- a/silverback/runner.py +++ b/silverback/runner.py @@ -1,73 +1,88 @@ import asyncio from abc import ABC, abstractmethod -from typing import Optional, Tuple +from typing import Optional from ape import chain from ape.contracts import ContractEvent, ContractInstance from ape.logging import logger from ape.utils import ManagerAccessMixin from ape_ethereum.ecosystem import keccak -from taskiq import AsyncTaskiqDecoratedTask, TaskiqResult +from taskiq import AsyncTaskiqDecoratedTask, AsyncTaskiqTask from .application import SilverbackApp from .exceptions import Halt, NoWebsocketAvailableError -from .recorder import BaseRecorder -from .settings import Settings +from .recorder import BaseRecorder, TaskResult from .subscriptions import SubscriptionType, Web3SubscriptionsManager -from .types import SilverbackID, SilverbackStartupState, TaskType +from .types import AppState, SilverbackID, TaskType from .utils import async_wrap_iter, hexbytes_dict -settings = Settings() - class BaseRunner(ABC): - def __init__(self, app: SilverbackApp, *args, max_exceptions: int = 3, **kwargs): + def __init__( + self, + app: SilverbackApp, + *args, + max_exceptions: int = 3, + recorder: Optional[BaseRecorder] = None, + **kwargs, + ): self.app = app + self.recorder = recorder self.max_exceptions = max_exceptions self.exceptions = 0 - self.last_block_seen = 0 - self.last_block_processed = 0 - self.recorder: Optional[BaseRecorder] = None - self.ident = SilverbackID.from_settings(settings) - def _handle_result(self, result: TaskiqResult): - if result.is_err: - self.exceptions += 1 + ecosystem_name, network_name = app.network_choice.split(":") + self.identifier = SilverbackID( + name=app.name, + ecosystem=ecosystem_name, + network=network_name, + ) + + logger.info(f"Using {self.__class__.__name__}: max_exceptions={self.max_exceptions}") + + async def _handle_task(self, task: AsyncTaskiqTask): + result = await task.wait_result() + + if self.recorder: + await self.recorder.add_result(TaskResult.from_taskiq(result)) - else: + if not result.is_err: # NOTE: Reset exception counter self.exceptions = 0 + return - if self.exceptions > self.max_exceptions: - raise Halt() from result.error + self.exceptions += 1 + + if self.exceptions > self.max_exceptions or isinstance(result.error, Halt): + result.raise_for_error() async def _checkpoint( - self, last_block_seen: int = 0, last_block_processed: int = 0 - ) -> Tuple[int, int]: + self, + last_block_seen: Optional[int] = None, + last_block_processed: Optional[int] = None, + ): """Set latest checkpoint block number""" - if ( - last_block_seen > self.last_block_seen - or last_block_processed > self.last_block_processed - ): - logger.debug( - ( - f"Checkpoint block [seen={self.last_block_seen}, " - f"procssed={self.last_block_processed}]" - ) + assert self.app.state, f"{self.__class__.__name__}.run() not triggered." + + logger.debug( + ( + f"Checkpoint block [seen={self.app.state.last_block_seen}, " + f"procssed={self.app.state.last_block_processed}]" ) - self.last_block_seen = max(last_block_seen, self.last_block_seen) - self.last_block_processed = max(last_block_processed, self.last_block_processed) + ) - if self.recorder: - try: - await self.recorder.set_state( - self.ident, self.last_block_seen, self.last_block_processed - ) - except Exception as err: - logger.error(f"Error settings state: {err}") + if last_block_seen: + self.app.state.last_block_seen = last_block_seen + if last_block_processed: + self.app.state.last_block_processed = last_block_processed + + if self.recorder: + try: + await self.recorder.set_state(self.app.state) - return self.last_block_seen, self.last_block_processed + except Exception as err: + logger.error(f"Error setting state: {err}") @abstractmethod async def _block_task(self, block_handler: AsyncTaskiqDecoratedTask): @@ -93,14 +108,14 @@ async def run(self): Raises: :class:`~silverback.exceptions.Halt`: If there are no configured tasks to execute. """ - self.recorder = settings.get_recorder() + # Initialize recorder (if available) and fetch state if app has been run previously + if self.recorder and (startup_state := (await self.recorder.init(app_id=self.identifier))): + self.app.state = startup_state - if self.recorder: - boot_state = await self.recorder.get_state(self.ident) - if boot_state: - self.last_block_seen = boot_state.last_block_seen - self.last_block_processed = boot_state.last_block_processed + else: # use empty state + self.app.state = AppState(last_block_seen=-1, last_block_processed=-1) + # Initialize broker (run worker startup events) await self.app.broker.startup() # Execute Silverback startup task before we init the rest @@ -144,7 +159,6 @@ class WebsocketRunner(BaseRunner, ManagerAccessMixin): def __init__(self, app: SilverbackApp, *args, **kwargs): super().__init__(app, *args, **kwargs) - logger.info(f"Using {self.__class__.__name__}: max_exceptions={self.max_exceptions}") # Check for websocket support if not (ws_uri := app.chain_manager.provider.ws_uri): @@ -159,16 +173,9 @@ async def _block_task(self, block_handler: AsyncTaskiqDecoratedTask): async for raw_block in self.subscriptions.get_subscription_data(sub_id): block = self.provider.network.ecosystem.decode_block(hexbytes_dict(raw_block)) - if block.number is not None: - await self._checkpoint(last_block_seen=block.number) - - block_task = await block_handler.kiq(raw_block) - result = await block_task.wait_result() - - self._handle_result(result) - - if block.number is not None: - await self._checkpoint(last_block_processed=block.number) + await self._checkpoint(last_block_seen=block.number) + await self._handle_task(await block_handler.kiq(raw_block)) + await self._checkpoint(last_block_processed=block.number) async def _event_task( self, contract_event: ContractEvent, event_handler: AsyncTaskiqDecoratedTask @@ -182,7 +189,9 @@ async def _event_task( address=contract_event.contract.address, topics=["0x" + keccak(text=contract_event.abi.selector).hex()], ) - logger.debug(f"Handling '{contract_event.name}' events via {sub_id}") + logger.debug( + f"Handling '{contract_event.contract.address}:{contract_event.name}' logs via {sub_id}" + ) async for raw_event in self.subscriptions.get_subscription_data(sub_id): event = next( # NOTE: `next` is okay since it only has one item @@ -192,15 +201,9 @@ async def _event_task( ) ) - if event.block_number is not None: - await self._checkpoint(last_block_seen=event.block_number) - - event_task = await event_handler.kiq(event) - result = await event_task.wait_result() - self._handle_result(result) - - if event.block_number is not None: - await self._checkpoint(last_block_processed=event.block_number) + await self._checkpoint(last_block_seen=event.block_number) + await self._handle_task(await event_handler.kiq(event)) + await self._checkpoint(last_block_processed=event.block_number) async def run(self): async with Web3SubscriptionsManager(self.ws_uri) as subscriptions: @@ -235,15 +238,9 @@ async def _block_task(self, block_handler: AsyncTaskiqDecoratedTask): async for block in async_wrap_iter( chain.blocks.poll_blocks(start_block=start_block, new_block_timeout=new_block_timeout) ): - if block.number is not None: - await self._checkpoint(last_block_seen=block.number) - - block_task = await block_handler.kiq(block) - result = await block_task.wait_result() - self._handle_result(result) - - if block.number is not None: - await self._checkpoint(last_block_processed=block.number) + await self._checkpoint(last_block_seen=block.number) + await self._handle_task(await block_handler.kiq(block)) + await self._checkpoint(last_block_processed=block.number) async def _event_task( self, contract_event: ContractEvent, event_handler: AsyncTaskiqDecoratedTask @@ -265,12 +262,6 @@ async def _event_task( async for event in async_wrap_iter( contract_event.poll_logs(start_block=start_block, new_block_timeout=new_block_timeout) ): - if event.block_number is not None: - await self._checkpoint(last_block_seen=event.block_number) - - event_task = await event_handler.kiq(event) - result = await event_task.wait_result() - self._handle_result(result) - - if event.block_number is not None: - await self._checkpoint(last_block_processed=event.block_number) + await self._checkpoint(last_block_seen=event.block_number) + await self._handle_task(await event_handler.kiq(event)) + await self._checkpoint(last_block_processed=event.block_number)