Skip to content

Commit

Permalink
refactor!: migrate recorder config to CLI callback
Browse files Browse the repository at this point in the history
also refactor result handling
  • Loading branch information
fubuloubu committed Apr 17, 2024
1 parent a48f34a commit d9495df
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 109 deletions.
19 changes: 17 additions & 2 deletions silverback/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@ def _runner_callback(ctx, param, val):
raise ValueError(f"Failed to import runner '{val}'.")


def _recorder_callback(ctx, param, val):
if not val:
return None

elif recorder := import_from_string(val):
return recorder()

raise ValueError(f"Failed to import recorder '{val}'.")


def _account_callback(ctx, param, val):
if val:
val = val.alias.replace("dev_", "TEST::")
Expand Down Expand Up @@ -92,11 +102,16 @@ async def run_worker(broker: AsyncBroker, worker_count=2, shutdown_timeout=90):
help="An import str in format '<module>:<CustomRunner>'",
callback=_runner_callback,
)
@click.option(
"--recorder",
help="An import string in format '<module>:<CustomRecorder>'",
callback=_recorder_callback,
)
@click.option("-x", "--max-exceptions", type=int, default=3)
@click.argument("path")
def run(cli_ctx, account, runner, max_exceptions, path):
def run(cli_ctx, account, runner, recorder, max_exceptions, path):
app = import_from_string(path)
runner = runner(app, max_exceptions=max_exceptions)
runner = runner(app, recorder=recorder, max_exceptions=max_exceptions)
asyncio.run(runner.run())


Expand Down
24 changes: 1 addition & 23 deletions silverback/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from eth_utils.conversions import to_hex
from taskiq import TaskiqMessage, TaskiqMiddleware, TaskiqResult

from silverback.recorder import HandlerResult
from silverback.types import SilverbackID, TaskType
from silverback.types import TaskType
from silverback.utils import hexbytes_dict


Expand All @@ -22,11 +21,7 @@ def compute_block_time() -> int:

return int((head.timestamp - genesis.timestamp) / head.number)

settings = kwargs.pop("silverback_settings")

self.block_time = self.chain_manager.provider.network.block_time or compute_block_time()
self.ident = SilverbackID.from_settings(settings)
self.recorder = settings.get_recorder()

def pre_send(self, message: TaskiqMessage) -> TaskiqMessage:
# TODO: Necessary because bytes/HexBytes doesn't encode/deocde well for some reason
Expand Down Expand Up @@ -95,21 +90,4 @@ def post_execute(self, message: TaskiqMessage, result: TaskiqResult):
f"{self._create_label(message)} " f"- {result.execution_time:.3f}s{percent_display}"
)

async def post_save(self, message: TaskiqMessage, result: TaskiqResult):
if not self.recorder:
return

handler_result = HandlerResult.from_taskiq(
self.ident,
message.task_name,
message.labels.get("block_number"),
message.labels.get("log_index"),
result,
)

try:
await self.recorder.add_result(handler_result)
except Exception as err:
logger.error(f"Error storing result: {err}")

# NOTE: Unless stdout is ignored, error traceback appears in stdout, no need for `on_error`
159 changes: 75 additions & 84 deletions silverback/runner.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,88 @@
import asyncio
from abc import ABC, abstractmethod
from typing import Optional, Tuple
from typing import Optional

from ape import chain
from ape.contracts import ContractEvent, ContractInstance
from ape.logging import logger
from ape.utils import ManagerAccessMixin
from ape_ethereum.ecosystem import keccak
from taskiq import AsyncTaskiqDecoratedTask, TaskiqResult
from taskiq import AsyncTaskiqDecoratedTask, AsyncTaskiqTask

from .application import SilverbackApp
from .exceptions import Halt, NoWebsocketAvailableError
from .recorder import BaseRecorder
from .settings import Settings
from .recorder import BaseRecorder, TaskResult
from .subscriptions import SubscriptionType, Web3SubscriptionsManager
from .types import SilverbackID, SilverbackStartupState, TaskType
from .types import AppState, SilverbackID, TaskType
from .utils import async_wrap_iter, hexbytes_dict

