Skip to content

Commit

Permalink
remove unnecessary dependencies
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 661214283
Change-Id: I785210ae245f4c2f5c618c44dd09f51a8e9f03ff
  • Loading branch information
jagapiou authored and copybara-github committed Aug 9, 2024
1 parent 5514bad commit db59748
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 77 deletions.
53 changes: 17 additions & 36 deletions examples/modular/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import argparse
import datetime
import importlib
import os
import pathlib
import sys

Expand All @@ -56,47 +55,29 @@
from concordia.language_model import ollama_model
from concordia.language_model import pytorch_gemma_model
from concordia.utils import measurements as measurements_lib
import openai
import sentence_transformers


def language_model_setup(args):
"""Get the wrapped language model."""
if not args.disable_language_model:
# By default this script uses GPT-4, so you must provide an API key.
# Note that it is also possible to use local models or other API models,
# simply replace the following with the correct initialization for the model
# you want to use.
if args.api_type == 'amazon_bedrock':
return amazon_bedrock_model.AmazonBedrockLanguageModel(
model_name=args.model_name)
elif args.api_type == 'google_aistudio_model':
return google_aistudio_model.GoogleAIStudioLanguageModel(
model_name=args.model_name)
elif args.api_type == 'langchain_ollama':
return langchain_ollama_model.LangchainOllamaLanguageModel(
model_name=args.model_name)
elif args.api_type == 'mistral':
mistral_api_key = os.environ['MISTRAL_API_KEY']
if not mistral_api_key:
raise ValueError('Mistral api_key is required.')
return mistral_model.MistralLanguageModel(api_key=mistral_api_key,
model_name=args.model_name)
elif args.api_type == 'ollama':
return ollama_model.OllamaLanguageModel(model_name=args.model_name)
elif args.api_type == 'openai':
openai.api_key = os.environ['OPENAI_API_KEY']
if not openai.api_key:
raise ValueError('OpenAI api_key is required.')
return gpt_model.GptLanguageModel(api_key=openai.api_key,
model_name=args.model_name)
elif args.api_type == 'pytorch_gemma':
return pytorch_gemma_model.PyTorchGemmaLanguageModel(
model_name=args.model_name)
else:
raise ValueError(f'Unrecognized api type: {args.api_type}')
else:
if args.disable_language_model:
return no_language_model.NoLanguageModel()
elif args.api_type == 'amazon_bedrock':
return amazon_bedrock_model.AmazonBedrockLanguageModel(args.model_name)
elif args.api_type == 'google_aistudio_model':
return google_aistudio_model.GoogleAIStudioLanguageModel(args.model_name)
elif args.api_type == 'langchain_ollama':
return langchain_ollama_model.LangchainOllamaLanguageModel(args.model_name)
elif args.api_type == 'mistral':
return mistral_model.MistralLanguageModel(args.model_name)
elif args.api_type == 'ollama':
return ollama_model.OllamaLanguageModel(args.model_name)
elif args.api_type == 'openai':
return gpt_model.GptLanguageModel(args.model_name)
elif args.api_type == 'pytorch_gemma':
return pytorch_gemma_model.PyTorchGemmaLanguageModel(args.model_name)
else:
raise ValueError(f'Unrecognized api type: {args.api_type}')


# Setup for command line arguments
Expand Down
53 changes: 17 additions & 36 deletions examples/modular/notebook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@
"from concordia.language_model import ollama_model\n",
"from concordia.language_model import pytorch_gemma_model\n",
"from concordia.utils import measurements as measurements_lib\n",
"\n",
"import openai\n",
"import sentence_transformers"
]
},
Expand Down Expand Up @@ -170,41 +168,24 @@
"source": [
"# @title Language Model setup\n",
"\n",
"if not DISABLE_LANGUAGE_MODEL:\n",
" # By default this script uses GPT-4, so you must provide an API key.\n",
" # Note that it is also possible to use local models or other API models,\n",
" # simply replace the following with the correct initialization for the model\n",
" # you want to use.\n",
" if API_TYPE == 'amazon_bedrock':\n",
" model = amazon_bedrock_model.AmazonBedrockLanguageModel(\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'google_aistudio_model':\n",
" model = google_aistudio_model.GoogleAIStudioLanguageModel(\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'langchain_ollama':\n",
" model = langchain_ollama_model.LangchainOllamaLanguageModel(\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'mistral':\n",
" mistral_api_key = os.environ['MISTRAL_API_KEY']\n",
" if not mistral_api_key:\n",
" raise ValueError('Mistral api_key is required.')\n",
" model = mistral_model.MistralLanguageModel(api_key=mistral_api_key,\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'ollama':\n",
" model = ollama_model.OllamaLanguageModel(model_name=MODEL_NAME)\n",
" elif API_TYPE == 'openai':\n",
" openai.api_key = os.environ['OPENAI_API_KEY']\n",
" if not openai.api_key:\n",
" raise ValueError('OpenAI api_key is required.')\n",
" model = gpt_model.GptLanguageModel(api_key=openai.api_key,\n",
" model_name=MODEL_NAME)\n",
" elif API_TYPE == 'pytorch_gemma':\n",
" model = pytorch_gemma_model.PyTorchGemmaLanguageModel(\n",
" model_name=MODEL_NAME)\n",
" else:\n",
" raise ValueError(f'Unrecognized api type: {API_TYPE}')\n",
"if DISABLE_LANGUAGE_MODEL:\n",
" model = no_language_model.NoLanguageModel()\n",
"elif API_TYPE == 'amazon_bedrock':\n",
" model = amazon_bedrock_model.AmazonBedrockLanguageModel(MODEL_NAME)\n",
"elif API_TYPE == 'google_aistudio_model':\n",
" model = google_aistudio_model.GoogleAIStudioLanguageModel(MODEL_NAME)\n",
"elif API_TYPE == 'langchain_ollama':\n",
" model = langchain_ollama_model.LangchainOllamaLanguageModel(MODEL_NAME)\n",
"elif API_TYPE == 'mistral':\n",
" model = mistral_model.MistralLanguageModel(MODEL_NAME)\n",
"elif API_TYPE == 'ollama':\n",
" model = ollama_model.OllamaLanguageModel(MODEL_NAME)\n",
"elif API_TYPE == 'openai':\n",
" model = gpt_model.GptLanguageModel(MODEL_NAME)\n",
"elif API_TYPE == 'pytorch_gemma':\n",
" model = pytorch_gemma_model.PyTorchGemmaLanguageModel(MODEL_NAME)\n",
"else:\n",
" model = no_language_model.NoLanguageModel()"
" raise ValueError(f'Unrecognized api type: {API_TYPE}')"
]
},
{
Expand Down
2 changes: 0 additions & 2 deletions examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,5 @@ IPython
matplotlib
ml_collections
numpy
openai
pandas
sentence_transformers
termcolor
1 change: 0 additions & 1 deletion examples/tutorials/agent_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
"import random\n",
"\n",
"import numpy as np\n",
"import pandas as pd\n",
"import sentence_transformers\n",
"\n",
"from IPython import display\n",
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def _remove_excluded(description: str) -> str:
package_data={},
python_requires='>=3.11',
install_requires=(
# TODO: b/312199199 - remove some requirements.
'absl-py',
'boto3',
'google-cloud-aiplatform',
Expand All @@ -83,7 +82,6 @@ def _remove_excluded(description: str) -> str:
'python-dateutil',
'reactivex',
'retry',
'scipy',
'termcolor',
'transformers',
'typing-extensions',
Expand Down

0 comments on commit db59748

Please sign in to comment.