From e0c3dc7eb210a6aba120b3005d0b4bf6efd807a7 Mon Sep 17 00:00:00 2001 From: fubuloubu <3859395+fubuloubu@users.noreply.github.com> Date: Tue, 9 Apr 2024 17:04:05 -0400 Subject: [PATCH] refactor: use defaultdict instead of custom collection type --- silverback/application.py | 49 +++++++++++++++++++-------------------- silverback/exceptions.py | 7 ++++++ silverback/runner.py | 8 +++---- 3 files changed, 35 insertions(+), 29 deletions(-) diff --git a/silverback/application.py b/silverback/application.py index d1fba94b..7801aef3 100644 --- a/silverback/application.py +++ b/silverback/application.py @@ -1,4 +1,5 @@ import atexit +from collections import defaultdict from dataclasses import dataclass from datetime import timedelta from typing import Callable, Dict, Optional, Union @@ -10,35 +11,17 @@ from ape.utils import ManagerAccessMixin from taskiq import AsyncTaskiqDecoratedTask, TaskiqEvents -from .exceptions import InvalidContainerTypeError +from .exceptions import ContainerTypeMismatchError, InvalidContainerTypeError from .settings import Settings from .types import TaskType @dataclass -class Task: +class TaskData: container: Union[BlockContainer, ContractEvent, None] handler: AsyncTaskiqDecoratedTask -class TaskCollection(dict): - def insert(self, task_type: TaskType, task: Task): - if not isinstance(task_type, TaskType): - raise ValueError("Unexpected key type") - - elif not isinstance(task, Task): - raise ValueError("Unexpected value type") - - elif task_type is TaskType.NEW_BLOCKS and not isinstance(task.container, BlockContainer): - raise ValueError("Mismatch between key and value types") - - elif task_type is TaskType.EVENT_LOG and not isinstance(task.container, ContractEvent): - raise ValueError("Mismatch between key and value types") - - task_list = super().get(task_type) or [] - super().__setitem__(task_type, task_list + [task]) - - class SilverbackApp(ManagerAccessMixin): """ The application singleton. Must be initialized prior to use. @@ -77,7 +60,8 @@ def __init__(self, settings: Optional[Settings] = None): logger.info(f"Loading Silverback App with settings:\n {settings_str}") self.broker = settings.get_broker() - self.tasks = TaskCollection() + # NOTE: If no tasks registered yet, defaults to empty list instead of raising KeyError + self.tasks: defaultdict[TaskType, list[TaskData]] = defaultdict(list) self.poll_settings: Dict[str, Dict] = {} atexit.register(self.network.__exit__, None, None, None) @@ -102,15 +86,30 @@ def broker_task_decorator( task_type: TaskType, container: Union[BlockContainer, ContractEvent, None] = None, ): + if ( + (task_type is TaskType.NEW_BLOCKS and not isinstance(container, BlockContainer)) + or (task_type is TaskType.EVENT_LOG and not isinstance(container, ContractEvent)) + or ( + task_type + not in ( + TaskType.NEW_BLOCKS, + TaskType.EVENT_LOG, + ) + and container is not None + ) + ): + raise ContainerTypeMismatchError(task_type, container) + + # Register user function as task handler with our broker def add_taskiq_task(handler: Callable): - # TODO: Support generic registration - task = self.broker.register_task( + broker_task = self.broker.register_task( handler, task_name=handler.__name__, task_type=str(task_type), ) - self.tasks.insert(task_type, Task(container=container, handler=task)) - return task + + self.tasks[task_type].append(TaskData(container=container, handler=broker_task)) + return broker_task return add_taskiq_task diff --git a/silverback/exceptions.py b/silverback/exceptions.py index a4da1258..125e85a0 100644 --- a/silverback/exceptions.py +++ b/silverback/exceptions.py @@ -3,6 +3,8 @@ from ape.exceptions import ApeException from ape.logging import logger +from .types import TaskType + class ImportFromStringError(Exception): pass @@ -13,6 +15,11 @@ def __init__(self, container: Any): super().__init__(f"Invalid container type: {container.__class__}") +class ContainerTypeMismatchError(Exception): + def __init__(self, task_type: TaskType, container: Any): + super().__init__(f"Invalid container type for '{task_type}': {container.__class__}") + + class NoWebsocketAvailableError(Exception): def __init__(self): super().__init__( diff --git a/silverback/runner.py b/silverback/runner.py index 233f4656..1aa7a263 100644 --- a/silverback/runner.py +++ b/silverback/runner.py @@ -103,7 +103,7 @@ async def run(self): await self.app.broker.startup() # Execute Silverback startup task before we init the rest - for startup_task in self.app.tasks.get(TaskType.STARTUP): + for startup_task in self.app.tasks[TaskType.STARTUP]: task = await startup_task.handler.kiq( SilverbackStartupState( last_block_seen=self.last_block_seen, @@ -114,10 +114,10 @@ async def run(self): self._handle_result(result) tasks = [] - for task in self.app.tasks.get(TaskType.NEW_BLOCKS): + for task in self.app.tasks[TaskType.NEW_BLOCKS]: tasks.append(self._block_task(task.handler)) - for task in self.app.tasks.get(TaskType.EVENT_LOG): + for task in self.app.tasks[TaskType.EVENT_LOG]: tasks.append(self._event_task(task.container, task.handler)) if len(tasks) == 0: @@ -126,7 +126,7 @@ async def run(self): await asyncio.gather(*tasks) # Execute Silverback shutdown task before shutting down the broker - for shutdown_task in self.app.tasks.get(TaskType.SHUTDOWN): + for shutdown_task in self.app.tasks[TaskType.SHUTDOWN]: task = await shutdown_task.handler.kiq() result = self._handle_result(await task.wait_result())