From 2ff56571ddbefec4db38f7dcea72ab150822baa0 Mon Sep 17 00:00:00 2001 From: "Joel Z. Leibo" Date: Thu, 5 Sep 2024 05:00:07 -0700 Subject: [PATCH] Fix language model setup util PiperOrigin-RevId: 671327977 Change-Id: Idfa61f10f587149059549ff94d08f209cc9966e0 --- concordia/language_model/utils.py | 5 +++++ examples/modular/launch.py | 12 ++++++++++-- .../modular/launch_concordia_challenge_evaluation.py | 11 ++++++++++- examples/modular/launch_resident_visitor.py | 11 ++++++++++- 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/concordia/language_model/utils.py b/concordia/language_model/utils.py index e9457763..a1599b73 100644 --- a/concordia/language_model/utils.py +++ b/concordia/language_model/utils.py @@ -49,6 +49,11 @@ def language_model_setup( if disable_language_model: return no_language_model.NoLanguageModel() elif api_type == 'amazon_bedrock': + if api_key is not None: + raise ValueError( + 'Explicitly passing the API key is not supported for Amazon Bedrock ' + 'models. Please use an environment variable instead.' + ) return amazon_bedrock_model.AmazonBedrockLanguageModel(model_name) elif api_type == 'google_aistudio_model': return google_aistudio_model.GoogleAIStudioLanguageModel( diff --git a/examples/modular/launch.py b/examples/modular/launch.py index b8be4a83..645dcfc4 100644 --- a/examples/modular/launch.py +++ b/examples/modular/launch.py @@ -52,7 +52,6 @@ from concordia.utils import measurements as measurements_lib import sentence_transformers - # Setup for command line arguments parser = argparse.ArgumentParser(description='Run a GDM-Concordia simulation.') parser.add_argument('--agent', @@ -75,6 +74,10 @@ action='store', default='all-mpnet-base-v2', dest='embedder_name') +parser.add_argument('--api_key', + action='store', + default=None, + dest='api_key') parser.add_argument('--disable_language_model', action='store_true', help=('replace the language model with a null model. This ' @@ -97,7 +100,12 @@ f'{IMPORT_ENV_BASE_DIR}.{command_line_args.environment_name}') # Language Model setup -model = utils.language_model_setup(**vars(command_line_args)) +model = utils.language_model_setup( + api_type=command_line_args.api_type, + model_name=command_line_args.model_name, + api_key=command_line_args.api_key, + disable_language_model=command_line_args.disable_language_model, +) # Setup sentence encoder st_model = sentence_transformers.SentenceTransformer( f'sentence-transformers/{command_line_args.embedder_name}') diff --git a/examples/modular/launch_concordia_challenge_evaluation.py b/examples/modular/launch_concordia_challenge_evaluation.py index 87af3356..931beb1b 100644 --- a/examples/modular/launch_concordia_challenge_evaluation.py +++ b/examples/modular/launch_concordia_challenge_evaluation.py @@ -113,6 +113,10 @@ default=2, dest='num_repetitions_per_scenario', ) +parser.add_argument('--api_key', + action='store', + default=None, + dest='api_key') parser.add_argument( '--disable_language_model', action='store_true', @@ -172,7 +176,12 @@ for scenario_name, scenario_config in scenarios_lib.SCENARIO_CONFIGS.items(): print(f'Running scenario: {scenario_name}') # Language Model setup - model = utils.language_model_setup(**vars(args)) + model = utils.language_model_setup( + api_type=args.api_type, + model_name=args.model_name, + api_key=args.api_key, + disable_language_model=args.disable_language_model, + ) # Setup sentence encoder embedder = lambda x: st_model.encode(x, show_progress_bar=False) diff --git a/examples/modular/launch_resident_visitor.py b/examples/modular/launch_resident_visitor.py index 7849b42c..657cd9e9 100644 --- a/examples/modular/launch_resident_visitor.py +++ b/examples/modular/launch_resident_visitor.py @@ -79,6 +79,10 @@ action='store', default='all-mpnet-base-v2', dest='embedder_name') +parser.add_argument('--api_key', + action='store', + default=None, + dest='api_key') parser.add_argument('--disable_language_model', action='store_true', help=('replace the language model with a null model. This ' @@ -103,7 +107,12 @@ f'{IMPORT_ENV_BASE_DIR}.{args.environment_name}') # Language Model setup -model = utils.language_model_setup(**vars(args)) +model = utils.language_model_setup( + api_type=args.api_type, + model_name=args.model_name, + api_key=args.api_key, + disable_language_model=args.disable_language_model, +) # Setup sentence encoder st_model = sentence_transformers.SentenceTransformer(