From eb06a1fb03f35225f02d77391d3bbbde617faf39 Mon Sep 17 00:00:00 2001 From: Sasha Vezhnevets Date: Tue, 29 Oct 2024 04:11:17 -0700 Subject: [PATCH] Add ability to save and load rational and basic agents to/from json. - add methods for saving and loading to components - add save / load to all main agent factories - saving and loading of associative memories PiperOrigin-RevId: 690960131 Change-Id: Ie323681fe63fbc966d70bbd21bcaf9d09f06c48d --- concordia/agents/entity_agent.py | 8 ++ concordia/agents/entity_agent_with_logging.py | 8 ++ .../associative_memory/associative_memory.py | 25 +++++ .../associative_memory/formative_memories.py | 19 +++- .../components/agent/action_spec_ignored.py | 11 +- .../components/agent/all_similar_memories.py | 5 +- .../components/agent/concat_act_component.py | 8 ++ .../components/agent/memory_component.py | 8 ++ concordia/components/agent/plan.py | 12 ++ .../question_of_query_associated_memories.py | 6 + .../agent/question_of_recent_memories.py | 1 + .../components/agent/affect_reflection.py | 90 +++++++++------ concordia/factory/agent/basic_agent.py | 101 +++++++++++++++-- .../factory/agent/basic_agent_without_plan.py | 85 +++++++++++++- .../agent/observe_recall_prompt_agent.py | 85 +++++++++++++- concordia/factory/agent/paranoid_agent.py | 85 +++++++++++++- concordia/factory/agent/rational_agent.py | 106 ++++++++++++++++-- concordia/factory/agent/synthetic_user.py | 83 ++++++++++++++ .../memory_bank/legacy_associative_memory.py | 52 ++++++++- concordia/typing/entity_component.py | 57 +++++++++- concordia/typing/memory.py | 46 ++++++++ 21 files changed, 832 insertions(+), 69 deletions(-) diff --git a/concordia/agents/entity_agent.py b/concordia/agents/entity_agent.py index b8cc2692..9f6c5a91 100644 --- a/concordia/agents/entity_agent.py +++ b/concordia/agents/entity_agent.py @@ -107,6 +107,14 @@ def get_component( component = self._context_components[name] return cast(entity_component.ComponentT, component) + def get_act_component(self) -> entity_component.ActingComponent: + return self._act_component + + def get_all_context_components( + self, + ) -> Mapping[str, entity_component.ContextComponent]: + return types.MappingProxyType(self._context_components) + def _parallel_call_( self, method_name: str, diff --git a/concordia/agents/entity_agent_with_logging.py b/concordia/agents/entity_agent_with_logging.py index ddf4d238..d705417b 100644 --- a/concordia/agents/entity_agent_with_logging.py +++ b/concordia/agents/entity_agent_with_logging.py @@ -15,10 +15,12 @@ """A modular entity agent using the new component system with side logging.""" from collections.abc import Mapping +import copy import types from typing import Any from absl import logging from concordia.agents import entity_agent +from concordia.associative_memory import formative_memories from concordia.typing import agent from concordia.typing import entity_component from concordia.utils import measurements as measurements_lib @@ -39,6 +41,7 @@ def __init__( types.MappingProxyType({}) ), component_logging: measurements_lib.Measurements | None = None, + config: formative_memories.AgentConfig | None = None, ): """Initializes the agent. @@ -56,6 +59,7 @@ def __init__( None, a NoOpContextProcessor will be used. context_components: The ContextComponents that will be used by the agent. component_logging: The channels where components publish events. + config: The agent configuration, used for checkpointing and debug. """ super().__init__(agent_name=agent_name, act_component=act_component, @@ -75,6 +79,7 @@ def __init__( on_error=lambda e: logging.error('Error in component logging: %s', e)) else: self._channel_names = [] + self._config = copy.deepcopy(config) def _set_log(self, log: tuple[Any, ...]) -> None: """Set the logging object to return from get_last_log. @@ -89,3 +94,6 @@ def _set_log(self, log: tuple[Any, ...]) -> None: def get_last_log(self): self._tick.on_next(None) # Trigger the logging. return self._log + + def get_config(self) -> formative_memories.AgentConfig | None: + return copy.deepcopy(self._config) diff --git a/concordia/associative_memory/associative_memory.py b/concordia/associative_memory/associative_memory.py index ba6bd83d..ec030b2b 100644 --- a/concordia/associative_memory/associative_memory.py +++ b/concordia/associative_memory/associative_memory.py @@ -25,9 +25,11 @@ import threading from concordia.associative_memory import importance_function +from concordia.typing import entity_component import numpy as np import pandas as pd + _NUM_TO_RETRIEVE_TO_CONTEXTUALIZE_IMPORTANCE = 25 @@ -79,6 +81,29 @@ def __init__( self._interval = clock_step_size self._stored_hashes = set() + def get_state(self) -> entity_component.ComponentState: + """Converts the AssociativeMemory to a dictionary.""" + + with self._memory_bank_lock: + output = { + 'seed': self._seed, + 'stored_hashes': list(self._stored_hashes), + 'memory_bank': self._memory_bank.to_json(), + } + if self._interval: + output['interval'] = self._interval.total_seconds() + return output + + def set_state(self, state: entity_component.ComponentState) -> None: + """Sets the AssociativeMemory from a dictionary.""" + + with self._memory_bank_lock: + self._seed = state['seed'] + self._stored_hashes = set(state['stored_hashes']) + self._memory_bank = pd.read_json(state['memory_bank']) + if 'interval' in state: + self._interval = datetime.timedelta(seconds=state['interval']) + def add( self, text: str, diff --git a/concordia/associative_memory/formative_memories.py b/concordia/associative_memory/formative_memories.py index e0759582..fd9fb3fb 100644 --- a/concordia/associative_memory/formative_memories.py +++ b/concordia/associative_memory/formative_memories.py @@ -15,7 +15,7 @@ """This is a factory for generating memories for concordia agents.""" -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Collection, Sequence import dataclasses import datetime import logging @@ -58,9 +58,24 @@ class AgentConfig: specific_memories: str = '' goal: str = '' date_of_birth: datetime.datetime = DEFAULT_DOB - formative_ages: Iterable[int] = DEFAULT_FORMATIVE_AGES + formative_ages: Collection[int] = DEFAULT_FORMATIVE_AGES extras: dict[str, Any] = dataclasses.field(default_factory=dict) + def to_dict(self) -> dict[str, Any]: + """Converts the AgentConfig to a dictionary.""" + result = dataclasses.asdict(self) + result['date_of_birth'] = self.date_of_birth.isoformat() + return result + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> 'AgentConfig': + """Initializes an AgentConfig from a dictionary.""" + date_of_birth = datetime.datetime.fromisoformat( + data['date_of_birth'] + ) + data = data | {'date_of_birth': date_of_birth} + return cls(**data) + class FormativeMemoryFactory: """Generator of formative memories.""" diff --git a/concordia/components/agent/action_spec_ignored.py b/concordia/components/agent/action_spec_ignored.py index d8fc4170..a6f1c804 100644 --- a/concordia/components/agent/action_spec_ignored.py +++ b/concordia/components/agent/action_spec_ignored.py @@ -16,10 +16,11 @@ import abc import threading -from typing import Final +from typing import Final, Any from concordia.typing import entity as entity_lib from concordia.typing import entity_component +from typing_extensions import override class ActionSpecIgnored( @@ -89,3 +90,11 @@ def get_named_component_pre_act_value(self, component_name: str) -> str: """Returns the pre-act value of a named component of the parent entity.""" return self.get_entity().get_component( component_name, type_=ActionSpecIgnored).get_pre_act_value() + + @override + def set_state(self, state: entity_component.ComponentState) -> Any: + return None + + @override + def get_state(self) -> entity_component.ComponentState: + return {} diff --git a/concordia/components/agent/all_similar_memories.py b/concordia/components/agent/all_similar_memories.py index 05f40db1..a2ad72c5 100644 --- a/concordia/components/agent/all_similar_memories.py +++ b/concordia/components/agent/all_similar_memories.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Return all memories similar to a prompt and filter them for relevance. -""" +"""Return all memories similar to a prompt and filter them for relevance.""" -from collections.abc import Mapping import types +from typing import Mapping from concordia.components.agent import action_spec_ignored from concordia.components.agent import memory_component diff --git a/concordia/components/agent/concat_act_component.py b/concordia/components/agent/concat_act_component.py index 5eb13681..ae303778 100644 --- a/concordia/components/agent/concat_act_component.py +++ b/concordia/components/agent/concat_act_component.py @@ -164,3 +164,11 @@ def _log(self, 'Value': result, 'Prompt': prompt.view().text().splitlines(), }) + + def get_state(self) -> entity_component.ComponentState: + """Converts the component to a dictionary.""" + return {} + + def set_state(self, state: entity_component.ComponentState) -> None: + pass + diff --git a/concordia/components/agent/memory_component.py b/concordia/components/agent/memory_component.py index 187eac99..66489c87 100644 --- a/concordia/components/agent/memory_component.py +++ b/concordia/components/agent/memory_component.py @@ -57,6 +57,14 @@ def _check_phase(self) -> None: 'You can only access the memory outside of the `UPDATE` phase.' ) + def get_state(self) -> Mapping[str, Any]: + with self._lock: + return self._memory.get_state() + + def set_state(self, state: Mapping[str, Any]) -> None: + with self._lock: + self._memory.set_state(state) + def retrieve( self, query: str = '', diff --git a/concordia/components/agent/plan.py b/concordia/components/agent/plan.py index 926f2bf8..13821f8c 100644 --- a/concordia/components/agent/plan.py +++ b/concordia/components/agent/plan.py @@ -163,3 +163,15 @@ def _make_pre_act_value(self) -> str: }) return result + + def get_state(self) -> entity_component.ComponentState: + """Converts the component to JSON data.""" + with self._lock: + return { + 'current_plan': self._current_plan, + } + + def set_state(self, state: entity_component.ComponentState) -> None: + """Sets the component state from JSON data.""" + with self._lock: + self._current_plan = state['current_plan'] diff --git a/concordia/components/agent/question_of_query_associated_memories.py b/concordia/components/agent/question_of_query_associated_memories.py index b86117d2..4e647e00 100644 --- a/concordia/components/agent/question_of_query_associated_memories.py +++ b/concordia/components/agent/question_of_query_associated_memories.py @@ -189,6 +189,12 @@ def pre_act( def update(self) -> None: self._component.update() + def get_state(self) -> entity_component.ComponentState: + return self._component.get_state() + + def set_state(self, state: entity_component.ComponentState) -> None: + self._component.set_state(state) + class Identity(QuestionOfQueryAssociatedMemories): """Identity component containing a few characteristics. diff --git a/concordia/components/agent/question_of_recent_memories.py b/concordia/components/agent/question_of_recent_memories.py index 81b2f5a6..c364b52a 100644 --- a/concordia/components/agent/question_of_recent_memories.py +++ b/concordia/components/agent/question_of_recent_memories.py @@ -193,6 +193,7 @@ class AvailableOptionsPerception(QuestionOfRecentMemories): """This component answers the question 'what actions are available to me?'.""" def __init__(self, **kwargs): + super().__init__( question=( 'Given the statements above, what actions are available to ' diff --git a/concordia/contrib/components/agent/affect_reflection.py b/concordia/contrib/components/agent/affect_reflection.py index e8c44048..95e26c65 100644 --- a/concordia/contrib/components/agent/affect_reflection.py +++ b/concordia/contrib/components/agent/affect_reflection.py @@ -68,8 +68,8 @@ def __init__( num_salient_to_retrieve: retrieve this many salient memories. num_questions_to_consider: how many questions to ask self. num_to_retrieve_per_question: how many memories to retrieve per question. - pre_act_key: Prefix to add to the output of the component when called - in `pre_act`. + pre_act_key: Prefix to add to the output of the component when called in + `pre_act`. logging_channel: The channel to use for debug logging. """ super().__init__(pre_act_key) @@ -91,35 +91,39 @@ def _make_pre_act_value(self) -> str: for key, prefix in self._components.items() ]) salience_chain_of_thought = interactive_document.InteractiveDocument( - self._model) + self._model + ) query = f'salient event, period, feeling, or concept for {agent_name}' timed_query = f'[{self._clock.now()}] {query}' memory = self.get_entity().get_component( - self._memory_component_name, - type_=memory_component.MemoryComponent) - mem_retrieved = '\n'.join( - [mem.text for mem in memory.retrieve( + self._memory_component_name, type_=memory_component.MemoryComponent + ) + mem_retrieved = '\n'.join([ + mem.text + for mem in memory.retrieve( query=timed_query, scoring_fn=legacy_associative_memory.RetrieveAssociative( use_recency=True, add_time=True ), - limit=self._num_salient_to_retrieve)] - ) + limit=self._num_salient_to_retrieve, + ) + ]) question_list = [] questions = salience_chain_of_thought.open_question( ( - f'Recent feelings: {self._previous_pre_act_value} \n' + - f"{agent_name}'s relevant memory:\n" + - f'{mem_retrieved}\n' + - f'Current time: {self._clock.now()}\n' + - '\nGiven the thoughts and beliefs above, what are the ' + - f'{self._num_questions_to_consider} most salient high-level '+ - f'questions that can be answered about what {agent_name} ' + - 'might be feeling about the current moment?'), + f'Recent feelings: {self._previous_pre_act_value} \n' + + f"{agent_name}'s relevant memory:\n" + + f'{mem_retrieved}\n' + + f'Current time: {self._clock.now()}\n' + + '\nGiven the thoughts and beliefs above, what are the ' + + f'{self._num_questions_to_consider} most salient high-level ' + + f'questions that can be answered about what {agent_name} ' + + 'might be feeling about the current moment?' + ), answer_prefix='- ', max_tokens=3000, terminators=(), @@ -128,25 +132,33 @@ def _make_pre_act_value(self) -> str: question_related_mems = [] for question in questions: question_list.append(question) - question_related_mems = [mem.text for mem in memory.retrieve( - query=agent_name, - scoring_fn=legacy_associative_memory.RetrieveAssociative( - use_recency=False, add_time=True - ), - limit=self._num_to_retrieve_per_question)] + question_related_mems = [ + mem.text + for mem in memory.retrieve( + query=agent_name, + scoring_fn=legacy_associative_memory.RetrieveAssociative( + use_recency=False, add_time=True + ), + limit=self._num_to_retrieve_per_question, + ) + ] insights = [] question_related_mems = '\n'.join(question_related_mems) chain_of_thought = interactive_document.InteractiveDocument(self._model) insight = chain_of_thought.open_question( - f'Selected memories:\n{question_related_mems}\n' + - f'Recent feelings: {self._previous_pre_act_value} \n\n' + - 'New context:\n' + context + '\n' + - f'Current time: {self._clock.now()}\n' + - 'What high-level insight can be inferred from the above ' + - f'statements about what {agent_name} might be feeling ' + - 'in the current moment?', - max_tokens=2000, terminators=(),) + f'Selected memories:\n{question_related_mems}\n' + + f'Recent feelings: {self._previous_pre_act_value} \n\n' + + 'New context:\n' + + context + + '\n' + + f'Current time: {self._clock.now()}\n' + + 'What high-level insight can be inferred from the above ' + + f'statements about what {agent_name} might be feeling ' + + 'in the current moment?', + max_tokens=2000, + terminators=(), + ) insights.append(insight) result = '\n'.join(insights) @@ -157,9 +169,19 @@ def _make_pre_act_value(self) -> str: 'Key': self.get_pre_act_key(), 'Value': result, 'Salience chain of thought': ( - salience_chain_of_thought.view().text().splitlines()), - 'Chain of thought': ( - chain_of_thought.view().text().splitlines()), + salience_chain_of_thought.view().text().splitlines() + ), + 'Chain of thought': chain_of_thought.view().text().splitlines(), }) return result + + def get_state(self) -> entity_component.ComponentState: + """Converts the component to a dictionary.""" + return { + 'previous_pre_act_value': self._previous_pre_act_value, + } + + def set_state(self, state: entity_component.ComponentState) -> None: + + self._previous_pre_act_value = str(state['previous_pre_act_value']) diff --git a/concordia/factory/agent/basic_agent.py b/concordia/factory/agent/basic_agent.py index cf48ef54..a8b6a3f4 100644 --- a/concordia/factory/agent/basic_agent.py +++ b/concordia/factory/agent/basic_agent.py @@ -13,8 +13,9 @@ # limitations under the License. """A factory implementing the three key questions agent as an entity.""" - +from collections.abc import Callable import datetime +import json from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory @@ -23,9 +24,12 @@ from concordia.components import agent as agent_components from concordia.language_model import language_model from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component from concordia.utils import measurements as measurements_lib +import numpy as np DEFAULT_PLANNING_HORIZON = 'the rest of the day, focusing most on the near term' +DEFAULT_GOAL_COMPONENT_NAME = 'Goal' def _get_class_name(object_: object) -> str: @@ -38,7 +42,7 @@ def build_agent( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, clock: game_clock.MultiIntervalClock, - update_time_interval: datetime.timedelta, + update_time_interval: datetime.timedelta | None = None, ) -> entity_agent_with_logging.EntityAgentWithLogging: """Build an agent. @@ -155,10 +159,12 @@ def build_agent( overarching_goal = agent_components.constant.Constant( state=config.goal, pre_act_key=goal_label, - logging_channel=measurements.get_channel(goal_label).on_next) - plan_components[goal_label] = goal_label + logging_channel=measurements.get_channel( + DEFAULT_GOAL_COMPONENT_NAME + ).on_next, + ) + plan_components[DEFAULT_GOAL_COMPONENT_NAME] = goal_label else: - goal_label = None overarching_goal = None plan_components.update({ @@ -189,7 +195,6 @@ def build_agent( person_by_situation, plan, time_display, - # Components that do not provide pre_act context. identity_characteristics, ) @@ -200,9 +205,9 @@ def build_agent( agent_components.memory_component.MemoryComponent(raw_memory)) component_order = list(components_of_agent.keys()) if overarching_goal is not None: - components_of_agent[goal_label] = overarching_goal + components_of_agent[DEFAULT_GOAL_COMPONENT_NAME] = overarching_goal # Place goal after the instructions. - component_order.insert(1, goal_label) + component_order.insert(1, DEFAULT_GOAL_COMPONENT_NAME) act_component = agent_components.concat_act_component.ConcatActComponent( model=model, @@ -216,6 +221,86 @@ def build_agent( act_component=act_component, context_components=components_of_agent, component_logging=measurements, + config=config, + ) + + return agent + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data) + + +def rebuild_from_json( + json_data: str, + model: language_model.LanguageModel, + clock: game_clock.MultiIntervalClock, + embedder: Callable[[str], np.ndarray], + memory_importance: Callable[[str], float] | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Rebuilds an agent from JSON data.""" + + data = json.loads(json_data) + + new_agent_memory = associative_memory.AssociativeMemory( + sentence_embedder=embedder, + importance=memory_importance, + clock=clock.now, + clock_step_size=clock.get_step_size(), ) + if 'agent_config' not in data: + raise ValueError('The JSON data does not contain the agent config.') + agent_config = formative_memories.AgentConfig.from_dict( + data.pop('agent_config') + ) + + agent = build_agent( + config=agent_config, + model=model, + memory=new_agent_memory, + clock=clock, + ) + + for component_name in agent.get_all_context_components(): + agent.get_component(component_name).set_state(data.pop(component_name)) + + agent.get_act_component().set_state(data.pop('act_component')) + + assert not data, f'Unused data {sorted(data)}' return agent diff --git a/concordia/factory/agent/basic_agent_without_plan.py b/concordia/factory/agent/basic_agent_without_plan.py index cf0601d4..5857c1ab 100644 --- a/concordia/factory/agent/basic_agent_without_plan.py +++ b/concordia/factory/agent/basic_agent_without_plan.py @@ -14,7 +14,9 @@ """A factory implementing the three key questions agent as an entity.""" +from collections.abc import Callable import datetime +import json from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory @@ -23,7 +25,9 @@ from concordia.components import agent as agent_components from concordia.language_model import language_model from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component from concordia.utils import measurements as measurements_lib +import numpy as np def _get_class_name(object_: object) -> str: @@ -36,7 +40,7 @@ def build_agent( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, clock: game_clock.MultiIntervalClock, - update_time_interval: datetime.timedelta, + update_time_interval: datetime.timedelta | None = None, ) -> entity_agent_with_logging.EntityAgentWithLogging: """Build an agent. @@ -199,3 +203,82 @@ def build_agent( ) return agent + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data) + + +def rebuild_from_json( + json_data: str, + model: language_model.LanguageModel, + clock: game_clock.MultiIntervalClock, + embedder: Callable[[str], np.ndarray], + memory_importance: Callable[[str], float] | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Rebuilds an agent from JSON data.""" + + data = json.loads(json_data) + + new_agent_memory = associative_memory.AssociativeMemory( + sentence_embedder=embedder, + importance=memory_importance, + clock=clock.now, + clock_step_size=clock.get_step_size(), + ) + + if 'agent_config' not in data: + raise ValueError('The JSON data does not contain the agent config.') + agent_config = formative_memories.AgentConfig.from_dict( + data.pop('agent_config') + ) + + agent = build_agent( + config=agent_config, + model=model, + memory=new_agent_memory, + clock=clock, + ) + + for component_name in agent.get_all_context_components(): + agent.get_component(component_name).set_state(data.pop(component_name)) + + agent.get_act_component().set_state(data.pop('act_component')) + + assert not data, f'Unused data {sorted(data)}' + return agent diff --git a/concordia/factory/agent/observe_recall_prompt_agent.py b/concordia/factory/agent/observe_recall_prompt_agent.py index e81dbfae..abff6361 100644 --- a/concordia/factory/agent/observe_recall_prompt_agent.py +++ b/concordia/factory/agent/observe_recall_prompt_agent.py @@ -14,7 +14,9 @@ """An Agent Factory.""" +from collections.abc import Callable import datetime +import json from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory @@ -23,7 +25,9 @@ from concordia.components import agent as agent_components from concordia.language_model import language_model from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component from concordia.utils import measurements as measurements_lib +import numpy as np def _get_class_name(object_: object) -> str: @@ -36,7 +40,7 @@ def build_agent( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, clock: game_clock.MultiIntervalClock, - update_time_interval: datetime.timedelta, + update_time_interval: datetime.timedelta | None = None, ) -> entity_agent_with_logging.EntityAgentWithLogging: """Build an agent. @@ -144,3 +148,82 @@ def build_agent( ) return agent + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data) + + +def rebuild_from_json( + json_data: str, + model: language_model.LanguageModel, + clock: game_clock.MultiIntervalClock, + embedder: Callable[[str], np.ndarray], + memory_importance: Callable[[str], float] | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Rebuilds an agent from JSON data.""" + + data = json.loads(json_data) + + new_agent_memory = associative_memory.AssociativeMemory( + sentence_embedder=embedder, + importance=memory_importance, + clock=clock.now, + clock_step_size=clock.get_step_size(), + ) + + if 'agent_config' not in data: + raise ValueError('The JSON data does not contain the agent config.') + agent_config = formative_memories.AgentConfig.from_dict( + data.pop('agent_config') + ) + + agent = build_agent( + config=agent_config, + model=model, + memory=new_agent_memory, + clock=clock, + ) + + for component_name in agent.get_all_context_components(): + agent.get_component(component_name).set_state(data.pop(component_name)) + + agent.get_act_component().set_state(data.pop('act_component')) + + assert not data, f'Unused data {sorted(data)}' + return agent diff --git a/concordia/factory/agent/paranoid_agent.py b/concordia/factory/agent/paranoid_agent.py index 1375edea..243f796e 100644 --- a/concordia/factory/agent/paranoid_agent.py +++ b/concordia/factory/agent/paranoid_agent.py @@ -14,7 +14,9 @@ """An Agent Factory.""" +from collections.abc import Callable import datetime +import json from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory @@ -23,7 +25,9 @@ from concordia.components import agent as agent_components from concordia.language_model import language_model from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component from concordia.utils import measurements as measurements_lib +import numpy as np def _get_class_name(object_: object) -> str: @@ -36,7 +40,7 @@ def build_agent( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, clock: game_clock.MultiIntervalClock, - update_time_interval: datetime.timedelta, + update_time_interval: datetime.timedelta | None = None, ) -> entity_agent_with_logging.EntityAgentWithLogging: """Build an agent. @@ -262,3 +266,82 @@ def build_agent( ) return agent + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data) + + +def rebuild_from_json( + json_data: str, + model: language_model.LanguageModel, + clock: game_clock.MultiIntervalClock, + embedder: Callable[[str], np.ndarray], + memory_importance: Callable[[str], float] | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Rebuilds an agent from JSON data.""" + + data = json.loads(json_data) + + new_agent_memory = associative_memory.AssociativeMemory( + sentence_embedder=embedder, + importance=memory_importance, + clock=clock.now, + clock_step_size=clock.get_step_size(), + ) + + if 'agent_config' not in data: + raise ValueError('The JSON data does not contain the agent config.') + agent_config = formative_memories.AgentConfig.from_dict( + data.pop('agent_config') + ) + + agent = build_agent( + config=agent_config, + model=model, + memory=new_agent_memory, + clock=clock, + ) + + for component_name in agent.get_all_context_components(): + agent.get_component(component_name).set_state(data.pop(component_name)) + + agent.get_act_component().set_state(data.pop('act_component')) + + assert not data, f'Unused data {sorted(data)}' + return agent diff --git a/concordia/factory/agent/rational_agent.py b/concordia/factory/agent/rational_agent.py index 01c1e80c..7e13a9a5 100644 --- a/concordia/factory/agent/rational_agent.py +++ b/concordia/factory/agent/rational_agent.py @@ -14,7 +14,9 @@ """An Agent Factory.""" +from collections.abc import Callable import datetime +import json from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory @@ -23,7 +25,12 @@ from concordia.components import agent as agent_components from concordia.language_model import language_model from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component from concordia.utils import measurements as measurements_lib +import numpy as np + + +DEFAULT_GOAL_COMPONENT_NAME = 'Goal' def _get_class_name(object_: object) -> str: @@ -36,7 +43,7 @@ def build_agent( model: language_model.LanguageModel, memory: associative_memory.AssociativeMemory, clock: game_clock.MultiIntervalClock, - update_time_interval: datetime.timedelta, + update_time_interval: datetime.timedelta | None = None, ) -> entity_agent_with_logging.EntityAgentWithLogging: """Build an agent. @@ -106,7 +113,7 @@ def build_agent( state=config.goal, pre_act_key=goal_label, logging_channel=measurements.get_channel(goal_label).on_next) - options_perception_components[goal_label] = goal_label + options_perception_components[DEFAULT_GOAL_COMPONENT_NAME] = goal_label else: goal_label = None overarching_goal = None @@ -136,7 +143,7 @@ def build_agent( f'best for {agent_name} to take right now?\nAnswer') best_option_perception = {} if config.goal: - best_option_perception[goal_label] = goal_label + best_option_perception[DEFAULT_GOAL_COMPONENT_NAME] = goal_label best_option_perception.update({ _get_class_name(observation): observation_label, _get_class_name(observation_summary): observation_summary_label, @@ -165,17 +172,18 @@ def build_agent( options_perception, best_option_perception, ) - components_of_agent = {_get_class_name(component): component - for component in entity_components} + components_of_agent = { + _get_class_name(component): component for component in entity_components + } components_of_agent[ - agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME] = ( - agent_components.memory_component.MemoryComponent(raw_memory)) + agent_components.memory_component.DEFAULT_MEMORY_COMPONENT_NAME + ] = agent_components.memory_component.MemoryComponent(raw_memory) component_order = list(components_of_agent.keys()) if overarching_goal is not None: - components_of_agent[goal_label] = overarching_goal + components_of_agent[DEFAULT_GOAL_COMPONENT_NAME] = overarching_goal # Place goal after the instructions. - component_order.insert(1, goal_label) + component_order.insert(1, DEFAULT_GOAL_COMPONENT_NAME) act_component = agent_components.concat_act_component.ConcatActComponent( model=model, @@ -189,6 +197,86 @@ def build_agent( act_component=act_component, context_components=components_of_agent, component_logging=measurements, + config=config, + ) + + return agent + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data) + + +def rebuild_from_json( + json_data: str, + model: language_model.LanguageModel, + clock: game_clock.MultiIntervalClock, + embedder: Callable[[str], np.ndarray], + memory_importance: Callable[[str], float] | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Rebuilds an agent from JSON data.""" + + data = json.loads(json_data) + + new_agent_memory = associative_memory.AssociativeMemory( + sentence_embedder=embedder, + importance=memory_importance, + clock=clock.now, + clock_step_size=clock.get_step_size(), ) + if 'agent_config' not in data: + raise ValueError('The JSON data does not contain the agent config.') + agent_config = formative_memories.AgentConfig.from_dict( + data.pop('agent_config') + ) + + agent = build_agent( + config=agent_config, + model=model, + memory=new_agent_memory, + clock=clock, + ) + + for component_name in agent.get_all_context_components(): + agent.get_component(component_name).set_state(data.pop(component_name)) + + agent.get_act_component().set_state(data.pop('act_component')) + + assert not data, f'Unused data {sorted(data)}' return agent diff --git a/concordia/factory/agent/synthetic_user.py b/concordia/factory/agent/synthetic_user.py index ab3ea276..e6d5c3d0 100644 --- a/concordia/factory/agent/synthetic_user.py +++ b/concordia/factory/agent/synthetic_user.py @@ -14,7 +14,9 @@ """A factory implementing a simulation of a user for a product.""" +from collections.abc import Callable import datetime +import json from concordia.agents import entity_agent_with_logging from concordia.associative_memory import associative_memory @@ -23,7 +25,9 @@ from concordia.components import agent as components from concordia.language_model import language_model from concordia.memory_bank import legacy_associative_memory +from concordia.typing import entity_component from concordia.utils import measurements as measurements_lib +import numpy as np def _get_class_name(object_: object) -> str: @@ -227,3 +231,82 @@ def build_agent( ) return agent + + +def save_to_json( + agent: entity_agent_with_logging.EntityAgentWithLogging, +) -> str: + """Saves an agent to JSON data. + + This function saves the agent's state to a JSON string, which can be loaded + afterwards with `rebuild_from_json`. The JSON data + includes the state of the agent's context components, act component, memory, + agent name and the initial config. The clock, model and embedder are not + saved and will have to be provided when the agent is rebuilt. The agent must + be in the `READY` phase to be saved. + + Args: + agent: The agent to save. + + Returns: + A JSON string representing the agent's state. + + Raises: + ValueError: If the agent is not in the READY phase. + """ + + if agent.get_phase() != entity_component.Phase.READY: + raise ValueError('The agent must be in the `READY` phase to be saved.') + + data = { + component_name: agent.get_component(component_name).get_state() + for component_name in agent.get_all_context_components() + } + + data['act_component'] = agent.get_act_component().get_state() + + config = agent.get_config() + if config is not None: + data['agent_config'] = config.to_dict() + + return json.dumps(data) + + +def rebuild_from_json( + json_data: str, + model: language_model.LanguageModel, + clock: game_clock.MultiIntervalClock, + embedder: Callable[[str], np.ndarray], + memory_importance: Callable[[str], float] | None = None, +) -> entity_agent_with_logging.EntityAgentWithLogging: + """Rebuilds an agent from JSON data.""" + + data = json.loads(json_data) + + new_agent_memory = associative_memory.AssociativeMemory( + sentence_embedder=embedder, + importance=memory_importance, + clock=clock.now, + clock_step_size=clock.get_step_size(), + ) + + if 'agent_config' not in data: + raise ValueError('The JSON data does not contain the agent config.') + agent_config = formative_memories.AgentConfig.from_dict( + data.pop('agent_config') + ) + + agent = build_agent( + config=agent_config, + model=model, + memory=new_agent_memory, + clock=clock, + ) + + for component_name in agent.get_all_context_components(): + agent.get_component(component_name).set_state(data.pop(component_name)) + + agent.get_act_component().set_state(data.pop('act_component')) + + assert not data, f'Unused data {sorted(data)}' + return agent diff --git a/concordia/memory_bank/legacy_associative_memory.py b/concordia/memory_bank/legacy_associative_memory.py index e252e446..a63b376f 100644 --- a/concordia/memory_bank/legacy_associative_memory.py +++ b/concordia/memory_bank/legacy_associative_memory.py @@ -14,14 +14,15 @@ """A memory bank that uses AssociativeMemory to store and retrieve memories.""" - -from collections.abc import Mapping, Sequence +from collections.abc import Sequence import dataclasses import datetime -from typing import Any +from typing import Any, Mapping from concordia.associative_memory import associative_memory +from concordia.typing import entity_component from concordia.typing import memory as memory_lib +from typing_extensions import override # These are dummy scoring functions that will be used to know the appropriate @@ -104,6 +105,51 @@ def __init__(self, memory: associative_memory.AssociativeMemory): def add(self, text: str, metadata: Mapping[str, Any]) -> None: self._memory.add(text, **metadata) + @override + def get_state(self) -> entity_component.ComponentState: + """Returns the state of the memory bank. + + See `set_state` for details. The default implementation returns an empty + dictionary. + """ + return self._memory.get_state() + + @override + def set_state(self, state: entity_component.ComponentState) -> None: + """Sets the state of the memory bank. + + This is used to restore the state of the memory bank. The state is assumed + to be the one returned by `get_state`. + The state does not need to contain any information that is passed in the + initialization of the memory bank (e.g. embedder, clock, imporance etc.) + It is assumed that set_state is called on the memory bank after it was + initialized with the same parameters as the one used to restore it. + The default implementation does nothing, which implies that the memory bank + does not have any state. + + Example (Creating a copy): + obj1 = MemoryBank(**kwargs) + state = obj.get_state() + obj2 = MemoryBank(**kwargs) + obj2.set_state(state) + # obj1 and obj2 will behave identically. + + Example (Restoring previous behavior): + obj = MemoryBank(**kwargs) + state = obj.get_state() + # do more with obj + obj.set_state(state) + # obj will now behave the same as it did before. + + Note that the state does not need to contain any information that is passed + in __init__ (e.g. the embedder, clock, imporance etc.) + + Args: + state: The state of the memory bank. + """ + + self._memory.set_state(state) + def _texts_with_constant_score( self, texts: Sequence[str]) -> Sequence[memory_lib.MemoryResult]: return [memory_lib.MemoryResult(text=t, score=0.0) for t in texts] diff --git a/concordia/typing/entity_component.py b/concordia/typing/entity_component.py index 2af4e904..07c88c35 100644 --- a/concordia/typing/entity_component.py +++ b/concordia/typing/entity_component.py @@ -19,7 +19,6 @@ import enum import functools from typing import TypeVar - from concordia.typing import entity as entity_lib ComponentName = str @@ -27,6 +26,10 @@ ComponentContextMapping = Mapping[ComponentName, ComponentContext] ComponentT = TypeVar("ComponentT", bound="BaseComponent") +_LeafT = str | int | float | None +_ValueT = _LeafT | Collection["_ValueT"] | Mapping[str, "_ValueT"] +ComponentState = Mapping[str, _ValueT] + class Phase(enum.Enum): """Phases of a component entity lifecycle. @@ -37,8 +40,7 @@ class Phase(enum.Enum): PRE_ACT: The agent has received a request to act. Components are being requested for their action context. This will be followed by `POST_ACT`. POST_ACT: The agent has just submitted an action attempt. Components are - being informed of the action attempt. This will be followed by - `UPDATE`. + being informed of the action attempt. This will be followed by `UPDATE`. PRE_OBSERVE: The agent has received an observation. Components are being informed of the observation. This will be followed by `POST_OBSERVE`. POST_OBSERVE: The agent has just observed. Components are given a chance to @@ -47,6 +49,7 @@ class Phase(enum.Enum): UPDATE: The agent is updating its internal state. This will be followed by `READY`. """ + READY = enum.auto() PRE_ACT = enum.auto() POST_ACT = enum.auto() @@ -76,9 +79,7 @@ def successors(self) -> Collection["Phase"]: def check_successor(self, successor: "Phase") -> None: """Raises ValueError if successor is not a valid next phase.""" if successor not in self.successors: - raise ValueError( - f"The transition from {self} to {successor} is invalid." - ) + raise ValueError(f"The transition from {self} to {successor} is invalid.") class BaseComponent: @@ -109,6 +110,50 @@ def get_entity(self) -> "EntityWithComponents": raise RuntimeError("Entity is not set.") return self._entity + def get_state(self) -> ComponentState: + """Returns the state of the component. + + See `set_state` for details. The default implementation returns an empty + dictionary. + """ + return {} + + def set_state(self, state: ComponentState): + """Sets the state of the component. + + This is used to restore the state of the component. The state is assumed to + be the one returned by `get_state`. + The state does not need to contain any information that is passed in the + initialization of the component (e.g. the memory bank, names of other + componets etc.) + It is assumed that set_state is called on the component after it was + initialized with the same parameters as the one used to restore it. + The default implementation does nothing, which implies that the component + does not have any state. + + Example (Creating a copy): + obj1 = Component(**kwargs) + state = obj.get_state() + obj2 = Componet(**kwargs) + obj2.set_state(state) + # obj1 and obj2 will behave identically. + + Example (Restoring previous behavior): + obj = Component(**kwargs) + state = obj.get_state() + # do more with obj + obj.set_state(state) + # obj will now behave the same as it did before. + + Note that the state does not need to contain any information that is passed + in __init__ (e.g. the memory bank, names of other components etc.) + + Args: + state: The state of the component. + """ + del state + return None + class EntityWithComponents(entity_lib.Entity): """An entity that contains components.""" diff --git a/concordia/typing/memory.py b/concordia/typing/memory.py index f98a80ea..78ceb629 100644 --- a/concordia/typing/memory.py +++ b/concordia/typing/memory.py @@ -18,6 +18,7 @@ from collections.abc import Mapping, Sequence import dataclasses from typing import Any, Protocol +from concordia.typing import entity_component class MemoryScorer(Protocol): @@ -74,6 +75,51 @@ def extend(self, texts: Sequence[str], metadata: Mapping[str, Any]) -> None: for text in texts: self.add(text, metadata) + @abc.abstractmethod + def get_state(self) -> entity_component.ComponentState: + """Returns the state of the memory bank. + + See `set_state` for details. The default implementation returns an empty + dictionary. + """ + raise NotImplementedError() + + @abc.abstractmethod + def set_state(self, state: entity_component.ComponentState) -> None: + """Sets the state of the memory bank. + + This is used to restore the state of the memory bank. The state is assumed + to be the one returned by `get_state`. + The state does not need to contain any information that is passed in the + initialization of the memory bank (e.g. embedder, clock, imporance etc.) + It is assumed that set_state is called on the memory bank after it was + initialized with the same parameters as the one used to restore it. + The default implementation does nothing, which implies that the memory bank + does not have any state. + + Example (Creating a copy): + obj1 = MemoryBank(**kwargs) + state = obj.get_state() + obj2 = MemoryBank(**kwargs) + obj2.set_state(state) + # obj1 and obj2 will behave identically. + + Example (Restoring previous behavior): + obj = MemoryBank(**kwargs) + state = obj.get_state() + # do more with obj + obj.set_state(state) + # obj will now behave the same as it did before. + + Note that the state does not need to contain any information that is passed + in __init__ (e.g. the embedder, clock, imporance etc.) + + Args: + state: The state of the memory bank. + """ + + raise NotImplementedError() + @abc.abstractmethod def retrieve( self,