From 0a89136a29109ab4fcccea13e5d745dcba7ae033 Mon Sep 17 00:00:00 2001 From: John Agapiou Date: Tue, 10 Sep 2024 00:39:05 -0700 Subject: [PATCH] add locking and check phase transitions PiperOrigin-RevId: 672840224 Change-Id: Ic9bf4bfb223b78fb08cf278821b61cfe5fb4dc39 --- concordia/agents/entity_agent.py | 59 +++++++++++++++++----------- concordia/typing/entity_component.py | 40 +++++++++++++++---- 2 files changed, 70 insertions(+), 29 deletions(-) diff --git a/concordia/agents/entity_agent.py b/concordia/agents/entity_agent.py index 3756360b..b8cc2692 100644 --- a/concordia/agents/entity_agent.py +++ b/concordia/agents/entity_agent.py @@ -16,6 +16,7 @@ from collections.abc import Mapping import functools +import threading import types from typing import cast @@ -64,7 +65,9 @@ def __init__( """ super().__init__() self._agent_name = agent_name - self._phase = entity_component.Phase.INIT + self._control_lock = threading.Lock() + self._phase_lock = threading.Lock() + self._phase = entity_component.Phase.READY self._act_component = act_component self._act_component.set_entity(self) @@ -86,7 +89,13 @@ def name(self) -> str: @override def get_phase(self) -> entity_component.Phase: - return self._phase + with self._phase_lock: + return self._phase + + def _set_phase(self, phase: entity_component.Phase) -> None: + with self._phase_lock: + self._phase.check_successor(phase) + self._phase = phase @override def get_component( @@ -125,31 +134,37 @@ def _parallel_call_( def act( self, action_spec: entity.ActionSpec = entity.DEFAULT_ACTION_SPEC ) -> str: - self._phase = entity_component.Phase.PRE_ACT - contexts = self._parallel_call_('pre_act', action_spec) - self._context_processor.pre_act(types.MappingProxyType(contexts)) - action_attempt = self._act_component.get_action_attempt( - contexts, action_spec - ) + with self._control_lock: + self._set_phase(entity_component.Phase.PRE_ACT) + contexts = self._parallel_call_('pre_act', action_spec) + self._context_processor.pre_act(types.MappingProxyType(contexts)) + action_attempt = self._act_component.get_action_attempt( + contexts, action_spec + ) + + self._set_phase(entity_component.Phase.POST_ACT) + contexts = self._parallel_call_('post_act', action_attempt) + self._context_processor.post_act(contexts) - self._phase = entity_component.Phase.POST_ACT - contexts = self._parallel_call_('post_act', action_attempt) - self._context_processor.post_act(contexts) + self._set_phase(entity_component.Phase.UPDATE) + self._parallel_call_('update') - self._phase = entity_component.Phase.UPDATE - self._parallel_call_('update') + self._set_phase(entity_component.Phase.READY) - return action_attempt + return action_attempt @override def observe(self, observation: str) -> None: - self._phase = entity_component.Phase.PRE_OBSERVE - contexts = self._parallel_call_('pre_observe', observation) - self._context_processor.pre_observe(contexts) + with self._control_lock: + self._set_phase(entity_component.Phase.PRE_OBSERVE) + contexts = self._parallel_call_('pre_observe', observation) + self._context_processor.pre_observe(contexts) + + self._set_phase(entity_component.Phase.POST_OBSERVE) + contexts = self._parallel_call_('post_observe') + self._context_processor.post_observe(contexts) - self._phase = entity_component.Phase.POST_OBSERVE - contexts = self._parallel_call_('post_observe') - self._context_processor.post_observe(contexts) + self._set_phase(entity_component.Phase.UPDATE) + self._parallel_call_('update') - self._phase = entity_component.Phase.UPDATE - self._parallel_call_('update') + self._set_phase(entity_component.Phase.READY) diff --git a/concordia/typing/entity_component.py b/concordia/typing/entity_component.py index 678f6eb8..2af4e904 100644 --- a/concordia/typing/entity_component.py +++ b/concordia/typing/entity_component.py @@ -15,8 +15,9 @@ """Base classes for Entity components.""" import abc -from collections.abc import Mapping +from collections.abc import Collection, Mapping import enum +import functools from typing import TypeVar from concordia.typing import entity as entity_lib @@ -31,9 +32,8 @@ class Phase(enum.Enum): """Phases of a component entity lifecycle. Attributes: - INIT: The agent has just been created. No action has been requested nor - observation has been received. This can be followed by a call to `pre_act` - or `pre_observe`. + READY: The agent is ready to `observe` or `act`. This will be followed by + `PRE_ACT` or `PRE_OBSERVE`. PRE_ACT: The agent has received a request to act. Components are being requested for their action context. This will be followed by `POST_ACT`. POST_ACT: The agent has just submitted an action attempt. Components are @@ -44,16 +44,42 @@ class Phase(enum.Enum): POST_OBSERVE: The agent has just observed. Components are given a chance to provide context after processing the observation. This will be followed by `UPDATE`. - UPDATE: The agent is about to update its internal state. This will be - followed by `PRE_ACT` or `PRE_OBSERVE`. + UPDATE: The agent is updating its internal state. This will be followed by + `READY`. """ - INIT = enum.auto() + READY = enum.auto() PRE_ACT = enum.auto() POST_ACT = enum.auto() PRE_OBSERVE = enum.auto() POST_OBSERVE = enum.auto() UPDATE = enum.auto() + @functools.cached_property + def successors(self) -> Collection["Phase"]: + """Returns the phases which may follow the current phase.""" + match self: + case Phase.READY: + return (Phase.PRE_ACT, Phase.PRE_OBSERVE) + case Phase.PRE_ACT: + return (Phase.POST_ACT,) + case Phase.POST_ACT: + return (Phase.UPDATE,) + case Phase.PRE_OBSERVE: + return (Phase.POST_OBSERVE,) + case Phase.POST_OBSERVE: + return (Phase.UPDATE,) + case Phase.UPDATE: + return (Phase.READY,) + case _: + raise NotImplementedError() + + def check_successor(self, successor: "Phase") -> None: + """Raises ValueError if successor is not a valid next phase.""" + if successor not in self.successors: + raise ValueError( + f"The transition from {self} to {successor} is invalid." + ) + class BaseComponent: """A base class for components."""