Skip to content

Commit

Permalink
Add agent communcation example.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gamenot committed Mar 1, 2023
1 parent af425d0 commit f6355a6
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 11 deletions.
207 changes: 207 additions & 0 deletions examples/control/agent_communcation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
import sys
from pathlib import Path
from typing import Any, Dict, Union

from smarts.core.agent import Agent
from smarts.core.agent_interface import AgentInterface, AgentType
from smarts.core.utils.episodes import episodes
from smarts.env.gymnasium.hiway_env_v1 import HiWayEnvV1
from smarts.env.gymnasium.wrappers.agent_communication import (
Bands,
Header,
Message,
MessagePasser,
V2XReceiver,
V2XTransmitter,
)
from smarts.env.utils.action_conversion import ActionOptions
from smarts.env.utils.observation_conversion import ObservationOptions
from smarts.sstudio.scenario_construction import build_scenarios

sys.path.insert(0, str(Path(__file__).parents[2].absolute()))
import gymnasium as gym

from examples.tools.argument_parser import default_argument_parser

TIMESTEP = 0.1
BYTES_IN_MEGABIT = 125000
MESSAGE_MEGABITS_PER_SECOND = 10
MESSAGE_BYTES = int(BYTES_IN_MEGABIT * MESSAGE_MEGABITS_PER_SECOND / TIMESTEP)


def filter_useless(transmissions):
for header, msg in transmissions:
if header.sender in ("parked_agent", "broken_stoplight"):
continue
if header.sender_type in ("advertisement",):
continue
yield header, msg


class LaneFollowerAgent(Agent):
def act(self, obs: Dict[Any, Union[Any, Dict]]):
return (obs["waypoint_paths"]["speed_limit"][0][0], 0)


class GossiperAgent(Agent):
def __init__(self, id_: str, base_agent: Agent, filter_, friends):
self._filter = filter_
self._id = id_
self._friends = friends
self._base_agent = base_agent

def act(self, obs, **configs):
out_transmissions = []
for header, msg in self._filter(obs["transmissions"]):
header: Header = header
msg: Message = msg
if not {self._id, "__all__"}.intersection(header.cc | header.bcc):
continue
if header.channel == "position_request":
print()
print("On step: ", obs["steps_completed"])
print("Gossiper received position request: ", header)
out_transmissions.append(
(
Header(
channel="position",
sender=self._id,
sender_type="ad_vehicle",
cc={header.sender},
bcc={*self._friends},
format="position",
), # optimize this later
Message(
content=obs["ego_vehicle_state"]["position"],
), # optimize this later
)
)
print("Gossiper sent position: ", out_transmissions[0][1])

base_action = self._base_agent.act(obs)
return (base_action, out_transmissions)


class SchemerAgent(Agent):
def __init__(self, id_: str, base_agent: Agent, request_freq) -> None:
self._base_agent = base_agent
self._id = id_
self._request_freq = request_freq

def act(self, obs, **configs):
out_transmissions = []
for header, msg in obs["transmissions"]:
header: Header = header
msg: Message = msg
if header.channel == "position":
print()
print("On step: ", obs["steps_completed"])
print("Schemer received position: ", msg)

if obs["steps_completed"] % self._request_freq == 0:
print()
print("On step: ", obs["steps_completed"])
out_transmissions.append(
(
Header(
channel="position_request",
sender=self._id,
sender_type="ad_vehicle",
cc=set(),
bcc={"__all__"},
format="position_request",
),
Message(content=None),
)
)
print("Schemer requested position with: ", out_transmissions[0][0])

base_action = self._base_agent.act(obs)
return (base_action, out_transmissions)


def main(scenarios, headless, num_episodes, max_episode_steps=None):
agent_interface = AgentInterface.from_type(
AgentType.LanerWithSpeed, max_episode_steps=max_episode_steps
)
hiwayv1env = HiWayEnvV1(
scenarios=scenarios,
agent_interfaces={"gossiper0": agent_interface, "schemer": agent_interface},
headless=headless,
observation_options=ObservationOptions.multi_agent,
action_options=ActionOptions.default,
)
# for now
env = MessagePasser(
hiwayv1env,
max_message_bytes=MESSAGE_BYTES,
message_config={
"gossiper0": (
V2XTransmitter(
bands=Bands.ALL,
range=100,
# available_channels=["position_request", "position"]
),
V2XReceiver(
bands=Bands.ALL,
aliases=["tim"],
blacklist_channels={"self_control"},
),
),
"schemer": (
V2XTransmitter(
bands=Bands.ALL,
range=100,
),
V2XReceiver(
bands=Bands.ALL,
aliases=[],
),
),
},
)
agents = {
"gossiper0": GossiperAgent(
"gossiper0",
base_agent=LaneFollowerAgent(),
filter_=filter_useless,
friends={"schemer"},
),
"schemer": SchemerAgent(
"schemer", base_agent=LaneFollowerAgent(), request_freq=100
),
}

