Skip to content

Commit

Permalink
Add support for scene-wise computation of metrics
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696152643
Change-Id: I530337d334b8b1d9b5ba7b2776300acfdb111e76
  • Loading branch information
jzleibo authored and copybara-github committed Nov 13, 2024
1 parent 56807c9 commit 853abec
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 3 deletions.
17 changes: 16 additions & 1 deletion concordia/environment/scenes/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

"""Scene runner."""

from collections.abc import Sequence
from collections.abc import Mapping, Sequence

from concordia.agents import deprecated_agent
from concordia.environment import game_master
from concordia.typing import clock as game_clock
from concordia.typing import logging as logging_lib
from concordia.typing import scene as scene_lib
from concordia.utils import helper_functions


def _get_interscene_messages(
Expand Down Expand Up @@ -61,6 +63,7 @@ def run_scenes(
players: Sequence[deprecated_agent.BasicAgent],
clock: game_clock.GameClock,
verbose: bool = False,
compute_metrics: Mapping[str, logging_lib.Metric] | None = None,
) -> None:
"""Run a sequence of scenes.
Expand All @@ -70,6 +73,7 @@ def run_scenes(
players: full list of players (a subset may participate in each scene)
clock: the game clock which may be advanced between scenes
verbose: if true then print intermediate outputs
compute_metrics: Optionally, a function to compute metrics.
"""
players_by_name = {player.name: player for player in players}
if len(players_by_name) != len(players):
Expand Down Expand Up @@ -135,3 +139,14 @@ def run_scenes(
print(f'{participant.name} -- conclusion: {message}')
participant.observe(message)
this_scene_game_master_memory.add(message)

# Branch off a metric scene if applicable
if scene.scene_type.save_after_each_scene:
serialized_agents = {}
for participant in participants:
serialized_agents = {}
json_representation = helper_functions.save_to_json(participant)
serialized_agents[participant.name] = json_representation

if compute_metrics is not None:
compute_metrics(serialized_agents)
5 changes: 4 additions & 1 deletion concordia/factory/environment/basic_game_master.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""A Generic Environment Factory."""

from collections.abc import Callable, Sequence
from collections.abc import Callable, Mapping, Sequence
import operator

from concordia import components as generic_components
Expand Down Expand Up @@ -295,6 +295,7 @@ def run_simulation(
scenes: Sequence[scene_lib.SceneSpec],
secondary_environments: Sequence[game_master.GameMaster] = tuple(),
summarize_entire_episode_in_log: bool = True,
compute_metrics: Callable[[Mapping[str, str]], None] | None = None,
) -> str:
"""Run a simulation.
Expand All @@ -307,6 +308,7 @@ def run_simulation(
secondary_environments: Sequence of secondary game masters for scenes.
summarize_entire_episode_in_log: Optionally, include summaries of the full
episode in the log.
compute_metrics: Optionally, a function to compute metrics.
Returns:
an HTML string log of the simulation.
Expand All @@ -317,6 +319,7 @@ def run_simulation(
scenes=scenes,
players=players,
clock=clock,
compute_metrics=compute_metrics,
)
result_html_log = create_html_log(
model=model,
Expand Down
12 changes: 11 additions & 1 deletion concordia/typing/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,19 @@

"""Types for logging."""

from collections.abc import Mapping
from collections.abc import Mapping, Sequence
import dataclasses
from typing import Any, Callable
from concordia.typing import entity as entity_lib

LoggingChannel = Callable[[Mapping[str, Any]], None]

NoOpLoggingChannel = lambda x: None


@dataclasses.dataclass(frozen=True)
class Metric:
question: str
output_type: entity_lib.OutputType
options: Sequence[str] | None = None
context: str | None = None
3 changes: 3 additions & 0 deletions concordia/typing/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class SceneTypeSpec:
the game master to ask the agents to produce during steps of this scene.
override_game_master: optionally specify a game master to use instead of the
default one.
save_after_each_scene: optionally specify whether to save the agent's state
after each scene.
"""

name: str
Expand All @@ -52,6 +54,7 @@ class SceneTypeSpec:
| None
) = None
override_game_master: game_master.GameMaster | None = None
save_after_each_scene: bool = False


@dataclasses.dataclass(frozen=True)
Expand Down
42 changes: 42 additions & 0 deletions concordia/utils/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
from collections.abc import Iterable, Sequence
import datetime
import functools
import json

from concordia.agents import entity_agent_with_logging
from concordia.document import interactive_document
from concordia.language_model import language_model
from concordia.typing import component
from concordia.typing import entity_component
from concordia.utils import concurrency


Expand Down Expand Up @@ -144,3 +147,42 @@ def apply_recursively(
getattr(parent_component, function_name)()
else:
getattr(parent_component, function_name)(function_arg)


def save_to_json(
agent: entity_agent_with_logging.EntityAgentWithLogging,
) -> str:
"""Saves an agent to JSON data.
This function saves the agent's state to a JSON string, which can be loaded
afterwards with `rebuild_from_json`. The JSON data
includes the state of the agent's context components, act component, memory,
agent name and the initial config. The clock, model and embedder are not
saved and will have to be provided when the agent is rebuilt. The agent must
be in the `READY` phase to be saved.
Args:
agent: The agent to save.
Returns:
A JSON string representing the agent's state.
Raises:
ValueError: If the agent is not in the READY phase.
"""

if agent.get_phase() != entity_component.Phase.READY:
raise ValueError('The agent must be in the `READY` phase to be saved.')

data = {
component_name: agent.get_component(component_name).get_state()
for component_name in agent.get_all_context_components()
}

data['act_component'] = agent.get_act_component().get_state()

config = agent.get_config()
if config is not None:
data['agent_config'] = config.to_dict()

return json.dumps(data)

0 comments on commit 853abec

Please sign in to comment.