Skip to content

Commit

Permalink
refactor: use defaultdict instead of custom collection type
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed Apr 9, 2024
1 parent 42247be commit e0c3dc7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 29 deletions.
49 changes: 24 additions & 25 deletions silverback/application.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import atexit
from collections import defaultdict
from dataclasses import dataclass
from datetime import timedelta
from typing import Callable, Dict, Optional, Union
Expand All @@ -10,35 +11,17 @@
from ape.utils import ManagerAccessMixin
from taskiq import AsyncTaskiqDecoratedTask, TaskiqEvents

from .exceptions import InvalidContainerTypeError
from .exceptions import ContainerTypeMismatchError, InvalidContainerTypeError
from .settings import Settings
from .types import TaskType


@dataclass
class Task:
class TaskData:
container: Union[BlockContainer, ContractEvent, None]
handler: AsyncTaskiqDecoratedTask


class TaskCollection(dict):
def insert(self, task_type: TaskType, task: Task):
if not isinstance(task_type, TaskType):
raise ValueError("Unexpected key type")

elif not isinstance(task, Task):
raise ValueError("Unexpected value type")

elif task_type is TaskType.NEW_BLOCKS and not isinstance(task.container, BlockContainer):
raise ValueError("Mismatch between key and value types")

elif task_type is TaskType.EVENT_LOG and not isinstance(task.container, ContractEvent):
raise ValueError("Mismatch between key and value types")

task_list = super().get(task_type) or []
super().__setitem__(task_type, task_list + [task])


class SilverbackApp(ManagerAccessMixin):
"""
The application singleton. Must be initialized prior to use.
Expand Down Expand Up @@ -77,7 +60,8 @@ def __init__(self, settings: Optional[Settings] = None):
logger.info(f"Loading Silverback App with settings:\n {settings_str}")

self.broker = settings.get_broker()
self.tasks = TaskCollection()
# NOTE: If no tasks registered yet, defaults to empty list instead of raising KeyError
self.tasks: defaultdict[TaskType, list[TaskData]] = defaultdict(list)
self.poll_settings: Dict[str, Dict] = {}

atexit.register(self.network.__exit__, None, None, None)
Expand All @@ -102,15 +86,30 @@ def broker_task_decorator(
task_type: TaskType,
container: Union[BlockContainer, ContractEvent, None] = None,
):
if (
(task_type is TaskType.NEW_BLOCKS and not isinstance(container, BlockContainer))
or (task_type is TaskType.EVENT_LOG and not isinstance(container, ContractEvent))
or (
task_type
not in (
TaskType.NEW_BLOCKS,
TaskType.EVENT_LOG,
)
and container is not None
)
):
raise ContainerTypeMismatchError(task_type, container)

# Register user function as task handler with our broker
def add_taskiq_task(handler: Callable):
# TODO: Support generic registration
task = self.broker.register_task(
broker_task = self.broker.register_task(
handler,
task_name=handler.__name__,
task_type=str(task_type),
)
self.tasks.insert(task_type, Task(container=container, handler=task))
return task

self.tasks[task_type].append(TaskData(container=container, handler=broker_task))
return broker_task

return add_taskiq_task

Expand Down
7 changes: 7 additions & 0 deletions silverback/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from ape.exceptions import ApeException
from ape.logging import logger

from .types import TaskType


class ImportFromStringError(Exception):
pass
Expand All @@ -13,6 +15,11 @@ def __init__(self, container: Any):
super().__init__(f"Invalid container type: {container.__class__}")


class ContainerTypeMismatchError(Exception):
def __init__(self, task_type: TaskType, container: Any):
super().__init__(f"Invalid container type for '{task_type}': {container.__class__}")


class NoWebsocketAvailableError(Exception):
def __init__(self):
super().__init__(
Expand Down
8 changes: 4 additions & 4 deletions silverback/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def run(self):
await self.app.broker.startup()

# Execute Silverback startup task before we init the rest
for startup_task in self.app.tasks.get(TaskType.STARTUP):
for startup_task in self.app.tasks[TaskType.STARTUP]:
task = await startup_task.handler.kiq(
SilverbackStartupState(
last_block_seen=self.last_block_seen,
Expand All @@ -114,10 +114,10 @@ async def run(self):
self._handle_result(result)

tasks = []
for task in self.app.tasks.get(TaskType.NEW_BLOCKS):
for task in self.app.tasks[TaskType.NEW_BLOCKS]:
tasks.append(self._block_task(task.handler))

for task in self.app.tasks.get(TaskType.EVENT_LOG):
for task in self.app.tasks[TaskType.EVENT_LOG]:
tasks.append(self._event_task(task.container, task.handler))

if len(tasks) == 0:
Expand All @@ -126,7 +126,7 @@ async def run(self):
await asyncio.gather(*tasks)

# Execute Silverback shutdown task before shutting down the broker
for shutdown_task in self.app.tasks.get(TaskType.SHUTDOWN):
for shutdown_task in self.app.tasks[TaskType.SHUTDOWN]:
task = await shutdown_task.handler.kiq()
result = self._handle_result(await task.wait_result())

Expand Down

0 comments on commit e0c3dc7

Please sign in to comment.