settings = Settings()


class BaseRunner(ABC):
def __init__(self, app: SilverbackApp, *args, max_exceptions: int = 3, **kwargs):
def __init__(
self,
app: SilverbackApp,
*args,
max_exceptions: int = 3,
recorder: Optional[BaseRecorder] = None,
**kwargs,
):
self.app = app
self.recorder = recorder

self.max_exceptions = max_exceptions
self.exceptions = 0
self.last_block_seen = 0
self.last_block_processed = 0
self.recorder: Optional[BaseRecorder] = None
self.ident = SilverbackID.from_settings(settings)

def _handle_result(self, result: TaskiqResult):
if result.is_err:
self.exceptions += 1
ecosystem_name, network_name = app.network_choice.split(":")
self.identifier = SilverbackID(
name=app.name,
ecosystem=ecosystem_name,
network=network_name,
)

logger.info(f"Using {self.__class__.__name__}: max_exceptions={self.max_exceptions}")

async def _handle_task(self, task: AsyncTaskiqTask):
result = await task.wait_result()

if self.recorder:
await self.recorder.add_result(TaskResult.from_taskiq(result))

else:
if not result.is_err:
# NOTE: Reset exception counter
self.exceptions = 0
return

if self.exceptions > self.max_exceptions:
raise Halt() from result.error
self.exceptions += 1

if self.exceptions > self.max_exceptions or isinstance(result.error, Halt):
result.raise_for_error()

