From 6aafc8326b742637c2a74cbf09aa394895f86717 Mon Sep 17 00:00:00 2001 From: fubuloubu <3859395+fubuloubu@users.noreply.github.com> Date: Tue, 9 Apr 2024 18:23:29 -0400 Subject: [PATCH] refactor: convert to TaskType for better processing --- silverback/middlewares.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/silverback/middlewares.py b/silverback/middlewares.py index b9830c3f..44101d3f 100644 --- a/silverback/middlewares.py +++ b/silverback/middlewares.py @@ -56,10 +56,16 @@ def _create_label(self, message: TaskiqMessage) -> str: return message.task_name def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: - task_type = message.labels.pop("task_type", "") + if not (task_type := message.labels.pop("task_type")): + return message # Not a silverback task - # NOTE: Don't compare `str` to `TaskType` using `is` - if task_type == TaskType.NEW_BLOCKS: + try: + task_type = TaskType(task_type) + except ValueError: + return message # Not a silverback task + + # Add extra labels for our task to see what their source was + if task_type is TaskType.NEW_BLOCKS: # NOTE: Necessary because we don't know the exact block class message.args[0] = self.provider.network.ecosystem.decode_block( hexbytes_dict(message.args[0]) @@ -67,7 +73,7 @@ def pre_execute(self, message: TaskiqMessage) -> TaskiqMessage: message.labels["block_number"] = str(message.args[0].number) message.labels["block_hash"] = message.args[0].hash.hex() - elif task_type == TaskType.EVENT_LOG: + elif task_type is TaskType.EVENT_LOG: # NOTE: Just in case the user doesn't specify type as `ContractLog` message.args[0] = ContractLog.model_validate(message.args[0]) message.labels["block_number"] = str(message.args[0].block_number)