Skip to content

Commit

Permalink
Fix language model setup util
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 671327977
Change-Id: Idfa61f10f587149059549ff94d08f209cc9966e0
  • Loading branch information
jzleibo authored and copybara-github committed Sep 5, 2024
1 parent 5ac80a9 commit 2ff5657
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
5 changes: 5 additions & 0 deletions concordia/language_model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ def language_model_setup(
if disable_language_model:
return no_language_model.NoLanguageModel()
elif api_type == 'amazon_bedrock':
if api_key is not None:
raise ValueError(
'Explicitly passing the API key is not supported for Amazon Bedrock '
'models. Please use an environment variable instead.'
)
return amazon_bedrock_model.AmazonBedrockLanguageModel(model_name)
elif api_type == 'google_aistudio_model':
return google_aistudio_model.GoogleAIStudioLanguageModel(
Expand Down
12 changes: 10 additions & 2 deletions examples/modular/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
from concordia.utils import measurements as measurements_lib
import sentence_transformers


# Setup for command line arguments
parser = argparse.ArgumentParser(description='Run a GDM-Concordia simulation.')
parser.add_argument('--agent',
Expand All @@ -75,6 +74,10 @@
action='store',
default='all-mpnet-base-v2',
dest='embedder_name')
parser.add_argument('--api_key',
action='store',
default=None,
dest='api_key')
parser.add_argument('--disable_language_model',
action='store_true',
help=('replace the language model with a null model. This '
Expand All @@ -97,7 +100,12 @@
f'{IMPORT_ENV_BASE_DIR}.{command_line_args.environment_name}')

# Language Model setup
model = utils.language_model_setup(**vars(command_line_args))
model = utils.language_model_setup(
api_type=command_line_args.api_type,
model_name=command_line_args.model_name,
api_key=command_line_args.api_key,
disable_language_model=command_line_args.disable_language_model,
)
# Setup sentence encoder
st_model = sentence_transformers.SentenceTransformer(
f'sentence-transformers/{command_line_args.embedder_name}')
Expand Down
11 changes: 10 additions & 1 deletion examples/modular/launch_concordia_challenge_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@
default=2,
dest='num_repetitions_per_scenario',
)
parser.add_argument('--api_key',
action='store',
default=None,
dest='api_key')
parser.add_argument(
'--disable_language_model',
action='store_true',
Expand Down Expand Up @@ -172,7 +176,12 @@
for scenario_name, scenario_config in scenarios_lib.SCENARIO_CONFIGS.items():
print(f'Running scenario: {scenario_name}')
# Language Model setup
model = utils.language_model_setup(**vars(args))
model = utils.language_model_setup(
api_type=args.api_type,
model_name=args.model_name,
api_key=args.api_key,
disable_language_model=args.disable_language_model,
)
# Setup sentence encoder
embedder = lambda x: st_model.encode(x, show_progress_bar=False)

Expand Down
11 changes: 10 additions & 1 deletion examples/modular/launch_resident_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
action='store',
default='all-mpnet-base-v2',
dest='embedder_name')
parser.add_argument('--api_key',
action='store',
default=None,
dest='api_key')
parser.add_argument('--disable_language_model',
action='store_true',
help=('replace the language model with a null model. This '
Expand All @@ -103,7 +107,12 @@
f'{IMPORT_ENV_BASE_DIR}.{args.environment_name}')

# Language Model setup
model = utils.language_model_setup(**vars(args))
model = utils.language_model_setup(
api_type=args.api_type,
model_name=args.model_name,
api_key=args.api_key,
disable_language_model=args.disable_language_model,
)

# Setup sentence encoder
st_model = sentence_transformers.SentenceTransformer(
Expand Down

0 comments on commit 2ff5657

Please sign in to comment.