From 2d03833a90187711b56bb313ebebe0352c111c4e Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 19:06:18 +0000 Subject: [PATCH 01/27] Add agent communication wrapper. --- .../gymnasium/wrappers/agent_communication.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 smarts/env/gymnasium/wrappers/agent_communication.py diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py new file mode 100644 index 0000000000..92796561a2 --- /dev/null +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -0,0 +1,61 @@ +import gymnasium as gym +from typing import Any, Dict, NamedTuple, Optional, Tuple +from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 + +import numpy as np + + +class Transmitter(NamedTuple): + pass + + +class Receiver(NamedTuple): + pass + + +class MessagePasser(gym.Wrapper): + def __init__(self, env: gym.Env): + super().__init__(env) + assert isinstance(env, HiWayEnvV1) + o_action_space: gym.spaces.Dict = self.env.action_space + msg_space = (gym.spaces.Box(low=0, high=256, shape=(125000,), dtype=np.uint8),) + self.action_space = gym.spaces.Dict( + { + a_id: gym.spaces.Tuple( + ( + action_space, + msg_space, + ) + ) + for a_id, action_space in o_action_space.spaces.items() + } + ) + o_observation_space: gym.spaces.Dict = self.env.observation_space + self.observation_space = gym.spaces.Dict( + { + "agents": o_observation_space, + "messages": gym.spaces.Dict( + {a_id: msg_space for a_id in o_action_space} + ), + } + ) + + def step(self, actions): + std_actions = {} + msgs = {} + for a_id, ma in actions.items(): + std_actions[a_id] = ma[0] + msgs[a_id] = ma[1] + + obs, rewards, terms, truncs, infos = self.env.step(std_actions) + obs_with_msgs = { + "agents": obs, + "messages": {a_id: msgs for a_id, ob in obs.items()}, + } + return obs_with_msgs, rewards, terms, truncs, infos + + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> Tuple[Any, Dict[str, Any]]: + obs, info = super().reset(seed=seed, options=options) + return {"agents": obs, "messages": self.observation_space["messages"].sample()} From 05e1f872530e97666eb32965190ab6d262b55e22 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 19:28:14 +0000 Subject: [PATCH 02/27] Add docstrings. --- .../gymnasium/wrappers/agent_communication.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 92796561a2..94df64e4f3 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -6,19 +6,29 @@ class Transmitter(NamedTuple): + """A configuration utility to set up agents to transmit messages.""" + pass class Receiver(NamedTuple): + """A configuratoin utility to set up agent to receive messages.""" + pass class MessagePasser(gym.Wrapper): - def __init__(self, env: gym.Env): + """This wrapper augments the observations and actions to require passing messages from agents.""" + + def __init__(self, env: gym.Env, message_size_bytes=125000): super().__init__(env) assert isinstance(env, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space - msg_space = (gym.spaces.Box(low=0, high=256, shape=(125000,), dtype=np.uint8),) + msg_space = ( + gym.spaces.Box( + low=0, high=256, shape=(message_size_bytes,), dtype=np.uint8 + ), + ) self.action_space = gym.spaces.Dict( { a_id: gym.spaces.Tuple( @@ -40,17 +50,17 @@ def __init__(self, env: gym.Env): } ) - def step(self, actions): + def step(self, action): std_actions = {} msgs = {} - for a_id, ma in actions.items(): + for a_id, ma in action.items(): std_actions[a_id] = ma[0] msgs[a_id] = ma[1] obs, rewards, terms, truncs, infos = self.env.step(std_actions) obs_with_msgs = { "agents": obs, - "messages": {a_id: msgs for a_id, ob in obs.items()}, + "messages": {a_id: msgs for a_id in obs}, } return obs_with_msgs, rewards, terms, truncs, infos From 9f8d428cdb4104338ffabb2dc16f4e5bfc10c4bf Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 19:36:05 +0000 Subject: [PATCH 03/27] Add header. --- .../gymnasium/wrappers/agent_communication.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 94df64e4f3..33fa95e248 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -1,6 +1,25 @@ -import gymnasium as gym +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. from typing import Any, Dict, NamedTuple, Optional, Tuple -from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 import numpy as np From eacbeaafe6ec9b5ed3e792b5672526e557e90495 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 19:36:34 +0000 Subject: [PATCH 04/27] Update configuration. --- smarts/env/gymnasium/wrappers/agent_communication.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 33fa95e248..dd2dc8989f 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -21,8 +21,11 @@ # THE SOFTWARE. from typing import Any, Dict, NamedTuple, Optional, Tuple +import gymnasium as gym import numpy as np +from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 + class Transmitter(NamedTuple): """A configuration utility to set up agents to transmit messages.""" @@ -39,14 +42,12 @@ class Receiver(NamedTuple): class MessagePasser(gym.Wrapper): """This wrapper augments the observations and actions to require passing messages from agents.""" - def __init__(self, env: gym.Env, message_size_bytes=125000): + def __init__(self, env: gym.Env, message_size=125000): super().__init__(env) assert isinstance(env, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space msg_space = ( - gym.spaces.Box( - low=0, high=256, shape=(message_size_bytes,), dtype=np.uint8 - ), + gym.spaces.Box(low=0, high=256, shape=(message_size,), dtype=np.uint8), ) self.action_space = gym.spaces.Dict( { From 7efa7248d4cdea10047f95cfcee44f018dc2aa1b Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 20:05:26 +0000 Subject: [PATCH 05/27] Mock up interface for message --- .../gymnasium/wrappers/agent_communication.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index dd2dc8989f..c91c012222 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -1,17 +1,17 @@ # MIT License -# +# # Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. -# +# # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: -# +# # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. -# +# # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE @@ -19,7 +19,8 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from typing import Any, Dict, NamedTuple, Optional, Tuple +from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple +from enum import Enum import gymnasium as gym import numpy as np @@ -27,16 +28,33 @@ from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 +class Sensitivity(Enum): + LOW = 0 + STANDARD = 1 + HIGH = 2 + + class Transmitter(NamedTuple): """A configuration utility to set up agents to transmit messages.""" - pass + # TODO MTA: move this to agent interface. + + bands: Sequence[int] + range: float + mission: bool + position: bool + heading: bool + breaking: bool + throttle: bool + steering: bool class Receiver(NamedTuple): """A configuratoin utility to set up agent to receive messages.""" - pass + # TODO MTA: move this to agent interface. + bands: Sequence[int] + sensitivity: Sensitivity = Sensitivity.HIGH class MessagePasser(gym.Wrapper): From b2a4c8d58209e80adb966d5a606884e35bfb04a5 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 20:09:47 +0000 Subject: [PATCH 06/27] Add custom block to interface option. --- smarts/env/gymnasium/wrappers/agent_communication.py | 1 + 1 file changed, 1 insertion(+) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index c91c012222..9f2fe33115 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -47,6 +47,7 @@ class Transmitter(NamedTuple): breaking: bool throttle: bool steering: bool + custom_message_blob_size: int class Receiver(NamedTuple): From d2081ba92147b37badcfba62345a0743c1128cfd Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 23 Feb 2023 18:16:09 +0000 Subject: [PATCH 07/27] Bring wrapper close to completion. --- .../gymnasium/wrappers/agent_communication.py | 235 +++++++++++++++--- 1 file changed, 201 insertions(+), 34 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 9f2fe33115..0ef5f55c28 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -19,8 +19,11 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple -from enum import Enum +import enum +from collections import defaultdict +from enum import Enum, IntFlag, unique +from functools import lru_cache +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple import gymnasium as gym import numpy as np @@ -28,83 +31,247 @@ from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 +@unique +class Bands(IntFlag): + L0 = enum.auto() + L1 = enum.auto() + L2 = enum.auto() + L3 = enum.auto() + L4 = enum.auto() + L5 = enum.auto() + L6 = enum.auto() + L7 = enum.auto() + L8 = enum.auto() + L9 = enum.auto() + L10 = enum.auto() + L11 = enum.auto() + ALL = L0 | L1 | L2 | L3 | L4 | L5 | L6 | L7 | L8 | L9 | L10 | L11 + + +@unique class Sensitivity(Enum): LOW = 0 STANDARD = 1 HIGH = 2 -class Transmitter(NamedTuple): - """A configuration utility to set up agents to transmit messages.""" +class Header(NamedTuple): + """A header modeled loosely after an email. - # TODO MTA: move this to agent interface. + Args: + channel (str): The channel this message is on. + sender (str): The name of the sender. + sender_type (str): The type of actor the sender is. + cc (Set[str]): + The actors and aliases this should be sent publicly to. All cc recipients see the cc. + bcc (Set[str]): + The actors and aliases this should be sent privately to. The recipient only sees self + in bcc list. + format (str): The format of this message. + """ + + channel: str + sender: str + sender_type: str + cc: Set[str] + bcc: Set[str] + format: str + + +class Message(NamedTuple): + """The message data. + + Args: + content (Any): The content of this message. + """ + + content: Any + # size: uint - bands: Sequence[int] + +class V2XTransmitter(NamedTuple): + """A configuration utility to set up agents to transmit messages.""" + + bands: Bands range: float - mission: bool - position: bool - heading: bool - breaking: bool - throttle: bool - steering: bool - custom_message_blob_size: int + # available_channels: List[str] -class Receiver(NamedTuple): +class V2XReceiver(NamedTuple): """A configuratoin utility to set up agent to receive messages.""" # TODO MTA: move this to agent interface. - bands: Sequence[int] - sensitivity: Sensitivity = Sensitivity.HIGH + bands: Bands + aliases: List[str] + whitelist_channels: Optional[List[str]] = None + blacklist_channels: Set[str] = {} + sensitivity: Sensitivity = Sensitivity.STANDARD class MessagePasser(gym.Wrapper): """This wrapper augments the observations and actions to require passing messages from agents.""" - def __init__(self, env: gym.Env, message_size=125000): + def __init__( + self, + env: gym.Env, + message_config: Dict[str, Tuple[V2XTransmitter, V2XReceiver]], + max_message_bytes=125000, + ): super().__init__(env) + self._message_config = message_config + # map alias to agent ids (multiple agents can be under the same alias) + self._alias_mapping = defaultdict(list) + for a_id, (_, receiver) in message_config.items(): + for alias in receiver.aliases: + self._alias_mapping[alias].append(a_id) + assert isinstance(env, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space msg_space = ( - gym.spaces.Box(low=0, high=256, shape=(message_size,), dtype=np.uint8), + gym.spaces.Box(low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8), ) self.action_space = gym.spaces.Dict( { a_id: gym.spaces.Tuple( ( - action_space, + base_action_space, msg_space, ) ) - for a_id, action_space in o_action_space.spaces.items() + for a_id, base_action_space in o_action_space.spaces.items() } ) o_observation_space: gym.spaces.Dict = self.env.observation_space + self._transmission_space = gym.spaces.Sequence( + gym.spaces.Tuple( + ( + gym.spaces.Tuple( + ( + gym.spaces.Text(20), # channel + gym.spaces.Text(30), # sender + gym.spaces.Text(10), # sender_type + gym.spaces.Sequence(gym.spaces.Text(30)), # cc + gym.spaces.Sequence(gym.spaces.Text(30)), # bcc + gym.spaces.Text(10), # format + ) + ), + gym.spaces.Tuple((msg_space,)), + ) + ) + ) self.observation_space = gym.spaces.Dict( { - "agents": o_observation_space, - "messages": gym.spaces.Dict( - {a_id: msg_space for a_id in o_action_space} - ), + a_id: gym.spaces.Dict( + dict( + **obs, + transmissions=self._transmission_space, + ) + ) + for a_id, obs in o_observation_space.items() } ) + @lru_cache + def resolve_alias(self, alias): + return set(self._alias_mapping[alias]) + def step(self, action): - std_actions = {} - msgs = {} - for a_id, ma in action.items(): - std_actions[a_id] = ma[0] - msgs[a_id] = ma[1] + # step + std_actions = {a_id: act for a_id, (act, _) in action} + observations, rewards, terms, truncs, infos = self.env.step(std_actions) + + msgs = defaultdict(list) + + # filter recipients for active + cached_active_filter = lru_cache(lambda a: a.intersection(observations.keys())) + + # filter recipients by band + ## compare transmitter + cached_band_filter = lru_cache( + lambda sender, recipients: ( + r + for r in recipients + if self._message_config[sender][0].bands + | self._message_config[r][1].bands + ) + ) + + # filter recipients that do not listen to the channel + accepts_channel = lru_cache( + lambda channel, recipient: ( + (not self._message_config[recipient][1].whitelist_channels) + or (channel in self._message_config[recipient][1].whitelist_channels) + ) + and channel not in self._message_config[recipient][1].blacklist_channels + ) + cached_channel_filter = lru_cache( + lambda channel, recipients: ( + r for r in recipients if accepts_channel(channel, r) + ) + ) + + ## filter recipients by distance + ## Includes all + ## TODO ensure this works on all formatting types + # cached_dist_comp = lambda sender, receiver: obs[sender]["position"].dot(obs[receiver]["position"]) + # cached_distance_filter = lru_cache(lambda sender, receivers: ( + # r for r in receivers if cached_distance_filter + # )) + + # compress filters + general_filter = lambda header, initial_recipients: ( + cc + for recipients in map(self.resolve_alias, initial_recipients) + for cc in cached_channel_filter( + header.channel, + cached_band_filter(header.sender, cached_active_filter(recipients)), + ) + ) + + # Organise the messages to their recipients + for a_id, (_, msg) in action.items(): + msg: List[Tuple[Header, Message]] = msg + for header, message in msg: + header: Header = header + message: Message = message + + # expand the recipients + cc_recipients = set(general_filter(header, header.cc)) + bcc_recipients = set(general_filter(header, header.bcc)) + cc_header = header._replace(cc=cc_recipients) + + # associate the messages to the recipients + for recipient in ( + cc_recipients - bcc_recipients + ): # privacy takes priority + msgs[recipient].append( + (cc_header._replace(bcc=set()), message) # clear bcc + ) + for recipient in bcc_recipients: + msgs[recipient].append( + ( + cc_header._replace( + bcc={recipient} + ), # leave recipient in bcc + message, + ) + ) - obs, rewards, terms, truncs, infos = self.env.step(std_actions) obs_with_msgs = { - "agents": obs, - "messages": {a_id: msgs for a_id in obs}, + a_id: dict( + **obs, + transmissions=msgs[a_id], + ) + for a_id, obs in observations.items() } return obs_with_msgs, rewards, terms, truncs, infos def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Any, Dict[str, Any]]: - obs, info = super().reset(seed=seed, options=options) - return {"agents": obs, "messages": self.observation_space["messages"].sample()} + observations, info = super().reset(seed=seed, options=options) + obs_with_msgs = { + a_id: dict(**obs, transmissions=self._transmission_space.sample(0)) + for a_id, obs in observations.items() + } + return obs_with_msgs, info From 96a878e0086c49fbcc64d16a71e76902d5e564ac Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 23 Feb 2023 18:17:59 +0000 Subject: [PATCH 08/27] Remove agent interface todo. --- smarts/env/gymnasium/wrappers/agent_communication.py | 1 - 1 file changed, 1 deletion(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 0ef5f55c28..f84619a0f0 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -100,7 +100,6 @@ class V2XTransmitter(NamedTuple): class V2XReceiver(NamedTuple): """A configuratoin utility to set up agent to receive messages.""" - # TODO MTA: move this to agent interface. bands: Bands aliases: List[str] whitelist_channels: Optional[List[str]] = None From b7dc02eb91f5b1ff7658e0c7f5870232f2b6eabe Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 23 Feb 2023 20:57:52 +0000 Subject: [PATCH 09/27] Fix docstring test. --- .../env/gymnasium/wrappers/agent_communication.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index f84619a0f0..2139067478 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -33,6 +33,8 @@ @unique class Bands(IntFlag): + """Communcation bands.""" + L0 = enum.auto() L1 = enum.auto() L2 = enum.auto() @@ -103,12 +105,14 @@ class V2XReceiver(NamedTuple): bands: Bands aliases: List[str] whitelist_channels: Optional[List[str]] = None - blacklist_channels: Set[str] = {} + blacklist_channels: Set[str] = set() sensitivity: Sensitivity = Sensitivity.STANDARD class MessagePasser(gym.Wrapper): - """This wrapper augments the observations and actions to require passing messages from agents.""" + """This wrapper augments the observations and actions to require passing messages from agents. + + It assumes that the underlying environment is :class:`HiWayEnvV1`""" def __init__( self, @@ -116,6 +120,7 @@ def __init__( message_config: Dict[str, Tuple[V2XTransmitter, V2XReceiver]], max_message_bytes=125000, ): + """""" super().__init__(env) self._message_config = message_config # map alias to agent ids (multiple agents can be under the same alias) @@ -172,15 +177,17 @@ def __init__( @lru_cache def resolve_alias(self, alias): + """Resolve the alias to agent ids.""" return set(self._alias_mapping[alias]) def step(self, action): - # step + """Steps the environment using the given action.""" std_actions = {a_id: act for a_id, (act, _) in action} observations, rewards, terms, truncs, infos = self.env.step(std_actions) msgs = defaultdict(list) + # pytype: disable=wrong-arg-types # filter recipients for active cached_active_filter = lru_cache(lambda a: a.intersection(observations.keys())) @@ -208,6 +215,7 @@ def step(self, action): r for r in recipients if accepts_channel(channel, r) ) ) + # pytype: enable=wrong-arg-types ## filter recipients by distance ## Includes all @@ -268,6 +276,7 @@ def step(self, action): def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Any, Dict[str, Any]]: + """Resets the environment.""" observations, info = super().reset(seed=seed, options=options) obs_with_msgs = { a_id: dict(**obs, transmissions=self._transmission_space.sample(0)) From e9a22c081377ce0a48d25a6e6bd8ea60b151672d Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 23 Feb 2023 23:30:08 +0000 Subject: [PATCH 10/27] Add agent communcation example. --- examples/control/agent_communcation.py | 207 ++++++++++++++++++ .../gymnasium/wrappers/agent_communication.py | 32 ++- 2 files changed, 228 insertions(+), 11 deletions(-) create mode 100644 examples/control/agent_communcation.py diff --git a/examples/control/agent_communcation.py b/examples/control/agent_communcation.py new file mode 100644 index 0000000000..f6ad77d2d8 --- /dev/null +++ b/examples/control/agent_communcation.py @@ -0,0 +1,207 @@ +import sys +from pathlib import Path +from typing import Any, Dict, Union + +from smarts.core.agent import Agent +from smarts.core.agent_interface import AgentInterface, AgentType +from smarts.core.utils.episodes import episodes +from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 +from smarts.env.gymnasium.wrappers.agent_communication import ( + Bands, + Header, + Message, + MessagePasser, + V2XReceiver, + V2XTransmitter, +) +from smarts.env.utils.action_conversion import ActionOptions +from smarts.env.utils.observation_conversion import ObservationOptions +from smarts.sstudio.scenario_construction import build_scenarios + +sys.path.insert(0, str(Path(__file__).parents[2].absolute())) +import gymnasium as gym + +from examples.tools.argument_parser import default_argument_parser + +TIMESTEP = 0.1 +BYTES_IN_MEGABIT = 125000 +MESSAGE_MEGABITS_PER_SECOND = 10 +MESSAGE_BYTES = int(BYTES_IN_MEGABIT * MESSAGE_MEGABITS_PER_SECOND / TIMESTEP) + + +def filter_useless(transmissions): + for header, msg in transmissions: + if header.sender in ("parked_agent", "broken_stoplight"): + continue + if header.sender_type in ("advertisement",): + continue + yield header, msg + + +class LaneFollowerAgent(Agent): + def act(self, obs: Dict[Any, Union[Any, Dict]]): + return (obs["waypoint_paths"]["speed_limit"][0][0], 0) + + +class GossiperAgent(Agent): + def __init__(self, id_: str, base_agent: Agent, filter_, friends): + self._filter = filter_ + self._id = id_ + self._friends = friends + self._base_agent = base_agent + + def act(self, obs, **configs): + out_transmissions = [] + for header, msg in self._filter(obs["transmissions"]): + header: Header = header + msg: Message = msg + if not {self._id, "__all__"}.intersection(header.cc | header.bcc): + continue + if header.channel == "position_request": + print() + print("On step: ", obs["steps_completed"]) + print("Gossiper received position request: ", header) + out_transmissions.append( + ( + Header( + channel="position", + sender=self._id, + sender_type="ad_vehicle", + cc={header.sender}, + bcc={*self._friends}, + format="position", + ), # optimize this later + Message( + content=obs["ego_vehicle_state"]["position"], + ), # optimize this later + ) + ) + print("Gossiper sent position: ", out_transmissions[0][1]) + + base_action = self._base_agent.act(obs) + return (base_action, out_transmissions) + + +class SchemerAgent(Agent): + def __init__(self, id_: str, base_agent: Agent, request_freq) -> None: + self._base_agent = base_agent + self._id = id_ + self._request_freq = request_freq + + def act(self, obs, **configs): + out_transmissions = [] + for header, msg in obs["transmissions"]: + header: Header = header + msg: Message = msg + if header.channel == "position": + print() + print("On step: ", obs["steps_completed"]) + print("Schemer received position: ", msg) + + if obs["steps_completed"] % self._request_freq == 0: + print() + print("On step: ", obs["steps_completed"]) + out_transmissions.append( + ( + Header( + channel="position_request", + sender=self._id, + sender_type="ad_vehicle", + cc=set(), + bcc={"__all__"}, + format="position_request", + ), + Message(content=None), + ) + ) + print("Schemer requested position with: ", out_transmissions[0][0]) + + base_action = self._base_agent.act(obs) + return (base_action, out_transmissions) + + +def main(scenarios, headless, num_episodes, max_episode_steps=None): + agent_interface = AgentInterface.from_type( + AgentType.LanerWithSpeed, max_episode_steps=max_episode_steps + ) + hiwayv1env = HiWayEnvV1( + scenarios=scenarios, + agent_interfaces={"gossiper0": agent_interface, "schemer": agent_interface}, + headless=headless, + observation_options=ObservationOptions.multi_agent, + action_options=ActionOptions.default, + ) + # for now + env = MessagePasser( + hiwayv1env, + max_message_bytes=MESSAGE_BYTES, + message_config={ + "gossiper0": ( + V2XTransmitter( + bands=Bands.ALL, + range=100, + # available_channels=["position_request", "position"] + ), + V2XReceiver( + bands=Bands.ALL, + aliases=["tim"], + blacklist_channels={"self_control"}, + ), + ), + "schemer": ( + V2XTransmitter( + bands=Bands.ALL, + range=100, + ), + V2XReceiver( + bands=Bands.ALL, + aliases=[], + ), + ), + }, + ) + agents = { + "gossiper0": GossiperAgent( + "gossiper0", + base_agent=LaneFollowerAgent(), + filter_=filter_useless, + friends={"schemer"}, + ), + "schemer": SchemerAgent( + "schemer", base_agent=LaneFollowerAgent(), request_freq=100 + ), + } + + # then just the standard gym interface with no modifications + for episode in episodes(n=num_episodes): + observation, info = env.reset() + episode.record_scenario(env.scenario_log) + + terminated = {"__all__": False} + while not terminated["__all__"]: + agent_action = { + agent_id: agents[agent_id].act(obs) + for agent_id, obs in observation.items() + } + observation, reward, terminated, truncated, info = env.step(agent_action) + episode.record_step(observation, reward, terminated, info) + + env.close() + + +if __name__ == "__main__": + parser = default_argument_parser("single-agent-example") + args = parser.parse_args() + + if not args.scenarios: + args.scenarios = [ + str(Path(__file__).absolute().parents[2] / "scenarios" / "sumo" / "loop") + ] + + build_scenarios(scenarios=args.scenarios) + + main( + scenarios=args.scenarios, + headless=args.headless, + num_episodes=args.episodes, + ) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 2139067478..af2836df15 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -128,11 +128,17 @@ def __init__( for a_id, (_, receiver) in message_config.items(): for alias in receiver.aliases: self._alias_mapping[alias].append(a_id) + self._alias_mapping[a_id].append(a_id) + self._alias_mapping["__all__"].append(a_id) - assert isinstance(env, HiWayEnvV1) + assert isinstance(env.unwrapped, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space - msg_space = ( - gym.spaces.Box(low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8), + msg_space = gym.spaces.Tuple( + ( + gym.spaces.Box( + low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8 + ), + ) ) self.action_space = gym.spaces.Dict( { @@ -182,19 +188,21 @@ def resolve_alias(self, alias): def step(self, action): """Steps the environment using the given action.""" - std_actions = {a_id: act for a_id, (act, _) in action} + std_actions = {a_id: act for a_id, (act, _) in action.items()} observations, rewards, terms, truncs, infos = self.env.step(std_actions) msgs = defaultdict(list) # pytype: disable=wrong-arg-types # filter recipients for active - cached_active_filter = lru_cache(lambda a: a.intersection(observations.keys())) + cached_active_filter = lru_cache( + lambda a: frozenset(a.intersection(observations.keys())) + ) # filter recipients by band ## compare transmitter cached_band_filter = lru_cache( - lambda sender, recipients: ( + lambda sender, recipients: frozenset( r for r in recipients if self._message_config[sender][0].bands @@ -211,7 +219,7 @@ def step(self, action): and channel not in self._message_config[recipient][1].blacklist_channels ) cached_channel_filter = lru_cache( - lambda channel, recipients: ( + lambda channel, recipients: frozenset( r for r in recipients if accepts_channel(channel, r) ) ) @@ -231,7 +239,9 @@ def step(self, action): for recipients in map(self.resolve_alias, initial_recipients) for cc in cached_channel_filter( header.channel, - cached_band_filter(header.sender, cached_active_filter(recipients)), + cached_band_filter( + header.sender, cached_active_filter(frozenset(recipients)) + ), ) ) @@ -243,8 +253,8 @@ def step(self, action): message: Message = message # expand the recipients - cc_recipients = set(general_filter(header, header.cc)) - bcc_recipients = set(general_filter(header, header.bcc)) + cc_recipients = set(general_filter(header, frozenset(header.cc))) + bcc_recipients = set(general_filter(header, frozenset(header.bcc))) cc_header = header._replace(cc=cc_recipients) # associate the messages to the recipients @@ -279,7 +289,7 @@ def reset( """Resets the environment.""" observations, info = super().reset(seed=seed, options=options) obs_with_msgs = { - a_id: dict(**obs, transmissions=self._transmission_space.sample(0)) + a_id: dict(**obs, transmissions=self._transmission_space.sample((0, ()))) for a_id, obs in observations.items() } return obs_with_msgs, info From cde2091359259e17bb38cb26a396a8644f83ca1c Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 19:06:18 +0000 Subject: [PATCH 11/27] Add agent communication wrapper. --- .../gymnasium/wrappers/agent_communication.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 smarts/env/gymnasium/wrappers/agent_communication.py diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py new file mode 100644 index 0000000000..92796561a2 --- /dev/null +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -0,0 +1,61 @@ +import gymnasium as gym +from typing import Any, Dict, NamedTuple, Optional, Tuple +from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 + +import numpy as np + + +class Transmitter(NamedTuple): + pass + + +class Receiver(NamedTuple): + pass + + +class MessagePasser(gym.Wrapper): + def __init__(self, env: gym.Env): + super().__init__(env) + assert isinstance(env, HiWayEnvV1) + o_action_space: gym.spaces.Dict = self.env.action_space + msg_space = (gym.spaces.Box(low=0, high=256, shape=(125000,), dtype=np.uint8),) + self.action_space = gym.spaces.Dict( + { + a_id: gym.spaces.Tuple( + ( + action_space, + msg_space, + ) + ) + for a_id, action_space in o_action_space.spaces.items() + } + ) + o_observation_space: gym.spaces.Dict = self.env.observation_space + self.observation_space = gym.spaces.Dict( + { + "agents": o_observation_space, + "messages": gym.spaces.Dict( + {a_id: msg_space for a_id in o_action_space} + ), + } + ) + + def step(self, actions): + std_actions = {} + msgs = {} + for a_id, ma in actions.items(): + std_actions[a_id] = ma[0] + msgs[a_id] = ma[1] + + obs, rewards, terms, truncs, infos = self.env.step(std_actions) + obs_with_msgs = { + "agents": obs, + "messages": {a_id: msgs for a_id, ob in obs.items()}, + } + return obs_with_msgs, rewards, terms, truncs, infos + + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> Tuple[Any, Dict[str, Any]]: + obs, info = super().reset(seed=seed, options=options) + return {"agents": obs, "messages": self.observation_space["messages"].sample()} From e021a065a80b7a9b9990988b61c0a6811ee07e02 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 19:28:14 +0000 Subject: [PATCH 12/27] Add docstrings. --- .../gymnasium/wrappers/agent_communication.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 92796561a2..94df64e4f3 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -6,19 +6,29 @@ class Transmitter(NamedTuple): + """A configuration utility to set up agents to transmit messages.""" + pass class Receiver(NamedTuple): + """A configuratoin utility to set up agent to receive messages.""" + pass class MessagePasser(gym.Wrapper): - def __init__(self, env: gym.Env): + """This wrapper augments the observations and actions to require passing messages from agents.""" + + def __init__(self, env: gym.Env, message_size_bytes=125000): super().__init__(env) assert isinstance(env, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space - msg_space = (gym.spaces.Box(low=0, high=256, shape=(125000,), dtype=np.uint8),) + msg_space = ( + gym.spaces.Box( + low=0, high=256, shape=(message_size_bytes,), dtype=np.uint8 + ), + ) self.action_space = gym.spaces.Dict( { a_id: gym.spaces.Tuple( @@ -40,17 +50,17 @@ def __init__(self, env: gym.Env): } ) - def step(self, actions): + def step(self, action): std_actions = {} msgs = {} - for a_id, ma in actions.items(): + for a_id, ma in action.items(): std_actions[a_id] = ma[0] msgs[a_id] = ma[1] obs, rewards, terms, truncs, infos = self.env.step(std_actions) obs_with_msgs = { "agents": obs, - "messages": {a_id: msgs for a_id, ob in obs.items()}, + "messages": {a_id: msgs for a_id in obs}, } return obs_with_msgs, rewards, terms, truncs, infos From 1e7827054ad01c4e8352f8e89db453bdc4f00079 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 19:36:05 +0000 Subject: [PATCH 13/27] Add header. --- .../gymnasium/wrappers/agent_communication.py | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 94df64e4f3..33fa95e248 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -1,6 +1,25 @@ -import gymnasium as gym +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. from typing import Any, Dict, NamedTuple, Optional, Tuple -from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 import numpy as np From 356a9228b503e2326d661afbc5571515978d0f16 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 19:36:34 +0000 Subject: [PATCH 14/27] Update configuration. --- smarts/env/gymnasium/wrappers/agent_communication.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 33fa95e248..dd2dc8989f 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -21,8 +21,11 @@ # THE SOFTWARE. from typing import Any, Dict, NamedTuple, Optional, Tuple +import gymnasium as gym import numpy as np +from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 + class Transmitter(NamedTuple): """A configuration utility to set up agents to transmit messages.""" @@ -39,14 +42,12 @@ class Receiver(NamedTuple): class MessagePasser(gym.Wrapper): """This wrapper augments the observations and actions to require passing messages from agents.""" - def __init__(self, env: gym.Env, message_size_bytes=125000): + def __init__(self, env: gym.Env, message_size=125000): super().__init__(env) assert isinstance(env, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space msg_space = ( - gym.spaces.Box( - low=0, high=256, shape=(message_size_bytes,), dtype=np.uint8 - ), + gym.spaces.Box(low=0, high=256, shape=(message_size,), dtype=np.uint8), ) self.action_space = gym.spaces.Dict( { From 4decba071f5c6203977a4f1ab9cbac3d667424af Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 20:05:26 +0000 Subject: [PATCH 15/27] Mock up interface for message --- .../gymnasium/wrappers/agent_communication.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index dd2dc8989f..c91c012222 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -1,17 +1,17 @@ # MIT License -# +# # Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. -# +# # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: -# +# # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. -# +# # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE @@ -19,7 +19,8 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from typing import Any, Dict, NamedTuple, Optional, Tuple +from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple +from enum import Enum import gymnasium as gym import numpy as np @@ -27,16 +28,33 @@ from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 +class Sensitivity(Enum): + LOW = 0 + STANDARD = 1 + HIGH = 2 + + class Transmitter(NamedTuple): """A configuration utility to set up agents to transmit messages.""" - pass + # TODO MTA: move this to agent interface. + + bands: Sequence[int] + range: float + mission: bool + position: bool + heading: bool + breaking: bool + throttle: bool + steering: bool class Receiver(NamedTuple): """A configuratoin utility to set up agent to receive messages.""" - pass + # TODO MTA: move this to agent interface. + bands: Sequence[int] + sensitivity: Sensitivity = Sensitivity.HIGH class MessagePasser(gym.Wrapper): From 8b0ed1941e5fb74501ab0b2723f817ebd802d5b1 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Wed, 22 Feb 2023 20:09:47 +0000 Subject: [PATCH 16/27] Add custom block to interface option. --- smarts/env/gymnasium/wrappers/agent_communication.py | 1 + 1 file changed, 1 insertion(+) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index c91c012222..9f2fe33115 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -47,6 +47,7 @@ class Transmitter(NamedTuple): breaking: bool throttle: bool steering: bool + custom_message_blob_size: int class Receiver(NamedTuple): From 8960f5276347d9959d166aaa72695d25eb90aeb5 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 23 Feb 2023 18:16:09 +0000 Subject: [PATCH 17/27] Bring wrapper close to completion. --- .../gymnasium/wrappers/agent_communication.py | 235 +++++++++++++++--- 1 file changed, 201 insertions(+), 34 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 9f2fe33115..0ef5f55c28 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -19,8 +19,11 @@ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # THE SOFTWARE. -from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple -from enum import Enum +import enum +from collections import defaultdict +from enum import Enum, IntFlag, unique +from functools import lru_cache +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple import gymnasium as gym import numpy as np @@ -28,83 +31,247 @@ from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 +@unique +class Bands(IntFlag): + L0 = enum.auto() + L1 = enum.auto() + L2 = enum.auto() + L3 = enum.auto() + L4 = enum.auto() + L5 = enum.auto() + L6 = enum.auto() + L7 = enum.auto() + L8 = enum.auto() + L9 = enum.auto() + L10 = enum.auto() + L11 = enum.auto() + ALL = L0 | L1 | L2 | L3 | L4 | L5 | L6 | L7 | L8 | L9 | L10 | L11 + + +@unique class Sensitivity(Enum): LOW = 0 STANDARD = 1 HIGH = 2 -class Transmitter(NamedTuple): - """A configuration utility to set up agents to transmit messages.""" +class Header(NamedTuple): + """A header modeled loosely after an email. - # TODO MTA: move this to agent interface. + Args: + channel (str): The channel this message is on. + sender (str): The name of the sender. + sender_type (str): The type of actor the sender is. + cc (Set[str]): + The actors and aliases this should be sent publicly to. All cc recipients see the cc. + bcc (Set[str]): + The actors and aliases this should be sent privately to. The recipient only sees self + in bcc list. + format (str): The format of this message. + """ + + channel: str + sender: str + sender_type: str + cc: Set[str] + bcc: Set[str] + format: str + + +class Message(NamedTuple): + """The message data. + + Args: + content (Any): The content of this message. + """ + + content: Any + # size: uint - bands: Sequence[int] + +class V2XTransmitter(NamedTuple): + """A configuration utility to set up agents to transmit messages.""" + + bands: Bands range: float - mission: bool - position: bool - heading: bool - breaking: bool - throttle: bool - steering: bool - custom_message_blob_size: int + # available_channels: List[str] -class Receiver(NamedTuple): +class V2XReceiver(NamedTuple): """A configuratoin utility to set up agent to receive messages.""" # TODO MTA: move this to agent interface. - bands: Sequence[int] - sensitivity: Sensitivity = Sensitivity.HIGH + bands: Bands + aliases: List[str] + whitelist_channels: Optional[List[str]] = None + blacklist_channels: Set[str] = {} + sensitivity: Sensitivity = Sensitivity.STANDARD class MessagePasser(gym.Wrapper): """This wrapper augments the observations and actions to require passing messages from agents.""" - def __init__(self, env: gym.Env, message_size=125000): + def __init__( + self, + env: gym.Env, + message_config: Dict[str, Tuple[V2XTransmitter, V2XReceiver]], + max_message_bytes=125000, + ): super().__init__(env) + self._message_config = message_config + # map alias to agent ids (multiple agents can be under the same alias) + self._alias_mapping = defaultdict(list) + for a_id, (_, receiver) in message_config.items(): + for alias in receiver.aliases: + self._alias_mapping[alias].append(a_id) + assert isinstance(env, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space msg_space = ( - gym.spaces.Box(low=0, high=256, shape=(message_size,), dtype=np.uint8), + gym.spaces.Box(low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8), ) self.action_space = gym.spaces.Dict( { a_id: gym.spaces.Tuple( ( - action_space, + base_action_space, msg_space, ) ) - for a_id, action_space in o_action_space.spaces.items() + for a_id, base_action_space in o_action_space.spaces.items() } ) o_observation_space: gym.spaces.Dict = self.env.observation_space + self._transmission_space = gym.spaces.Sequence( + gym.spaces.Tuple( + ( + gym.spaces.Tuple( + ( + gym.spaces.Text(20), # channel + gym.spaces.Text(30), # sender + gym.spaces.Text(10), # sender_type + gym.spaces.Sequence(gym.spaces.Text(30)), # cc + gym.spaces.Sequence(gym.spaces.Text(30)), # bcc + gym.spaces.Text(10), # format + ) + ), + gym.spaces.Tuple((msg_space,)), + ) + ) + ) self.observation_space = gym.spaces.Dict( { - "agents": o_observation_space, - "messages": gym.spaces.Dict( - {a_id: msg_space for a_id in o_action_space} - ), + a_id: gym.spaces.Dict( + dict( + **obs, + transmissions=self._transmission_space, + ) + ) + for a_id, obs in o_observation_space.items() } ) + @lru_cache + def resolve_alias(self, alias): + return set(self._alias_mapping[alias]) + def step(self, action): - std_actions = {} - msgs = {} - for a_id, ma in action.items(): - std_actions[a_id] = ma[0] - msgs[a_id] = ma[1] + # step + std_actions = {a_id: act for a_id, (act, _) in action} + observations, rewards, terms, truncs, infos = self.env.step(std_actions) + + msgs = defaultdict(list) + + # filter recipients for active + cached_active_filter = lru_cache(lambda a: a.intersection(observations.keys())) + + # filter recipients by band + ## compare transmitter + cached_band_filter = lru_cache( + lambda sender, recipients: ( + r + for r in recipients + if self._message_config[sender][0].bands + | self._message_config[r][1].bands + ) + ) + + # filter recipients that do not listen to the channel + accepts_channel = lru_cache( + lambda channel, recipient: ( + (not self._message_config[recipient][1].whitelist_channels) + or (channel in self._message_config[recipient][1].whitelist_channels) + ) + and channel not in self._message_config[recipient][1].blacklist_channels + ) + cached_channel_filter = lru_cache( + lambda channel, recipients: ( + r for r in recipients if accepts_channel(channel, r) + ) + ) + + ## filter recipients by distance + ## Includes all + ## TODO ensure this works on all formatting types + # cached_dist_comp = lambda sender, receiver: obs[sender]["position"].dot(obs[receiver]["position"]) + # cached_distance_filter = lru_cache(lambda sender, receivers: ( + # r for r in receivers if cached_distance_filter + # )) + + # compress filters + general_filter = lambda header, initial_recipients: ( + cc + for recipients in map(self.resolve_alias, initial_recipients) + for cc in cached_channel_filter( + header.channel, + cached_band_filter(header.sender, cached_active_filter(recipients)), + ) + ) + + # Organise the messages to their recipients + for a_id, (_, msg) in action.items(): + msg: List[Tuple[Header, Message]] = msg + for header, message in msg: + header: Header = header + message: Message = message + + # expand the recipients + cc_recipients = set(general_filter(header, header.cc)) + bcc_recipients = set(general_filter(header, header.bcc)) + cc_header = header._replace(cc=cc_recipients) + + # associate the messages to the recipients + for recipient in ( + cc_recipients - bcc_recipients + ): # privacy takes priority + msgs[recipient].append( + (cc_header._replace(bcc=set()), message) # clear bcc + ) + for recipient in bcc_recipients: + msgs[recipient].append( + ( + cc_header._replace( + bcc={recipient} + ), # leave recipient in bcc + message, + ) + ) - obs, rewards, terms, truncs, infos = self.env.step(std_actions) obs_with_msgs = { - "agents": obs, - "messages": {a_id: msgs for a_id in obs}, + a_id: dict( + **obs, + transmissions=msgs[a_id], + ) + for a_id, obs in observations.items() } return obs_with_msgs, rewards, terms, truncs, infos def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Any, Dict[str, Any]]: - obs, info = super().reset(seed=seed, options=options) - return {"agents": obs, "messages": self.observation_space["messages"].sample()} + observations, info = super().reset(seed=seed, options=options) + obs_with_msgs = { + a_id: dict(**obs, transmissions=self._transmission_space.sample(0)) + for a_id, obs in observations.items() + } + return obs_with_msgs, info From 5268ebdf26f149f1b7564a678af07db1ca1c2849 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 23 Feb 2023 18:17:59 +0000 Subject: [PATCH 18/27] Remove agent interface todo. --- smarts/env/gymnasium/wrappers/agent_communication.py | 1 - 1 file changed, 1 deletion(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 0ef5f55c28..f84619a0f0 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -100,7 +100,6 @@ class V2XTransmitter(NamedTuple): class V2XReceiver(NamedTuple): """A configuratoin utility to set up agent to receive messages.""" - # TODO MTA: move this to agent interface. bands: Bands aliases: List[str] whitelist_channels: Optional[List[str]] = None From af425d01fd86514b11e3803ceddbcb0a546300a2 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 23 Feb 2023 20:57:52 +0000 Subject: [PATCH 19/27] Fix docstring test. --- .../env/gymnasium/wrappers/agent_communication.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index f84619a0f0..2139067478 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -33,6 +33,8 @@ @unique class Bands(IntFlag): + """Communcation bands.""" + L0 = enum.auto() L1 = enum.auto() L2 = enum.auto() @@ -103,12 +105,14 @@ class V2XReceiver(NamedTuple): bands: Bands aliases: List[str] whitelist_channels: Optional[List[str]] = None - blacklist_channels: Set[str] = {} + blacklist_channels: Set[str] = set() sensitivity: Sensitivity = Sensitivity.STANDARD class MessagePasser(gym.Wrapper): - """This wrapper augments the observations and actions to require passing messages from agents.""" + """This wrapper augments the observations and actions to require passing messages from agents. + + It assumes that the underlying environment is :class:`HiWayEnvV1`""" def __init__( self, @@ -116,6 +120,7 @@ def __init__( message_config: Dict[str, Tuple[V2XTransmitter, V2XReceiver]], max_message_bytes=125000, ): + """""" super().__init__(env) self._message_config = message_config # map alias to agent ids (multiple agents can be under the same alias) @@ -172,15 +177,17 @@ def __init__( @lru_cache def resolve_alias(self, alias): + """Resolve the alias to agent ids.""" return set(self._alias_mapping[alias]) def step(self, action): - # step + """Steps the environment using the given action.""" std_actions = {a_id: act for a_id, (act, _) in action} observations, rewards, terms, truncs, infos = self.env.step(std_actions) msgs = defaultdict(list) + # pytype: disable=wrong-arg-types # filter recipients for active cached_active_filter = lru_cache(lambda a: a.intersection(observations.keys())) @@ -208,6 +215,7 @@ def step(self, action): r for r in recipients if accepts_channel(channel, r) ) ) + # pytype: enable=wrong-arg-types ## filter recipients by distance ## Includes all @@ -268,6 +276,7 @@ def step(self, action): def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Any, Dict[str, Any]]: + """Resets the environment.""" observations, info = super().reset(seed=seed, options=options) obs_with_msgs = { a_id: dict(**obs, transmissions=self._transmission_space.sample(0)) From f6355a6e8458e951f1fda288fd7c5b5838e03a06 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 23 Feb 2023 23:30:08 +0000 Subject: [PATCH 20/27] Add agent communcation example. --- examples/control/agent_communcation.py | 207 ++++++++++++++++++ .../gymnasium/wrappers/agent_communication.py | 32 ++- 2 files changed, 228 insertions(+), 11 deletions(-) create mode 100644 examples/control/agent_communcation.py diff --git a/examples/control/agent_communcation.py b/examples/control/agent_communcation.py new file mode 100644 index 0000000000..f6ad77d2d8 --- /dev/null +++ b/examples/control/agent_communcation.py @@ -0,0 +1,207 @@ +import sys +from pathlib import Path +from typing import Any, Dict, Union + +from smarts.core.agent import Agent +from smarts.core.agent_interface import AgentInterface, AgentType +from smarts.core.utils.episodes import episodes +from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 +from smarts.env.gymnasium.wrappers.agent_communication import ( + Bands, + Header, + Message, + MessagePasser, + V2XReceiver, + V2XTransmitter, +) +from smarts.env.utils.action_conversion import ActionOptions +from smarts.env.utils.observation_conversion import ObservationOptions +from smarts.sstudio.scenario_construction import build_scenarios + +sys.path.insert(0, str(Path(__file__).parents[2].absolute())) +import gymnasium as gym + +from examples.tools.argument_parser import default_argument_parser + +TIMESTEP = 0.1 +BYTES_IN_MEGABIT = 125000 +MESSAGE_MEGABITS_PER_SECOND = 10 +MESSAGE_BYTES = int(BYTES_IN_MEGABIT * MESSAGE_MEGABITS_PER_SECOND / TIMESTEP) + + +def filter_useless(transmissions): + for header, msg in transmissions: + if header.sender in ("parked_agent", "broken_stoplight"): + continue + if header.sender_type in ("advertisement",): + continue + yield header, msg + + +class LaneFollowerAgent(Agent): + def act(self, obs: Dict[Any, Union[Any, Dict]]): + return (obs["waypoint_paths"]["speed_limit"][0][0], 0) + + +class GossiperAgent(Agent): + def __init__(self, id_: str, base_agent: Agent, filter_, friends): + self._filter = filter_ + self._id = id_ + self._friends = friends + self._base_agent = base_agent + + def act(self, obs, **configs): + out_transmissions = [] + for header, msg in self._filter(obs["transmissions"]): + header: Header = header + msg: Message = msg + if not {self._id, "__all__"}.intersection(header.cc | header.bcc): + continue + if header.channel == "position_request": + print() + print("On step: ", obs["steps_completed"]) + print("Gossiper received position request: ", header) + out_transmissions.append( + ( + Header( + channel="position", + sender=self._id, + sender_type="ad_vehicle", + cc={header.sender}, + bcc={*self._friends}, + format="position", + ), # optimize this later + Message( + content=obs["ego_vehicle_state"]["position"], + ), # optimize this later + ) + ) + print("Gossiper sent position: ", out_transmissions[0][1]) + + base_action = self._base_agent.act(obs) + return (base_action, out_transmissions) + + +class SchemerAgent(Agent): + def __init__(self, id_: str, base_agent: Agent, request_freq) -> None: + self._base_agent = base_agent + self._id = id_ + self._request_freq = request_freq + + def act(self, obs, **configs): + out_transmissions = [] + for header, msg in obs["transmissions"]: + header: Header = header + msg: Message = msg + if header.channel == "position": + print() + print("On step: ", obs["steps_completed"]) + print("Schemer received position: ", msg) + + if obs["steps_completed"] % self._request_freq == 0: + print() + print("On step: ", obs["steps_completed"]) + out_transmissions.append( + ( + Header( + channel="position_request", + sender=self._id, + sender_type="ad_vehicle", + cc=set(), + bcc={"__all__"}, + format="position_request", + ), + Message(content=None), + ) + ) + print("Schemer requested position with: ", out_transmissions[0][0]) + + base_action = self._base_agent.act(obs) + return (base_action, out_transmissions) + + +def main(scenarios, headless, num_episodes, max_episode_steps=None): + agent_interface = AgentInterface.from_type( + AgentType.LanerWithSpeed, max_episode_steps=max_episode_steps + ) + hiwayv1env = HiWayEnvV1( + scenarios=scenarios, + agent_interfaces={"gossiper0": agent_interface, "schemer": agent_interface}, + headless=headless, + observation_options=ObservationOptions.multi_agent, + action_options=ActionOptions.default, + ) + # for now + env = MessagePasser( + hiwayv1env, + max_message_bytes=MESSAGE_BYTES, + message_config={ + "gossiper0": ( + V2XTransmitter( + bands=Bands.ALL, + range=100, + # available_channels=["position_request", "position"] + ), + V2XReceiver( + bands=Bands.ALL, + aliases=["tim"], + blacklist_channels={"self_control"}, + ), + ), + "schemer": ( + V2XTransmitter( + bands=Bands.ALL, + range=100, + ), + V2XReceiver( + bands=Bands.ALL, + aliases=[], + ), + ), + }, + ) + agents = { + "gossiper0": GossiperAgent( + "gossiper0", + base_agent=LaneFollowerAgent(), + filter_=filter_useless, + friends={"schemer"}, + ), + "schemer": SchemerAgent( + "schemer", base_agent=LaneFollowerAgent(), request_freq=100 + ), + } + + # then just the standard gym interface with no modifications + for episode in episodes(n=num_episodes): + observation, info = env.reset() + episode.record_scenario(env.scenario_log) + + terminated = {"__all__": False} + while not terminated["__all__"]: + agent_action = { + agent_id: agents[agent_id].act(obs) + for agent_id, obs in observation.items() + } + observation, reward, terminated, truncated, info = env.step(agent_action) + episode.record_step(observation, reward, terminated, info) + + env.close() + + +if __name__ == "__main__": + parser = default_argument_parser("single-agent-example") + args = parser.parse_args() + + if not args.scenarios: + args.scenarios = [ + str(Path(__file__).absolute().parents[2] / "scenarios" / "sumo" / "loop") + ] + + build_scenarios(scenarios=args.scenarios) + + main( + scenarios=args.scenarios, + headless=args.headless, + num_episodes=args.episodes, + ) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 2139067478..af2836df15 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -128,11 +128,17 @@ def __init__( for a_id, (_, receiver) in message_config.items(): for alias in receiver.aliases: self._alias_mapping[alias].append(a_id) + self._alias_mapping[a_id].append(a_id) + self._alias_mapping["__all__"].append(a_id) - assert isinstance(env, HiWayEnvV1) + assert isinstance(env.unwrapped, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space - msg_space = ( - gym.spaces.Box(low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8), + msg_space = gym.spaces.Tuple( + ( + gym.spaces.Box( + low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8 + ), + ) ) self.action_space = gym.spaces.Dict( { @@ -182,19 +188,21 @@ def resolve_alias(self, alias): def step(self, action): """Steps the environment using the given action.""" - std_actions = {a_id: act for a_id, (act, _) in action} + std_actions = {a_id: act for a_id, (act, _) in action.items()} observations, rewards, terms, truncs, infos = self.env.step(std_actions) msgs = defaultdict(list) # pytype: disable=wrong-arg-types # filter recipients for active - cached_active_filter = lru_cache(lambda a: a.intersection(observations.keys())) + cached_active_filter = lru_cache( + lambda a: frozenset(a.intersection(observations.keys())) + ) # filter recipients by band ## compare transmitter cached_band_filter = lru_cache( - lambda sender, recipients: ( + lambda sender, recipients: frozenset( r for r in recipients if self._message_config[sender][0].bands @@ -211,7 +219,7 @@ def step(self, action): and channel not in self._message_config[recipient][1].blacklist_channels ) cached_channel_filter = lru_cache( - lambda channel, recipients: ( + lambda channel, recipients: frozenset( r for r in recipients if accepts_channel(channel, r) ) ) @@ -231,7 +239,9 @@ def step(self, action): for recipients in map(self.resolve_alias, initial_recipients) for cc in cached_channel_filter( header.channel, - cached_band_filter(header.sender, cached_active_filter(recipients)), + cached_band_filter( + header.sender, cached_active_filter(frozenset(recipients)) + ), ) ) @@ -243,8 +253,8 @@ def step(self, action): message: Message = message # expand the recipients - cc_recipients = set(general_filter(header, header.cc)) - bcc_recipients = set(general_filter(header, header.bcc)) + cc_recipients = set(general_filter(header, frozenset(header.cc))) + bcc_recipients = set(general_filter(header, frozenset(header.bcc))) cc_header = header._replace(cc=cc_recipients) # associate the messages to the recipients @@ -279,7 +289,7 @@ def reset( """Resets the environment.""" observations, info = super().reset(seed=seed, options=options) obs_with_msgs = { - a_id: dict(**obs, transmissions=self._transmission_space.sample(0)) + a_id: dict(**obs, transmissions=self._transmission_space.sample((0, ()))) for a_id, obs in observations.items() } return obs_with_msgs, info From 842f60e4f37ec147b97df127b90e414b38a9785e Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Thu, 2 Mar 2023 17:29:34 +0000 Subject: [PATCH 21/27] Improve action space. --- .../gymnasium/wrappers/agent_communication.py | 88 +++++++++++-------- 1 file changed, 51 insertions(+), 37 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index af2836df15..225c04a0bb 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -97,6 +97,7 @@ class V2XTransmitter(NamedTuple): bands: Bands range: float # available_channels: List[str] + # max_message_bytes: int = 125000 class V2XReceiver(NamedTuple): @@ -133,55 +134,68 @@ def __init__( assert isinstance(env.unwrapped, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space - msg_space = gym.spaces.Tuple( + o_observation_space: gym.spaces.Dict = self.env.observation_space + + header_space = gym.spaces.Tuple( ( - gym.spaces.Box( - low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8 - ), + gym.spaces.Text(20), # channel + gym.spaces.Text(30), # sender + gym.spaces.Text(10), # sender_type + gym.spaces.Sequence(gym.spaces.Text(30)), # cc + gym.spaces.Sequence(gym.spaces.Text(30)), # bcc + gym.spaces.Text(10), # format ) ) - self.action_space = gym.spaces.Dict( - { - a_id: gym.spaces.Tuple( - ( - base_action_space, - msg_space, - ) - ) - for a_id, base_action_space in o_action_space.spaces.items() - } - ) - o_observation_space: gym.spaces.Dict = self.env.observation_space - self._transmission_space = gym.spaces.Sequence( - gym.spaces.Tuple( + + def gen_msg_body_space(max_message_bytes: int): + return gym.spaces.Tuple( ( - gym.spaces.Tuple( - ( - gym.spaces.Text(20), # channel - gym.spaces.Text(30), # sender - gym.spaces.Text(10), # sender_type - gym.spaces.Sequence(gym.spaces.Text(30)), # cc - gym.spaces.Sequence(gym.spaces.Text(30)), # bcc - gym.spaces.Text(10), # format - ) + gym.spaces.Box( + low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8 ), - gym.spaces.Tuple((msg_space,)), ) ) - ) - self.observation_space = gym.spaces.Dict( - { - a_id: gym.spaces.Dict( + + def gen_msg_space(max_message_bytes: int): + return gym.spaces.Tuple( + ( + header_space, + gen_msg_body_space(max_message_bytes), + ) + ) + + def gen_transmission_space(max_message_bytes: int): + return gym.spaces.Sequence(gen_msg_space(max_message_bytes)) + + _action_space = {} + for a_id, base_action_space in o_action_space.spaces.items(): + if a_id not in message_config: + _action_space[a_id] = base_action_space + else: + _action_space[a_id] = gym.spaces.Tuple( + ( + base_action_space, + gen_transmission_space(max_message_bytes=max_message_bytes), + ) + ) + self.action_space = gym.spaces.Dict(_action_space) + + _observation_space = {} + for a_id, obs in o_observation_space.spaces.items(): + if a_id not in message_config: + _observation_space[a_id] = obs + else: + _observation_space[a_id] = gym.spaces.Dict( dict( **obs, - transmissions=self._transmission_space, + transmissions=gen_transmission_space( + max_message_bytes=max_message_bytes + ), ) ) - for a_id, obs in o_observation_space.items() - } - ) + self.observation_space = gym.spaces.Dict(_observation_space) - @lru_cache + @lru_cache() def resolve_alias(self, alias): """Resolve the alias to agent ids.""" return set(self._alias_mapping[alias]) From f98350fd801fff2905a6dd6b4725b124609ff817 Mon Sep 17 00:00:00 2001 From: Tucker Date: Fri, 3 Mar 2023 08:51:11 -0500 Subject: [PATCH 22/27] Add vehicle targetting communication wrapper. --- smarts/core/smarts.py | 4 +- smarts/core/vehicle_index.py | 2 +- .../gymnasium/wrappers/agent_communication.py | 203 +++++++++++------- .../env/gymnasium/wrappers/vehicle_watch.py | 88 ++++++++ 4 files changed, 219 insertions(+), 78 deletions(-) create mode 100644 smarts/env/gymnasium/wrappers/vehicle_watch.py diff --git a/smarts/core/smarts.py b/smarts/core/smarts.py index f1ea6b697a..e7be01b990 100644 --- a/smarts/core/smarts.py +++ b/smarts/core/smarts.py @@ -21,7 +21,7 @@ import logging import os import warnings -from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, TypeVar import numpy as np from scipy.spatial.distance import cdist @@ -1098,7 +1098,7 @@ def providers(self) -> List[Provider]: """The current providers controlling actors within the simulation.""" return self._providers - def get_provider_by_type(self, requested_type) -> Optional[Provider]: + def get_provider_by_type(self, requested_type: type) -> Optional[Provider]: """Get The first provider that matches the requested type.""" self._check_valid() for provider in self._providers: diff --git a/smarts/core/vehicle_index.py b/smarts/core/vehicle_index.py index 92520e3e54..8b7ac8120e 100644 --- a/smarts/core/vehicle_index.py +++ b/smarts/core/vehicle_index.py @@ -280,7 +280,7 @@ def vehicleitems(self) -> Iterator[Tuple[str, Vehicle]]: """A list of all vehicle IDs paired with their vehicle.""" return map(lambda x: (self._2id_to_id[x[0]], x[1]), self._vehicles.items()) - def vehicle_by_id(self, vehicle_id, default=...): + def vehicle_by_id(self, vehicle_id, default=...) -> Vehicle: """Get a vehicle by its id.""" vehicle_id = _2id(vehicle_id) if default is ...: diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 225c04a0bb..e6bbf7df1b 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -22,8 +22,8 @@ import enum from collections import defaultdict from enum import Enum, IntFlag, unique -from functools import lru_cache -from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple +from functools import lru_cache, partial +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple import gymnasium as gym import numpy as np @@ -104,12 +104,77 @@ class V2XReceiver(NamedTuple): """A configuratoin utility to set up agent to receive messages.""" bands: Bands - aliases: List[str] - whitelist_channels: Optional[List[str]] = None + aliases: Set[str] + whitelist_channels: Optional[Set[str]] = None blacklist_channels: Set[str] = set() sensitivity: Sensitivity = Sensitivity.STANDARD +@lru_cache +def active_filter(a: frozenset, a_in_observations): + return frozenset(a.intersection(a_in_observations)) + + +# filter recipients by band +## compare transmitter +def band_filter( + sender, recipients, message_config: Dict[str, Tuple[V2XTransmitter, V2XReceiver]] +): + return frozenset( + r + for r in recipients + if message_config[sender][0].bands | message_config[r][1].bands + ) + + +# filter recipients that do not listen to the channel +def accepts_channel(channel, receiver: V2XReceiver): + return ( + (not receiver.whitelist_channels) or (channel in receiver.whitelist_channels) + ) and channel not in receiver.blacklist_channels + + +def channel_filter( + channel, recipients, message_config: Dict[str, Tuple[V2XTransmitter, V2XReceiver]] +): + return frozenset( + r for r in recipients if accepts_channel(channel, message_config[r][1]) + ) + + +# pytype: enable=wrong-arg-types + +## filter recipients by distance +## Includes all +## TODO ensure this works on all formatting types +# cached_dist_comp = lambda sender, receiver: obs[sender]["position"].dot(obs[receiver]["position"]) +# cached_distance_filter = lru_cache(lambda sender, receivers: ( +# r for r in receivers if cached_distance_filter +# )) + +# compress filters +def general_filter( + header, + initial_recipients, + observations, + message_config, + alias_resolver: Callable[[frozenset], frozenset], +): + return ( + cc + for recipients in map(alias_resolver, initial_recipients) + for cc in channel_filter( + header.channel, + band_filter( + header.sender, + active_filter(frozenset(recipients), frozenset(observations.keys())), + message_config, + ), + message_config, + ) + ) + + class MessagePasser(gym.Wrapper): """This wrapper augments the observations and actions to require passing messages from agents. @@ -122,7 +187,9 @@ def __init__( max_message_bytes=125000, ): """""" + assert isinstance(env.unwrapped, HiWayEnvV1) super().__init__(env) + self._max_message_bytes = max_message_bytes self._message_config = message_config # map alias to agent ids (multiple agents can be under the same alias) self._alias_mapping = defaultdict(list) @@ -132,7 +199,6 @@ def __init__( self._alias_mapping[a_id].append(a_id) self._alias_mapping["__all__"].append(a_id) - assert isinstance(env.unwrapped, HiWayEnvV1) o_action_space: gym.spaces.Dict = self.env.action_space o_observation_space: gym.spaces.Dict = self.env.observation_space @@ -167,6 +233,10 @@ def gen_msg_space(max_message_bytes: int): def gen_transmission_space(max_message_bytes: int): return gym.spaces.Sequence(gen_msg_space(max_message_bytes)) + self._transmission_space = gen_transmission_space( + max_message_bytes=max_message_bytes + ) + _action_space = {} for a_id, base_action_space in o_action_space.spaces.items(): if a_id not in message_config: @@ -175,7 +245,7 @@ def gen_transmission_space(max_message_bytes: int): _action_space[a_id] = gym.spaces.Tuple( ( base_action_space, - gen_transmission_space(max_message_bytes=max_message_bytes), + self._transmission_space, ) ) self.action_space = gym.spaces.Dict(_action_space) @@ -188,98 +258,90 @@ def gen_transmission_space(max_message_bytes: int): _observation_space[a_id] = gym.spaces.Dict( dict( **obs, - transmissions=gen_transmission_space( - max_message_bytes=max_message_bytes - ), + transmissions=self._transmission_space, ) ) self.observation_space = gym.spaces.Dict(_observation_space) - @lru_cache() + @property + def message_config(self): + """The current message config. + + Returns: + Dict[str, Tuple[V2XTransmitter, V2XReceiver]]: The message configuration. + """ + return self._message_config.copy() + + @property + def max_message_bytes(self): + """The max message bytes size. + + Returns: + int: The max size of a message. + """ + return self._max_message_bytes + + @lru_cache def resolve_alias(self, alias): """Resolve the alias to agent ids.""" - return set(self._alias_mapping[alias]) + return frozenset(self._alias_mapping[alias]) def step(self, action): """Steps the environment using the given action.""" std_actions = {a_id: act for a_id, (act, _) in action.items()} observations, rewards, terms, truncs, infos = self.env.step(std_actions) - msgs = defaultdict(list) + obs_with_msgs = { + a_id: dict( + **obs, + transmissions=[], + ) + if a_id in self._message_config + else obs + for a_id, obs in observations.items() + } - # pytype: disable=wrong-arg-types - # filter recipients for active - cached_active_filter = lru_cache( - lambda a: frozenset(a.intersection(observations.keys())) + self.augment_observations( + messages=(msg for (_, msg) in action.values()), observations=obs_with_msgs ) - # filter recipients by band - ## compare transmitter - cached_band_filter = lru_cache( - lambda sender, recipients: frozenset( - r - for r in recipients - if self._message_config[sender][0].bands - | self._message_config[r][1].bands - ) - ) + return obs_with_msgs, rewards, terms, truncs, infos - # filter recipients that do not listen to the channel - accepts_channel = lru_cache( - lambda channel, recipient: ( - (not self._message_config[recipient][1].whitelist_channels) - or (channel in self._message_config[recipient][1].whitelist_channels) - ) - and channel not in self._message_config[recipient][1].blacklist_channels - ) - cached_channel_filter = lru_cache( - lambda channel, recipients: frozenset( - r for r in recipients if accepts_channel(channel, r) - ) - ) - # pytype: enable=wrong-arg-types - - ## filter recipients by distance - ## Includes all - ## TODO ensure this works on all formatting types - # cached_dist_comp = lambda sender, receiver: obs[sender]["position"].dot(obs[receiver]["position"]) - # cached_distance_filter = lru_cache(lambda sender, receivers: ( - # r for r in receivers if cached_distance_filter - # )) - - # compress filters - general_filter = lambda header, initial_recipients: ( - cc - for recipients in map(self.resolve_alias, initial_recipients) - for cc in cached_channel_filter( - header.channel, - cached_band_filter( - header.sender, cached_active_filter(frozenset(recipients)) - ), - ) + def augment_observations( + self, + messages, + observations, + ): + f = partial( + general_filter, + observations=observations, + message_config=self._message_config, + alias_resolver=self.resolve_alias, ) - - # Organise the messages to their recipients - for a_id, (_, msg) in action.items(): + for msg in messages: msg: List[Tuple[Header, Message]] = msg for header, message in msg: header: Header = header message: Message = message # expand the recipients - cc_recipients = set(general_filter(header, frozenset(header.cc))) - bcc_recipients = set(general_filter(header, frozenset(header.bcc))) + cc_recipients = set( + f(header=header, initial_recipients=frozenset(header.cc)) + ) + bcc_recipients = set( + f(header=header, initial_recipients=frozenset(header.bcc)) + ) cc_header = header._replace(cc=cc_recipients) # associate the messages to the recipients for recipient in ( cc_recipients - bcc_recipients ): # privacy takes priority - msgs[recipient].append( + observations[recipient]["transmissions"].append( (cc_header._replace(bcc=set()), message) # clear bcc ) for recipient in bcc_recipients: - msgs[recipient].append( + observations[recipient]["transmissions"].append( ( cc_header._replace( bcc={recipient} @@ -288,15 +350,6 @@ def step(self, action): ) ) - obs_with_msgs = { - a_id: dict( - **obs, - transmissions=msgs[a_id], - ) - for a_id, obs in observations.items() - } - return obs_with_msgs, rewards, terms, truncs, infos - def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Any, Dict[str, Any]]: diff --git a/smarts/env/gymnasium/wrappers/vehicle_watch.py b/smarts/env/gymnasium/wrappers/vehicle_watch.py new file mode 100644 index 0000000000..0886de0bc7 --- /dev/null +++ b/smarts/env/gymnasium/wrappers/vehicle_watch.py @@ -0,0 +1,88 @@ +import enum +from collections import defaultdict +from enum import Enum, IntFlag, unique +from functools import lru_cache +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Sequence, + Set, + SupportsFloat, + Tuple, + TypeVar, +) + +import gymnasium as gym +import numpy as np + +from smarts.core.smarts import SMARTS +from smarts.core.sumo_traffic_simulation import SumoTrafficSimulation +from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 + +from .agent_communication import ( + Bands, + Header, + Message, + MessagePasser, + V2XReceiver, + V2XTransmitter, +) + +C = TypeVar("C", Callable[[str, SMARTS], Sequence[Tuple[Header, Message]]]) + + +class VehicleWatch(gym.Wrapper): + def __init__(self, env, vehicle_watches: Dict[str, Tuple[V2XTransmitter, C]]): + super().__init__(env) + assert isinstance(self.env, MessagePasser) + + # action space should remain the same + self._vehicle_watches = vehicle_watches + + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> Tuple[Any, Dict[str, Any]]: + return super().reset(seed=seed, options=options) + + def step( + self, action: Any + ) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: + observation, reward, term, trunc, info = super().step(action) + env: MessagePasser = self.env + + smarts: SMARTS = env.smarts + msgs = [ + watch(target, smarts) + for target, (_, watch) in self._vehicle_watches.values() + ] + env.augment_observations(msgs, observation) + + return observation, reward, term, trunc, info + + def programmed_sumo_device( + self, target: str, smarts: SMARTS + ) -> Sequence[Tuple[Header, Message]]: + traffic_sim: SumoTrafficSimulation = smarts.get_provider_by_type( + SumoTrafficSimulation + ) + + if not target in traffic_sim.actor_ids: + return [] + + return [ + ( + Header( + channel="position", + sender=target, + sender_type="leader", + cc={"__all__"}, + bcc={}, + format="position", + ), + Message(smarts.vehicle_index.vehicle_by_id(target).position), + ) + ] From 931800f4b6d2f8785afe8ad71fa505aef4dc691e Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 10 Mar 2023 14:38:28 +0000 Subject: [PATCH 23/27] make gen-header --- .../env/gymnasium/wrappers/vehicle_watch.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/smarts/env/gymnasium/wrappers/vehicle_watch.py b/smarts/env/gymnasium/wrappers/vehicle_watch.py index 0886de0bc7..c31ab57ced 100644 --- a/smarts/env/gymnasium/wrappers/vehicle_watch.py +++ b/smarts/env/gymnasium/wrappers/vehicle_watch.py @@ -1,3 +1,24 @@ +# MIT License +# +# Copyright (C) 2023. Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +# THE SOFTWARE. import enum from collections import defaultdict from enum import Enum, IntFlag, unique From b51978576dce08aeeeb744082706c1b04d647182 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 10 Mar 2023 15:16:06 +0000 Subject: [PATCH 24/27] Fix docstring test --- examples/control/agent_communcation.py | 7 ++-- .../gymnasium/wrappers/agent_communication.py | 27 +++++++++------ .../env/gymnasium/wrappers/vehicle_watch.py | 34 ++++++++++++++++++- 3 files changed, 55 insertions(+), 13 deletions(-) diff --git a/examples/control/agent_communcation.py b/examples/control/agent_communcation.py index f6ad77d2d8..f318127777 100644 --- a/examples/control/agent_communcation.py +++ b/examples/control/agent_communcation.py @@ -1,6 +1,6 @@ import sys from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Dict, Generator, List, Tuple, Union from smarts.core.agent import Agent from smarts.core.agent_interface import AgentInterface, AgentType @@ -29,7 +29,10 @@ MESSAGE_BYTES = int(BYTES_IN_MEGABIT * MESSAGE_MEGABITS_PER_SECOND / TIMESTEP) -def filter_useless(transmissions): +def filter_useless( + transmissions: List[Header, Message] +) -> Generator[Tuple[Header, Message], None, None]: + """A primitive example filter that takes in transmissions and outputs filtered transmissions.""" for header, msg in transmissions: if header.sender in ("parked_agent", "broken_stoplight"): continue diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index e6bbf7df1b..6970ff89a5 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -33,8 +33,10 @@ @unique class Bands(IntFlag): - """Communcation bands.""" + """Communcation bands. The receiver must have the band the transmitter has in order to receive + information from the transmitter.""" + # TODO augment range with standard bands. L0 = enum.auto() L1 = enum.auto() L2 = enum.auto() @@ -52,6 +54,8 @@ class Bands(IntFlag): @unique class Sensitivity(Enum): + """Sensitivity models for the receiver. Will be modelled after wave attenuation.""" + LOW = 0 STANDARD = 1 HIGH = 2 @@ -112,14 +116,14 @@ class V2XReceiver(NamedTuple): @lru_cache def active_filter(a: frozenset, a_in_observations): + """A filter for active agents in observations.""" return frozenset(a.intersection(a_in_observations)) -# filter recipients by band -## compare transmitter def band_filter( sender, recipients, message_config: Dict[str, Tuple[V2XTransmitter, V2XReceiver]] ): + """A filter to only include recipients whose receivers match the transmitter band.""" return frozenset( r for r in recipients @@ -127,8 +131,8 @@ def band_filter( ) -# filter recipients that do not listen to the channel -def accepts_channel(channel, receiver: V2XReceiver): +def _accepts_channel(channel, receiver: V2XReceiver): + """Evaluates if the current receiver accepts the given channel.""" return ( (not receiver.whitelist_channels) or (channel in receiver.whitelist_channels) ) and channel not in receiver.blacklist_channels @@ -137,8 +141,9 @@ def accepts_channel(channel, receiver: V2XReceiver): def channel_filter( channel, recipients, message_config: Dict[str, Tuple[V2XTransmitter, V2XReceiver]] ): + """Filters recipients by channel.""" return frozenset( - r for r in recipients if accepts_channel(channel, message_config[r][1]) + r for r in recipients if _accepts_channel(channel, message_config[r][1]) ) @@ -154,12 +159,13 @@ def channel_filter( # compress filters def general_filter( - header, - initial_recipients, - observations, - message_config, + header: Header, + initial_recipients: frozenset, # frozenset[str] + observations: Dict, + message_config: Dict[str, Tuple[V2XTransmitter, V2XReceiver]], alias_resolver: Callable[[frozenset], frozenset], ): + """Compresses all used filters.""" return ( cc for recipients in map(alias_resolver, initial_recipients) @@ -312,6 +318,7 @@ def augment_observations( messages, observations, ): + """Updates the given observations to send out the given messages.""" f = partial( general_filter, observations=observations, diff --git a/smarts/env/gymnasium/wrappers/vehicle_watch.py b/smarts/env/gymnasium/wrappers/vehicle_watch.py index c31ab57ced..a69982e4b5 100644 --- a/smarts/env/gymnasium/wrappers/vehicle_watch.py +++ b/smarts/env/gymnasium/wrappers/vehicle_watch.py @@ -57,7 +57,20 @@ class VehicleWatch(gym.Wrapper): - def __init__(self, env, vehicle_watches: Dict[str, Tuple[V2XTransmitter, C]]): + """A wrapper that augments the MessagePasser wrapper to allow programmable messages. + + These messages are configured through vehicle_watches. + + Args: + env (MessagePasser): The base environment. This must be a MessagePasser. + vehicle_watches (Dict[str, Tuple[V2XTransmitter, C]]): + The configurable message generator. The first part is the transmitter config, the + second is the transmission generator callable. + """ + + def __init__( + self, env: MessagePasser, vehicle_watches: Dict[str, Tuple[V2XTransmitter, C]] + ): super().__init__(env) assert isinstance(self.env, MessagePasser) @@ -67,11 +80,21 @@ def __init__(self, env, vehicle_watches: Dict[str, Tuple[V2XTransmitter, C]]): def reset( self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[Any, Dict[str, Any]]: + """Resets the environment + + Args: + seed (int, optional): The environment seed. Defaults to None. + options (Dict[str, Any], optional): The options to the environment. Defaults to None. + + Returns: + Tuple[Any, Dict[str, Any]]: The observations and infos. + """ return super().reset(seed=seed, options=options) def step( self, action: Any ) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: + """The gym step function.""" observation, reward, term, trunc, info = super().step(action) env: MessagePasser = self.env @@ -87,6 +110,15 @@ def step( def programmed_sumo_device( self, target: str, smarts: SMARTS ) -> Sequence[Tuple[Header, Message]]: + """An example transmission method. + + Args: + target (str): The intended sender of the transmission. + smarts (SMARTS): The smarts instance to grab relevant information from. + + Returns: + Sequence[Tuple[Header, Message]]: A new set of transmissions. + """ traffic_sim: SumoTrafficSimulation = smarts.get_provider_by_type( SumoTrafficSimulation ) From b1a03efcd0716ed581b72c19da06f16f695ff6b8 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 10 Mar 2023 16:04:55 +0000 Subject: [PATCH 25/27] Fix type test. --- smarts/env/gymnasium/wrappers/agent_communication.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/smarts/env/gymnasium/wrappers/agent_communication.py b/smarts/env/gymnasium/wrappers/agent_communication.py index 6970ff89a5..c1842ef33b 100644 --- a/smarts/env/gymnasium/wrappers/agent_communication.py +++ b/smarts/env/gymnasium/wrappers/agent_communication.py @@ -114,7 +114,7 @@ class V2XReceiver(NamedTuple): sensitivity: Sensitivity = Sensitivity.STANDARD -@lru_cache +@lru_cache(128) def active_filter(a: frozenset, a_in_observations): """A filter for active agents in observations.""" return frozenset(a.intersection(a_in_observations)) @@ -287,7 +287,7 @@ def max_message_bytes(self): """ return self._max_message_bytes - @lru_cache + @lru_cache(128) def resolve_alias(self, alias): """Resolve the alias to agent ids.""" return frozenset(self._alias_mapping[alias]) From 85c7582e5d8e560623ba7e8c37bba817e1c7631a Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 10 Mar 2023 16:23:04 +0000 Subject: [PATCH 26/27] Fix remaining typing issues. --- examples/control/agent_communcation.py | 6 +++--- smarts/env/gymnasium/wrappers/vehicle_watch.py | 14 +++++++++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/examples/control/agent_communcation.py b/examples/control/agent_communcation.py index f318127777..b042ebbac1 100644 --- a/examples/control/agent_communcation.py +++ b/examples/control/agent_communcation.py @@ -30,7 +30,7 @@ def filter_useless( - transmissions: List[Header, Message] + transmissions: List[Tuple[Header, Message]] ) -> Generator[Tuple[Header, Message], None, None]: """A primitive example filter that takes in transmissions and outputs filtered transmissions.""" for header, msg in transmissions: @@ -147,7 +147,7 @@ def main(scenarios, headless, num_episodes, max_episode_steps=None): ), V2XReceiver( bands=Bands.ALL, - aliases=["tim"], + aliases={"tim"}, blacklist_channels={"self_control"}, ), ), @@ -158,7 +158,7 @@ def main(scenarios, headless, num_episodes, max_episode_steps=None): ), V2XReceiver( bands=Bands.ALL, - aliases=[], + aliases=set(), ), ), }, diff --git a/smarts/env/gymnasium/wrappers/vehicle_watch.py b/smarts/env/gymnasium/wrappers/vehicle_watch.py index a69982e4b5..e828868b99 100644 --- a/smarts/env/gymnasium/wrappers/vehicle_watch.py +++ b/smarts/env/gymnasium/wrappers/vehicle_watch.py @@ -40,6 +40,7 @@ import gymnasium as gym import numpy as np +from smarts.core.provider import Provider from smarts.core.smarts import SMARTS from smarts.core.sumo_traffic_simulation import SumoTrafficSimulation from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1 @@ -53,7 +54,7 @@ V2XTransmitter, ) -C = TypeVar("C", Callable[[str, SMARTS], Sequence[Tuple[Header, Message]]]) +CALLABLE = Callable[[str, SMARTS], Sequence[Tuple[Header, Message]]] class VehicleWatch(gym.Wrapper): @@ -69,7 +70,9 @@ class VehicleWatch(gym.Wrapper): """ def __init__( - self, env: MessagePasser, vehicle_watches: Dict[str, Tuple[V2XTransmitter, C]] + self, + env: MessagePasser, + vehicle_watches: Dict[str, Tuple[V2XTransmitter, CALLABLE]], ): super().__init__(env) assert isinstance(self.env, MessagePasser) @@ -101,7 +104,7 @@ def step( smarts: SMARTS = env.smarts msgs = [ watch(target, smarts) - for target, (_, watch) in self._vehicle_watches.values() + for target, (_, watch) in self._vehicle_watches.items() ] env.augment_observations(msgs, observation) @@ -119,9 +122,10 @@ def programmed_sumo_device( Returns: Sequence[Tuple[Header, Message]]: A new set of transmissions. """ - traffic_sim: SumoTrafficSimulation = smarts.get_provider_by_type( + traffic_sim: Optional[Provider] = smarts.get_provider_by_type( SumoTrafficSimulation ) + assert isinstance(traffic_sim, SumoTrafficSimulation) if not target in traffic_sim.actor_ids: return [] @@ -133,7 +137,7 @@ def programmed_sumo_device( sender=target, sender_type="leader", cc={"__all__"}, - bcc={}, + bcc=set(), format="position", ), Message(smarts.vehicle_index.vehicle_by_id(target).position), From 2cef72e174ffe4cf0451af4af5bd6c765a78a7b2 Mon Sep 17 00:00:00 2001 From: Montgomery Alban Date: Fri, 10 Mar 2023 16:26:27 +0000 Subject: [PATCH 27/27] Remove unused import. --- smarts/core/smarts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smarts/core/smarts.py b/smarts/core/smarts.py index e7be01b990..d958614828 100644 --- a/smarts/core/smarts.py +++ b/smarts/core/smarts.py @@ -21,7 +21,7 @@ import logging import os import warnings -from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple, TypeVar +from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Tuple import numpy as np from scipy.spatial.distance import cdist