-
Notifications
You must be signed in to change notification settings - Fork 218
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #231 from SylphAI-Inc/main
[update the docs mainly] Check back later for the bedrock integration
- Loading branch information
Showing
43 changed files
with
2,654 additions
and
2,478 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,4 +36,5 @@ index.faiss | |
*.svg | ||
# ignore the softlink to adalflow cache | ||
*.adalflow | ||
.idea | ||
extend/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
154 changes: 154 additions & 0 deletions
154
adalflow/adalflow/components/model_client/bedrock_client.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
"""AWS Bedrock ModelClient integration.""" | ||
|
||
import os | ||
from typing import Dict, Optional, Any, Callable | ||
import backoff | ||
import logging | ||
|
||
from adalflow.core.model_client import ModelClient | ||
from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput | ||
|
||
import boto3 | ||
from botocore.config import Config | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
bedrock_runtime_exceptions = boto3.client( | ||
service_name="bedrock-runtime", | ||
region_name=os.getenv("AWS_REGION_NAME", "us-east-1") | ||
).exceptions | ||
|
||
|
||
def get_first_message_content(completion: Dict) -> str: | ||
r"""When we only need the content of the first message. | ||
It is the default parser for chat completion.""" | ||
return completion['output']['message']['content'][0]['text'] | ||
|
||
|
||
__all__ = ["BedrockAPIClient", "get_first_message_content", "bedrock_runtime_exceptions"] | ||
|
||
|
||
class BedrockAPIClient(ModelClient): | ||
__doc__ = r"""A component wrapper for the Bedrock API client. | ||
Visit https://docs.aws.amazon.com/bedrock/latest/APIReference/welcome.html for more api details. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
aws_profile_name=None, | ||
aws_region_name=None, | ||
aws_access_key_id=None, | ||
aws_secret_access_key=None, | ||
aws_session_token=None, | ||
aws_connection_timeout=None, | ||
aws_read_timeout=None, | ||
chat_completion_parser: Callable = None, | ||
): | ||
super().__init__() | ||
self._aws_profile_name = aws_profile_name | ||
self._aws_region_name = aws_region_name | ||
self._aws_access_key_id = aws_access_key_id | ||
self._aws_secret_access_key = aws_secret_access_key | ||
self._aws_session_token = aws_session_token | ||
self._aws_connection_timeout = aws_connection_timeout | ||
self._aws_read_timeout = aws_read_timeout | ||
|
||
self.session = None | ||
self.sync_client = self.init_sync_client() | ||
self.chat_completion_parser = ( | ||
chat_completion_parser or get_first_message_content | ||
) | ||
|
||
def init_sync_client(self): | ||
""" | ||
There is no need to pass both profile and secret key and access key. Path one of them. | ||
if the compute power assume a role that have access to bedrock, no need to pass anything. | ||
""" | ||
aws_profile_name = self._aws_profile_name or os.getenv("AWS_PROFILE_NAME") | ||
aws_region_name = self._aws_region_name or os.getenv("AWS_REGION_NAME") | ||
aws_access_key_id = self._aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID") | ||
aws_secret_access_key = self._aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY") | ||
aws_session_token = self._aws_session_token or os.getenv("AWS_SESSION_TOKEN") | ||
|
||
config = None | ||
if self._aws_connection_timeout or self._aws_read_timeout: | ||
config = Config( | ||
connect_timeout=self._aws_connection_timeout, # Connection timeout in seconds | ||
read_timeout=self._aws_read_timeout # Read timeout in seconds | ||
) | ||
|
||
session = boto3.Session( | ||
profile_name=aws_profile_name, | ||
region_name=aws_region_name, | ||
aws_access_key_id=aws_access_key_id, | ||
aws_secret_access_key=aws_secret_access_key, | ||
aws_session_token=aws_session_token, | ||
) | ||
bedrock_runtime = session.client(service_name="bedrock-runtime", config=config) | ||
return bedrock_runtime | ||
|
||
def init_async_client(self): | ||
raise NotImplementedError("Async call not implemented yet.") | ||
|
||
def parse_chat_completion(self, completion): | ||
log.debug(f"completion: {completion}") | ||
try: | ||
data = completion['output']['message']['content'][0]['text'] | ||
usage = self.track_completion_usage(completion) | ||
return GeneratorOutput(data=None, usage=usage, raw_response=data) | ||
except Exception as e: | ||
log.error(f"Error parsing completion: {e}") | ||
return GeneratorOutput( | ||
data=None, error=str(e), raw_response=str(completion) | ||
) | ||
|
||
def track_completion_usage(self, completion: Dict) -> CompletionUsage: | ||
r"""Track the completion usage.""" | ||
usage = completion['usage'] | ||
return CompletionUsage( | ||
completion_tokens=usage['outputTokens'], | ||
prompt_tokens=usage['inputTokens'], | ||
total_tokens=usage['totalTokens'] | ||
) | ||
|
||
def convert_inputs_to_api_kwargs( | ||
self, | ||
input: Optional[Any] = None, | ||
model_kwargs: Dict = {}, | ||
model_type: ModelType = ModelType.UNDEFINED | ||
): | ||
""" | ||
check the converse api doc here: | ||
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html | ||
""" | ||
api_kwargs = model_kwargs.copy() | ||
if model_type == ModelType.LLM: | ||
api_kwargs["messages"] = [ | ||
{"role": "user", "content": [{"text": input}]}, | ||
] | ||
else: | ||
raise ValueError(f"Model type {model_type} not supported") | ||
return api_kwargs | ||
|
||
@backoff.on_exception( | ||
backoff.expo, | ||
( | ||
bedrock_runtime_exceptions.ThrottlingException, | ||
bedrock_runtime_exceptions.ModelTimeoutException, | ||
bedrock_runtime_exceptions.InternalServerException, | ||
bedrock_runtime_exceptions.ModelErrorException, | ||
bedrock_runtime_exceptions.ValidationException | ||
), | ||
max_time=5, | ||
) | ||
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED): | ||
""" | ||
kwargs is the combined input and model_kwargs | ||
""" | ||
if model_type == ModelType.LLM: | ||
return self.sync_client.converse(**api_kwargs) | ||
else: | ||
raise ValueError(f"model_type {model_type} is not supported") | ||
|
||
async def acall(self): | ||
raise NotImplementedError("Async call not implemented yet.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.