async def _checkpoint(
self, last_block_seen: int = 0, last_block_processed: int = 0
) -> Tuple[int, int]:
self,
last_block_seen: Optional[int] = None,
last_block_processed: Optional[int] = None,
):
"""Set latest checkpoint block number"""
if (
last_block_seen > self.last_block_seen
or last_block_processed > self.last_block_processed
):
logger.debug(
(
f"Checkpoint block [seen={self.last_block_seen}, "
f"procssed={self.last_block_processed}]"
)
assert self.app.state, f"{self.__class__.__name__}.run() not triggered."

logger.debug(
(
f"Checkpoint block [seen={self.app.state.last_block_seen}, "
f"procssed={self.app.state.last_block_processed}]"
)
self.last_block_seen = max(last_block_seen, self.last_block_seen)
self.last_block_processed = max(last_block_processed, self.last_block_processed)
)

if self.recorder:
try:
await self.recorder.set_state(
self.ident, self.last_block_seen, self.last_block_processed
)
except Exception as err:
logger.error(f"Error settings state: {err}")
if last_block_seen:
self.app.state.last_block_seen = last_block_seen
if last_block_processed:
self.app.state.last_block_processed = last_block_processed

if self.recorder:
try:
await self.recorder.set_state(self.app.state)

return self.last_block_seen, self.last_block_processed
except Exception as err:
logger.error(f"Error setting state: {err}")

@abstractmethod
async def _block_task(self, block_handler: AsyncTaskiqDecoratedTask):
Expand All @@ -93,14 +108,14 @@ async def run(self):
Raises:
:class:`~silverback.exceptions.Halt`: If there are no configured tasks to execute.
"""
self.recorder = settings.get_recorder()
# Initialize recorder (if available) and fetch state if app has been run previously
if self.recorder and (startup_state := (await self.recorder.init(app_id=self.identifier))):
self.app.state = startup_state

if self.recorder:
boot_state = await self.recorder.get_state(self.ident)
if boot_state:
self.last_block_seen = boot_state.last_block_seen
self.last_block_processed = boot_state.last_block_processed
else: # use empty state
self.app.state = AppState(last_block_seen=-1, last_block_processed=-1)

# Initialize broker (run worker startup events)
await self.app.broker.startup()

# Execute Silverback startup task before we init the rest
Expand Down Expand Up @@ -144,7 +159,6 @@ class WebsocketRunner(BaseRunner, ManagerAccessMixin):

def __init__(self, app: SilverbackApp, *args, **kwargs):
super().__init__(app, *args, **kwargs)
logger.info(f"Using {self.__class__.__name__}: max_exceptions={self.max_exceptions}")

# Check for websocket support
if not (ws_uri := app.chain_manager.provider.ws_uri):
Expand All @@ -159,16 +173,9 @@ async def _block_task(self, block_handler: AsyncTaskiqDecoratedTask):
async for raw_block in self.subscriptions.get_subscription_data(sub_id):
block = self.provider.network.ecosystem.decode_block(hexbytes_dict(raw_block))

if block.number is not None:
await self._checkpoint(last_block_seen=block.number)

block_task = await block_handler.kiq(raw_block)
result = await block_task.wait_result()

self._handle_result(result)

if block.number is not None:
await self._checkpoint(last_block_processed=block.number)
await self._checkpoint(last_block_seen=block.number)
await self._handle_task(await block_handler.kiq(raw_block))
await self._checkpoint(last_block_processed=block.number)

async def _event_task(
self, contract_event: ContractEvent, event_handler: AsyncTaskiqDecoratedTask
Expand All @@ -182,7 +189,9 @@ async def _event_task(
address=contract_event.contract.address,
topics=["0x" + keccak(text=contract_event.abi.selector).hex()],
)
logger.debug(f"Handling '{contract_event.name}' events via {sub_id}")
logger.debug(
f"Handling '{contract_event.contract.address}:{contract_event.name}' logs via {sub_id}"
)

async for raw_event in self.subscriptions.get_subscription_data(sub_id):
event = next( # NOTE: `next` is okay since it only has one item
Expand All @@ -192,15 +201,9 @@ async def _event_task(
)
)

if event.block_number is not None:
await self._checkpoint(last_block_seen=event.block_number)

event_task = await event_handler.kiq(event)
result = await event_task.wait_result()
self._handle_result(result)

if event.block_number is not None:
await self._checkpoint(last_block_processed=event.block_number)
await self._checkpoint(last_block_seen=event.block_number)
await self._handle_task(await event_handler.kiq(event))
await self._checkpoint(last_block_processed=event.block_number)

async def run(self):
async with Web3SubscriptionsManager(self.ws_uri) as subscriptions:
Expand Down Expand Up @@ -235,15 +238,9 @@ async def _block_task(self, block_handler: AsyncTaskiqDecoratedTask):
async for block in async_wrap_iter(
chain.blocks.poll_blocks(start_block=start_block, new_block_timeout=new_block_timeout)
):
if block.number is not None:
await self._checkpoint(last_block_seen=block.number)

block_task = await block_handler.kiq(block)
result = await block_task.wait_result()
self._handle_result(result)

if block.number is not None:
await self._checkpoint(last_block_processed=block.number)
await self._checkpoint(last_block_seen=block.number)
await self._handle_task(await block_handler.kiq(block))
await self._checkpoint(last_block_processed=block.number)

async def _event_task(
self, contract_event: ContractEvent, event_handler: AsyncTaskiqDecoratedTask
Expand All @@ -265,12 +262,6 @@ async def _event_task(
async for event in async_wrap_iter(
contract_event.poll_logs(start_block=start_block, new_block_timeout=new_block_timeout)
):
if event.block_number is not None:
await self._checkpoint(last_block_seen=event.block_number)

event_task = await event_handler.kiq(event)
result = await event_task.wait_result()
self._handle_result(result)

if event.block_number is not None:
await self._checkpoint(last_block_processed=event.block_number)
await self._checkpoint(last_block_seen=event.block_number)
await self._handle_task(await event_handler.kiq(event))
await self._checkpoint(last_block_processed=event.block_number)

0 comments on commit d9495df

Please sign in to comment.