Skip to content

Commit

Permalink
add locking and check phase transitions
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672840224
Change-Id: Ic9bf4bfb223b78fb08cf278821b61cfe5fb4dc39
  • Loading branch information
jagapiou authored and copybara-github committed Sep 10, 2024
1 parent f4d0804 commit 0a89136
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 29 deletions.
59 changes: 37 additions & 22 deletions concordia/agents/entity_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Mapping
import functools
import threading
import types
from typing import cast

Expand Down Expand Up @@ -64,7 +65,9 @@ def __init__(
"""
super().__init__()
self._agent_name = agent_name
self._phase = entity_component.Phase.INIT
self._control_lock = threading.Lock()
self._phase_lock = threading.Lock()
self._phase = entity_component.Phase.READY

self._act_component = act_component
self._act_component.set_entity(self)
Expand All @@ -86,7 +89,13 @@ def name(self) -> str:

@override
def get_phase(self) -> entity_component.Phase:
return self._phase
with self._phase_lock:
return self._phase

def _set_phase(self, phase: entity_component.Phase) -> None:
with self._phase_lock:
self._phase.check_successor(phase)
self._phase = phase

@override
def get_component(
Expand Down Expand Up @@ -125,31 +134,37 @@ def _parallel_call_(
def act(
self, action_spec: entity.ActionSpec = entity.DEFAULT_ACTION_SPEC
) -> str:
self._phase = entity_component.Phase.PRE_ACT
contexts = self._parallel_call_('pre_act', action_spec)
self._context_processor.pre_act(types.MappingProxyType(contexts))
action_attempt = self._act_component.get_action_attempt(
contexts, action_spec
)
with self._control_lock:
self._set_phase(entity_component.Phase.PRE_ACT)
contexts = self._parallel_call_('pre_act', action_spec)
self._context_processor.pre_act(types.MappingProxyType(contexts))
action_attempt = self._act_component.get_action_attempt(
contexts, action_spec
)

self._set_phase(entity_component.Phase.POST_ACT)
contexts = self._parallel_call_('post_act', action_attempt)
self._context_processor.post_act(contexts)

self._phase = entity_component.Phase.POST_ACT
contexts = self._parallel_call_('post_act', action_attempt)
self._context_processor.post_act(contexts)
self._set_phase(entity_component.Phase.UPDATE)
self._parallel_call_('update')

self._phase = entity_component.Phase.UPDATE
self._parallel_call_('update')
self._set_phase(entity_component.Phase.READY)

return action_attempt
return action_attempt

@override
def observe(self, observation: str) -> None:
self._phase = entity_component.Phase.PRE_OBSERVE
contexts = self._parallel_call_('pre_observe', observation)
self._context_processor.pre_observe(contexts)
with self._control_lock:
self._set_phase(entity_component.Phase.PRE_OBSERVE)
contexts = self._parallel_call_('pre_observe', observation)
self._context_processor.pre_observe(contexts)

self._set_phase(entity_component.Phase.POST_OBSERVE)
contexts = self._parallel_call_('post_observe')
self._context_processor.post_observe(contexts)

self._phase = entity_component.Phase.POST_OBSERVE
contexts = self._parallel_call_('post_observe')
self._context_processor.post_observe(contexts)
self._set_phase(entity_component.Phase.UPDATE)
self._parallel_call_('update')

self._phase = entity_component.Phase.UPDATE
self._parallel_call_('update')
self._set_phase(entity_component.Phase.READY)
40 changes: 33 additions & 7 deletions concordia/typing/entity_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
"""Base classes for Entity components."""

import abc
from collections.abc import Mapping
from collections.abc import Collection, Mapping
import enum
import functools
from typing import TypeVar

from concordia.typing import entity as entity_lib
Expand All @@ -31,9 +32,8 @@ class Phase(enum.Enum):
"""Phases of a component entity lifecycle.
Attributes:
INIT: The agent has just been created. No action has been requested nor
observation has been received. This can be followed by a call to `pre_act`
or `pre_observe`.
READY: The agent is ready to `observe` or `act`. This will be followed by
`PRE_ACT` or `PRE_OBSERVE`.
PRE_ACT: The agent has received a request to act. Components are being
requested for their action context. This will be followed by `POST_ACT`.
POST_ACT: The agent has just submitted an action attempt. Components are
Expand All @@ -44,16 +44,42 @@ class Phase(enum.Enum):
POST_OBSERVE: The agent has just observed. Components are given a chance to
provide context after processing the observation. This will be followed by
`UPDATE`.
UPDATE: The agent is about to update its internal state. This will be
followed by `PRE_ACT` or `PRE_OBSERVE`.
UPDATE: The agent is updating its internal state. This will be followed by
`READY`.
"""
INIT = enum.auto()
READY = enum.auto()
PRE_ACT = enum.auto()
POST_ACT = enum.auto()
PRE_OBSERVE = enum.auto()
POST_OBSERVE = enum.auto()
UPDATE = enum.auto()

@functools.cached_property
def successors(self) -> Collection["Phase"]:
"""Returns the phases which may follow the current phase."""
match self:
case Phase.READY:
return (Phase.PRE_ACT, Phase.PRE_OBSERVE)
case Phase.PRE_ACT:
return (Phase.POST_ACT,)
case Phase.POST_ACT:
return (Phase.UPDATE,)
case Phase.PRE_OBSERVE:
return (Phase.POST_OBSERVE,)
case Phase.POST_OBSERVE:
return (Phase.UPDATE,)
case Phase.UPDATE:
return (Phase.READY,)
case _:
raise NotImplementedError()

def check_successor(self, successor: "Phase") -> None:
"""Raises ValueError if successor is not a valid next phase."""
if successor not in self.successors:
raise ValueError(
f"The transition from {self} to {successor} is invalid."
)


class BaseComponent:
"""A base class for components."""
Expand Down

0 comments on commit 0a89136

Please sign in to comment.