Skip to content

Commit

Permalink
IWF-357: Add internal channel TypeStore (#70)
Browse files Browse the repository at this point in the history
* IWF-357: Add internal channel TypeStore

* IWF-357: Lint

* IWF-357: Fix

* IWF-357: Fix

* IWF-357: Fix

* IWF-357: Fix

* IWF-357: Fix

* IWF-357: Lint

* IWF-357: Add test

* IWF-357: Lint

* IWF-357: Change class name

* IWF-357: Address MR comments

* IWF-357: Lint

* IWF-357: Refactor

* IWF-357: Fix test
  • Loading branch information
lwolczynski authored Jan 7, 2025
1 parent b810adb commit a85eb32
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 45 deletions.
20 changes: 8 additions & 12 deletions iwf/command_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from dataclasses import dataclass
from typing import Any, Union

from iwf.errors import WorkflowDefinitionError
from iwf.errors import WorkflowDefinitionError, NotRegisteredError
from iwf.iwf_api.models import (
ChannelRequestStatus,
CommandResults as IdlCommandResults,
TimerStatus,
)
from iwf.iwf_api.types import Unset
from iwf.object_encoder import ObjectEncoder
from iwf.type_store import TypeStore


@dataclass
Expand Down Expand Up @@ -43,7 +44,7 @@ class CommandResults:

def from_idl_command_results(
idl_results: Union[Unset, IdlCommandResults],
internal_channel_types: dict[str, typing.Optional[type]],
internal_channel_types: TypeStore,
signal_channel_types: dict[str, typing.Optional[type]],
object_encoder: ObjectEncoder,
) -> CommandResults:
Expand All @@ -58,18 +59,13 @@ def from_idl_command_results(

if not isinstance(idl_results.inter_state_channel_results, Unset):
for inter in idl_results.inter_state_channel_results:
val_type = internal_channel_types.get(inter.channel_name)
if val_type is None:
# fallback to assume it's prefix
# TODO use is_prefix to implement like Java SDK
for name, t in internal_channel_types.items():
if inter.channel_name.startswith(name):
val_type = t
break
if val_type is None:

try:
val_type = internal_channel_types.get_type(inter.channel_name)
except NotRegisteredError as exception:
raise WorkflowDefinitionError(
"internal channel is not registered: " + inter.channel_name
)
) from exception

encoded = object_encoder.decode(inter.value, val_type)

Expand Down
31 changes: 12 additions & 19 deletions iwf/communication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Optional, Union

from iwf.errors import WorkflowDefinitionError
from iwf.errors import WorkflowDefinitionError, NotRegisteredError
from iwf.iwf_api.models import (
EncodedObject,
InterStateChannelPublishing,
Expand All @@ -9,10 +9,11 @@
)
from iwf.object_encoder import ObjectEncoder
from iwf.state_movement import StateMovement
from iwf.type_store import TypeStore


class Communication:
_internal_channel_type_store: dict[str, Optional[type]]
_internal_channel_type_store: TypeStore
_signal_channel_type_store: dict[str, Optional[type]]
_object_encoder: ObjectEncoder
_to_publish_internal_channel: dict[str, list[EncodedObject]]
Expand All @@ -22,7 +23,7 @@ class Communication:

def __init__(
self,
internal_channel_type_store: dict[str, Optional[type]],
internal_channel_type_store: TypeStore,
signal_channel_type_store: dict[str, Optional[type]],
object_encoder: ObjectEncoder,
internal_channel_infos: Optional[WorkflowWorkerRpcRequestInternalChannelInfos],
Expand All @@ -47,17 +48,12 @@ def trigger_state_execution(self, state: Union[str, type], state_input: Any = No
self._state_movements.append(movement)

def publish_to_internal_channel(self, channel_name: str, value: Any = None):
registered_type = self._internal_channel_type_store.get(channel_name)

if registered_type is None:
for name, t in self._internal_channel_type_store.items():
if channel_name.startswith(name):
registered_type = t

if registered_type is None:
try:
registered_type = self._internal_channel_type_store.get_type(channel_name)
except NotRegisteredError as exception:
raise WorkflowDefinitionError(
f"InternalChannel channel_name is not defined {channel_name}"
)
) from exception

if (
value is not None
Expand All @@ -84,14 +80,11 @@ def get_to_trigger_state_movements(self) -> list[StateMovement]:
return self._state_movements

def get_internal_channel_size(self, channel_name):
registered_type = self._internal_channel_type_store.get(channel_name)

if registered_type is None:
for name, t in self._internal_channel_type_store.items():
if channel_name.startswith(name):
registered_type = t
is_type_registered = self._internal_channel_type_store.is_valid_name_or_prefix(
channel_name
)

if registered_type is None:
if is_type_registered is False:
raise WorkflowDefinitionError(
f"InternalChannel channel_name is not defined {channel_name}"
)
Expand Down
4 changes: 4 additions & 0 deletions iwf/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ class InvalidArgumentError(Exception):
pass


class NotRegisteredError(Exception):
pass


class HttpError(RuntimeError):
def __init__(self, status: int, err_resp: ErrorResponse):
super().__init__(err_resp.detail)
Expand Down
19 changes: 12 additions & 7 deletions iwf/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from iwf.errors import InvalidArgumentError, WorkflowDefinitionError
from iwf.persistence_schema import PersistenceFieldType
from iwf.rpc import RPCInfo
from iwf.type_store import TypeStore, Type
from iwf.workflow import ObjectWorkflow, get_workflow_type
from iwf.workflow_state import WorkflowState, get_state_id

Expand All @@ -12,7 +13,7 @@ class Registry:
_workflow_store: dict[str, ObjectWorkflow]
_starting_state_store: dict[str, WorkflowState]
_state_store: dict[str, dict[str, WorkflowState]]
_internal_channel_type_store: dict[str, dict[str, Optional[type]]]
_internal_channel_type_store: dict[str, TypeStore]
_signal_channel_type_store: dict[str, dict[str, Optional[type]]]
_data_attribute_types: dict[str, dict[str, Optional[type]]]
_rpc_infos: dict[str, dict[str, RPCInfo]]
Expand Down Expand Up @@ -63,7 +64,7 @@ def get_workflow_state_with_check(
def get_state_store(self, wf_type: str) -> dict[str, WorkflowState]:
return self._state_store[wf_type]

def get_internal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
def get_internal_channel_type_store(self, wf_type: str) -> TypeStore:
return self._internal_channel_type_store[wf_type]

def get_signal_channel_types(self, wf_type: str) -> dict[str, Optional[type]]:
Expand All @@ -83,13 +84,17 @@ def _register_workflow_type(self, wf: ObjectWorkflow):

def _register_internal_channels(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
types: dict[str, Optional[type]] = {}

if wf_type not in self._internal_channel_type_store:
self._internal_channel_type_store[wf_type] = TypeStore(
Type.INTERNAL_CHANNEL
)

for method in wf.get_communication_schema().communication_methods:
if method.method_type == CommunicationMethodType.InternalChannel:
types[method.name] = method.value_type
# TODO use is_prefix to implement like Java SDK
#
self._internal_channel_type_store[wf_type] = types
self._internal_channel_type_store[wf_type].add_internal_channel_def(
method
)

def _register_signal_channels(self, wf: ObjectWorkflow):
wf_type = get_workflow_type(wf)
Expand Down
10 changes: 6 additions & 4 deletions iwf/tests/test_internal_channel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import inspect
import time
import unittest

from iwf.client import Client
from iwf.command_request import CommandRequest, InternalChannelCommand
Expand Down Expand Up @@ -133,8 +134,9 @@ def get_communication_schema(self) -> CommunicationSchema:
client = Client(registry)


def test_internal_channel_workflow():
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"
class TestConditionalComplete(unittest.TestCase):
def test_internal_channel_workflow(self):
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"

client.start_workflow(InternalChannelWorkflow, wf_id, 100, None)
client.get_simple_workflow_result_with_wait(wf_id, None)
client.start_workflow(InternalChannelWorkflow, wf_id, 100, None)
client.get_simple_workflow_result_with_wait(wf_id, None)
123 changes: 123 additions & 0 deletions iwf/tests/test_internal_channel_with_no_prefix_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import inspect
import time
import unittest

from iwf.client import Client
from iwf.command_request import CommandRequest, InternalChannelCommand
from iwf.command_results import CommandResults
from iwf.communication import Communication
from iwf.communication_schema import CommunicationMethod, CommunicationSchema
from iwf.persistence import Persistence
from iwf.state_decision import StateDecision
from iwf.state_schema import StateSchema
from iwf.tests.worker_server import registry
from iwf.workflow import ObjectWorkflow
from iwf.workflow_context import WorkflowContext
from iwf.workflow_state import T, WorkflowState

internal_channel_name = "internal-channel-1"

test_non_prefix_channel_name = "test-channel-"
test_non_prefix_channel_name_with_suffix = test_non_prefix_channel_name + "abc"


class InitState(WorkflowState[None]):
def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
return StateDecision.multi_next_states(
WaitAnyWithPublishState, WaitAllThenPublishState
)


class WaitAnyWithPublishState(WorkflowState[None]):
def wait_until(
self,
ctx: WorkflowContext,
input: T,
persistence: Persistence,
communication: Communication,
) -> CommandRequest:
# Trying to publish to a non-existing channel; this would only work if test_channel_name_non_prefix was defined as a prefix channel
communication.publish_to_internal_channel(
test_non_prefix_channel_name_with_suffix, "str-value-for-prefix"
)
return CommandRequest.for_any_command_completed(
InternalChannelCommand.by_name(internal_channel_name),
)

def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
return StateDecision.graceful_complete_workflow()


class WaitAllThenPublishState(WorkflowState[None]):
def wait_until(
self,
ctx: WorkflowContext,
input: T,
persistence: Persistence,
communication: Communication,
) -> CommandRequest:
return CommandRequest.for_all_command_completed(
InternalChannelCommand.by_name(test_non_prefix_channel_name),
)

def execute(
self,
ctx: WorkflowContext,
input: T,
command_results: CommandResults,
persistence: Persistence,
communication: Communication,
) -> StateDecision:
communication.publish_to_internal_channel(internal_channel_name, None)
return StateDecision.dead_end


class InternalChannelWorkflowWithNoPrefixChannel(ObjectWorkflow):
def get_workflow_states(self) -> StateSchema:
return StateSchema.with_starting_state(
InitState(), WaitAnyWithPublishState(), WaitAllThenPublishState()
)

def get_communication_schema(self) -> CommunicationSchema:
return CommunicationSchema.create(
CommunicationMethod.internal_channel_def(internal_channel_name, type(None)),
# Defining a standard channel (non-prefix) to make sure messages to the channel with a suffix added will not be accepted
CommunicationMethod.internal_channel_def(test_non_prefix_channel_name, str),
)


wf = InternalChannelWorkflowWithNoPrefixChannel()
registry.add_workflow(wf)
client = Client(registry)


class TestInternalChannelWithNoPrefix(unittest.TestCase):
def test_internal_channel_workflow_with_no_prefix_channel(self):
wf_id = f"{inspect.currentframe().f_code.co_name}-{time.time_ns()}"

client.start_workflow(
InternalChannelWorkflowWithNoPrefixChannel, wf_id, 5, None
)

with self.assertRaises(Exception) as context:
client.wait_for_workflow_completion(wf_id, None)

self.assertIn("FAILED", context.exception.workflow_status)
self.assertIn(
f"WorkerExecutionError: InternalChannel channel_name is not defined {test_non_prefix_channel_name_with_suffix}",
context.exception.error_message,
)
68 changes: 68 additions & 0 deletions iwf/type_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from typing import Optional
from enum import Enum

from iwf.communication_schema import CommunicationMethod
from iwf.errors import WorkflowDefinitionError, NotRegisteredError


class Type(Enum):
INTERNAL_CHANNEL = 1
# TODO: extend to other types
# DATA_ATTRIBUTE = 2
# SIGNAL_CHANNEL = 3


class TypeStore:
_class_type: Type
_name_to_type_store: dict[str, Optional[type]]
_prefix_to_type_store: dict[str, Optional[type]]

def __init__(self, class_type: Type):
self._class_type = class_type
self._name_to_type_store = dict()
self._prefix_to_type_store = dict()

def is_valid_name_or_prefix(self, name: str) -> bool:
t = self._do_get_type(name)
return t is not None

def get_type(self, name: str) -> type:
t = self._do_get_type(name)

if t is None:
raise NotRegisteredError(f"{self._class_type} not registered: {name}")

return t

def add_internal_channel_def(self, obj: CommunicationMethod):
if self._class_type != Type.INTERNAL_CHANNEL:
raise ValueError(
f"Cannot add internal channel definition to {self._class_type}"
)
self._do_add_to_store(obj.is_prefix, obj.name, obj.value_type)

def _do_get_type(self, name: str) -> Optional[type]:
if name in self._name_to_type_store:
return self._name_to_type_store[name]

prefixes = self._prefix_to_type_store.keys()

first = next((prefix for prefix in prefixes if name.startswith(prefix)), None)

if first is None:
return None

return self._prefix_to_type_store.get(first, None)

def _do_add_to_store(self, is_prefix: bool, name: str, t: Optional[type]):
if is_prefix:
store = self._prefix_to_type_store
else:
store = self._name_to_type_store

if name in store:
raise WorkflowDefinitionError(
f"{self._class_type} name/prefix {name} already exists"
)

store[name] = t
Loading

0 comments on commit a85eb32

Please sign in to comment.