diff --git a/concordia/environment/scenes/runner.py b/concordia/environment/scenes/runner.py index 6f2d3c33..b3bdfed1 100644 --- a/concordia/environment/scenes/runner.py +++ b/concordia/environment/scenes/runner.py @@ -14,12 +14,14 @@ """Scene runner.""" -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from concordia.agents import deprecated_agent from concordia.environment import game_master from concordia.typing import clock as game_clock +from concordia.typing import logging as logging_lib from concordia.typing import scene as scene_lib +from concordia.utils import helper_functions def _get_interscene_messages( @@ -61,6 +63,7 @@ def run_scenes( players: Sequence[deprecated_agent.BasicAgent], clock: game_clock.GameClock, verbose: bool = False, + compute_metrics: Mapping[str, logging_lib.Metric] | None = None, ) -> None: """Run a sequence of scenes. @@ -70,6 +73,7 @@ def run_scenes( players: full list of players (a subset may participate in each scene) clock: the game clock which may be advanced between scenes verbose: if true then print intermediate outputs + compute_metrics: Optionally, a function to compute metrics. """ players_by_name = {player.name: player for player in players} if len(players_by_name) != len(players): @@ -135,3 +139,14 @@ def run_scenes( print(f'{participant.name} -- conclusion: {message}') participant.observe(message) this_scene_game_master_memory.add(message) + + # Branch off a metric scene if applicable + if scene.scene_type.save_after_each_scene: + serialized_agents = {} + for participant in participants: + serialized_agents = {} + json_representation = helper_functions.save_to_json(participant) + serialized_agents[participant.name] = json_representation + + if compute_metrics is not None: + compute_metrics(serialized_agents) diff --git a/concordia/factory/environment/basic_game_master.py b/concordia/factory/environment/basic_game_master.py index 81894deb..570934a9 100644 --- a/concordia/factory/environment/basic_game_master.py +++ b/concordia/factory/environment/basic_game_master.py @@ -14,7 +14,7 @@ """A Generic Environment Factory.""" -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence import operator from concordia import components as generic_components @@ -295,6 +295,7 @@ def run_simulation( scenes: Sequence[scene_lib.SceneSpec], secondary_environments: Sequence[game_master.GameMaster] = tuple(), summarize_entire_episode_in_log: bool = True, + compute_metrics: Callable[[Mapping[str, str]], None] | None = None, ) -> str: """Run a simulation. @@ -307,6 +308,7 @@ def run_simulation( secondary_environments: Sequence of secondary game masters for scenes. summarize_entire_episode_in_log: Optionally, include summaries of the full episode in the log. + compute_metrics: Optionally, a function to compute metrics. Returns: an HTML string log of the simulation. @@ -317,6 +319,7 @@ def run_simulation( scenes=scenes, players=players, clock=clock, + compute_metrics=compute_metrics, ) result_html_log = create_html_log( model=model, diff --git a/concordia/typing/logging.py b/concordia/typing/logging.py index ebd72612..b3435903 100644 --- a/concordia/typing/logging.py +++ b/concordia/typing/logging.py @@ -14,9 +14,19 @@ """Types for logging.""" -from collections.abc import Mapping +from collections.abc import Mapping, Sequence +import dataclasses from typing import Any, Callable +from concordia.typing import entity as entity_lib LoggingChannel = Callable[[Mapping[str, Any]], None] NoOpLoggingChannel = lambda x: None + + +@dataclasses.dataclass(frozen=True) +class Metric: + question: str + output_type: entity_lib.OutputType + options: Sequence[str] | None = None + context: str | None = None diff --git a/concordia/typing/scene.py b/concordia/typing/scene.py index ad062000..c071d341 100644 --- a/concordia/typing/scene.py +++ b/concordia/typing/scene.py @@ -39,6 +39,8 @@ class SceneTypeSpec: the game master to ask the agents to produce during steps of this scene. override_game_master: optionally specify a game master to use instead of the default one. + save_after_each_scene: optionally specify whether to save the agent's state + after each scene. """ name: str @@ -52,6 +54,7 @@ class SceneTypeSpec: | None ) = None override_game_master: game_master.GameMaster | None = None + save_after_each_scene: bool = False @dataclasses.dataclass(frozen=True) diff --git a/concordia/utils/helper_functions.py b/concordia/utils/helper_functions.py index 920f4d8c..a28e67ed 100644 --- a/concordia/utils/helper_functions.py +++ b/concordia/utils/helper_functions.py @@ -17,10 +17,13 @@ from collections.abc import Iterable, Sequence import datetime import functools +import json +from concordia.agents import entity_agent_with_logging from concordia.document import interactive_document from concordia.language_model import language_model from concordia.typing import component +from concordia.typing import entity_component from concordia.utils import concurrency @@ -144,3 +147,42 @@ def apply_recursively( getattr(parent_component, function_name)() else: getattr(parent_component, function_name)(function_arg) + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data)