From 93597d9f57249b8aac3a70a369ecc2308c17f70d Mon Sep 17 00:00:00 2001 From: David Gardner Date: Mon, 23 Sep 2024 12:45:21 -0700 Subject: [PATCH] Fix casting of control messages --- .../stages/llm/llm_engine_stage.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/python/morpheus_llm/morpheus_llm/stages/llm/llm_engine_stage.py b/python/morpheus_llm/morpheus_llm/stages/llm/llm_engine_stage.py index ae9b0bd4d9..a85a40b791 100644 --- a/python/morpheus_llm/morpheus_llm/stages/llm/llm_engine_stage.py +++ b/python/morpheus_llm/morpheus_llm/stages/llm/llm_engine_stage.py @@ -15,6 +15,8 @@ import functools import logging import types +import typing +from collections import deque import mrc from mrc.core import operators as ops @@ -77,7 +79,22 @@ def _store_payload(self, message: ControlMessage) -> ControlMessage: message.set_metadata("llm_message_meta", message.payload()) return message - def _cast_to_cpp_control_message(self, message: ControlMessage, *, + def _copy_tasks_and_metadata(self, + src: ControlMessage, + dst: ControlMessage, + metadata: dict[str, typing.Any] = None): + if metadata is None: + metadata = src.get_metadata() + + for (key, value) in metadata.items(): + dst.set_metadata(key, value) + + tasks = src.get_tasks() + for (task, task_value) in tasks.items(): + for tv in task_value: + dst.add_task(task, tv) + + def _cast_to_cpp_control_message(self, py_message: ControlMessage, *, cpp_messages_lib: types.ModuleType) -> ControlMessage: """ LLMEngineStage does not contain a Python implementation, however it is capable of running in cpu-only mode. @@ -85,12 +102,10 @@ def _cast_to_cpp_control_message(self, message: ControlMessage, *, This is different than casting from the Python bindings for the C++ ControlMessage to a C++ ControlMessage. """ - cm = cpp_messages_lib.ControlMessage() - metadata = message.get_metadata() - for (key, value) in metadata.items(): - cm.set_metadata(key, value) + cpp_message = cpp_messages_lib.ControlMessage() + self._copy_tasks_and_metadata(py_message, cpp_message) - return cm + return cpp_message def _restore_payload(self, message: ControlMessage) -> ControlMessage: """ @@ -103,8 +118,8 @@ def _restore_payload(self, message: ControlMessage) -> ControlMessage: out_message = ControlMessage() out_message.payload(message_meta) - for (key, value) in metadata.items(): - out_message.set_metadata(key, value) + + self._copy_tasks_and_metadata(message, out_message, metadata=metadata) return out_message