# then just the standard gym interface with no modifications
for episode in episodes(n=num_episodes):
observation, info = env.reset()
episode.record_scenario(env.scenario_log)

terminated = {"__all__": False}
while not terminated["__all__"]:
agent_action = {
agent_id: agents[agent_id].act(obs)
for agent_id, obs in observation.items()
}
observation, reward, terminated, truncated, info = env.step(agent_action)
episode.record_step(observation, reward, terminated, info)

env.close()


if __name__ == "__main__":
parser = default_argument_parser("single-agent-example")
args = parser.parse_args()

if not args.scenarios:
args.scenarios = [
str(Path(__file__).absolute().parents[2] / "scenarios" / "sumo" / "loop")
]

build_scenarios(scenarios=args.scenarios)

main(
scenarios=args.scenarios,
headless=args.headless,
num_episodes=args.episodes,
)
32 changes: 21 additions & 11 deletions smarts/env/gymnasium/wrappers/agent_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,17 @@ def __init__(
for a_id, (_, receiver) in message_config.items():
for alias in receiver.aliases:
self._alias_mapping[alias].append(a_id)
self._alias_mapping[a_id].append(a_id)
self._alias_mapping["__all__"].append(a_id)

assert isinstance(env, HiWayEnvV1)
assert isinstance(env.unwrapped, HiWayEnvV1)
o_action_space: gym.spaces.Dict = self.env.action_space
msg_space = (
gym.spaces.Box(low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8),
msg_space = gym.spaces.Tuple(
(
gym.spaces.Box(
low=0, high=256, shape=(max_message_bytes,), dtype=np.uint8
),
)
)
self.action_space = gym.spaces.Dict(
{
Expand Down Expand Up @@ -182,19 +188,21 @@ def resolve_alias(self, alias):

def step(self, action):
"""Steps the environment using the given action."""
std_actions = {a_id: act for a_id, (act, _) in action}
std_actions = {a_id: act for a_id, (act, _) in action.items()}
observations, rewards, terms, truncs, infos = self.env.step(std_actions)

msgs = defaultdict(list)

# pytype: disable=wrong-arg-types
# filter recipients for active
cached_active_filter = lru_cache(lambda a: a.intersection(observations.keys()))
cached_active_filter = lru_cache(
lambda a: frozenset(a.intersection(observations.keys()))
)

# filter recipients by band
## compare transmitter
cached_band_filter = lru_cache(
lambda sender, recipients: (
lambda sender, recipients: frozenset(
r
for r in recipients
if self._message_config[sender][0].bands
Expand All @@ -211,7 +219,7 @@ def step(self, action):
and channel not in self._message_config[recipient][1].blacklist_channels
)
cached_channel_filter = lru_cache(
lambda channel, recipients: (
lambda channel, recipients: frozenset(
r for r in recipients if accepts_channel(channel, r)
)
)
Expand All @@ -231,7 +239,9 @@ def step(self, action):
for recipients in map(self.resolve_alias, initial_recipients)
for cc in cached_channel_filter(
header.channel,
cached_band_filter(header.sender, cached_active_filter(recipients)),
cached_band_filter(
header.sender, cached_active_filter(frozenset(recipients))
),
)
)

Expand All @@ -243,8 +253,8 @@ def step(self, action):
message: Message = message

# expand the recipients
cc_recipients = set(general_filter(header, header.cc))
bcc_recipients = set(general_filter(header, header.bcc))
cc_recipients = set(general_filter(header, frozenset(header.cc)))
bcc_recipients = set(general_filter(header, frozenset(header.bcc)))
cc_header = header._replace(cc=cc_recipients)

# associate the messages to the recipients
Expand Down Expand Up @@ -279,7 +289,7 @@ def reset(
"""Resets the environment."""
observations, info = super().reset(seed=seed, options=options)
obs_with_msgs = {
a_id: dict(**obs, transmissions=self._transmission_space.sample(0))
a_id: dict(**obs, transmissions=self._transmission_space.sample((0, ())))
for a_id, obs in observations.items()
}
return obs_with_msgs, info

0 comments on commit f6355a6

Please sign in to comment.