Skip to content

Commit

Permalink
Merge pull request #231 from SylphAI-Inc/main
Browse files Browse the repository at this point in the history
[update the docs mainly] Check back later for the bedrock integration
  • Loading branch information
Sylph-AI authored Oct 21, 2024
2 parents 363f535 + b068eaf commit 794319d
Show file tree
Hide file tree
Showing 43 changed files with 2,654 additions and 2,478 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,5 @@ index.faiss
*.svg
# ignore the softlink to adalflow cache
*.adalflow
.idea
extend/
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ Just define it as a ``Parameter`` and pass it to AdalFlow's ``Generator``.
``AdalComponent`` acts as the 'interpreter' between task pipeline and the trainer, defining training and validation steps, optimizers, evaluators, loss functions, backward engine for textual gradients or tracing the demonstrations, the teacher generator.

<p align="center">
<img src="https://raw.githubusercontent.com/SylphAI-Inc/LightRAG/main/docs/source/_static/images/trainer.png" alt="AdalFlow AdalComponent & Trainer">
<img src="https://raw.githubusercontent.com/SylphAI-Inc/AdalFlow/main/docs/source/_static/images/trainer.png" alt="AdalFlow AdalComponent & Trainer">

</p>

# Quick Install
Expand Down
3 changes: 2 additions & 1 deletion adalflow/adalflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from adalflow.optim.grad_component import GradComponent
from adalflow.core.generator import Generator


from adalflow.core.types import (
GeneratorOutput,
EmbedderOutput,
Expand Down Expand Up @@ -56,6 +55,7 @@
TransformersClient,
AnthropicAPIClient,
CohereAPIClient,
BedrockAPIClient,
)

__all__ = [
Expand Down Expand Up @@ -113,4 +113,5 @@
"TransformersClient",
"AnthropicAPIClient",
"CohereAPIClient",
"BedrockAPIClient",
]
6 changes: 5 additions & 1 deletion adalflow/adalflow/components/model_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
"adalflow.components.model_client.anthropic_client.AnthropicAPIClient",
OptionalPackages.ANTHROPIC,
)
BedrockAPIClient = LazyImport(
"adalflow.components.model_client.bedrock_client.BedrockAPIClient",
OptionalPackages.BEDROCK,
)
GroqAPIClient = LazyImport(
"adalflow.components.model_client.groq_client.GroqAPIClient",
OptionalPackages.GROQ,
Expand Down Expand Up @@ -61,14 +65,14 @@
OptionalPackages.OPENAI,
)


__all__ = [
"CohereAPIClient",
"TransformerReranker",
"TransformerEmbedder",
"TransformerLLM",
"TransformersClient",
"AnthropicAPIClient",
"BedrockAPIClient",
"GroqAPIClient",
"OpenAIClient",
"GoogleGenAIClient",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@ async def acall(
elif model_type == ModelType.LLM:
return await self.async_client.messages.create(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")
raise ValueError(f"model_type {model_type} is not supported")
154 changes: 154 additions & 0 deletions adalflow/adalflow/components/model_client/bedrock_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
"""AWS Bedrock ModelClient integration."""

import os
from typing import Dict, Optional, Any, Callable
import backoff
import logging

from adalflow.core.model_client import ModelClient
from adalflow.core.types import ModelType, CompletionUsage, GeneratorOutput

import boto3
from botocore.config import Config

log = logging.getLogger(__name__)

bedrock_runtime_exceptions = boto3.client(
service_name="bedrock-runtime",
region_name=os.getenv("AWS_REGION_NAME", "us-east-1")
).exceptions


def get_first_message_content(completion: Dict) -> str:
r"""When we only need the content of the first message.
It is the default parser for chat completion."""
return completion['output']['message']['content'][0]['text']


__all__ = ["BedrockAPIClient", "get_first_message_content", "bedrock_runtime_exceptions"]


class BedrockAPIClient(ModelClient):
__doc__ = r"""A component wrapper for the Bedrock API client.
Visit https://docs.aws.amazon.com/bedrock/latest/APIReference/welcome.html for more api details.
"""

def __init__(
self,
aws_profile_name=None,
aws_region_name=None,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
aws_connection_timeout=None,
aws_read_timeout=None,
chat_completion_parser: Callable = None,
):
super().__init__()
self._aws_profile_name = aws_profile_name
self._aws_region_name = aws_region_name
self._aws_access_key_id = aws_access_key_id
self._aws_secret_access_key = aws_secret_access_key
self._aws_session_token = aws_session_token
self._aws_connection_timeout = aws_connection_timeout
self._aws_read_timeout = aws_read_timeout

self.session = None
self.sync_client = self.init_sync_client()
self.chat_completion_parser = (
chat_completion_parser or get_first_message_content
)

def init_sync_client(self):
"""
There is no need to pass both profile and secret key and access key. Path one of them.
if the compute power assume a role that have access to bedrock, no need to pass anything.
"""
aws_profile_name = self._aws_profile_name or os.getenv("AWS_PROFILE_NAME")
aws_region_name = self._aws_region_name or os.getenv("AWS_REGION_NAME")
aws_access_key_id = self._aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_access_key = self._aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY")
aws_session_token = self._aws_session_token or os.getenv("AWS_SESSION_TOKEN")

config = None
if self._aws_connection_timeout or self._aws_read_timeout:
config = Config(
connect_timeout=self._aws_connection_timeout, # Connection timeout in seconds
read_timeout=self._aws_read_timeout # Read timeout in seconds
)

session = boto3.Session(
profile_name=aws_profile_name,
region_name=aws_region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
)
bedrock_runtime = session.client(service_name="bedrock-runtime", config=config)
return bedrock_runtime

def init_async_client(self):
raise NotImplementedError("Async call not implemented yet.")

def parse_chat_completion(self, completion):
log.debug(f"completion: {completion}")
try:
data = completion['output']['message']['content'][0]['text']
usage = self.track_completion_usage(completion)
return GeneratorOutput(data=None, usage=usage, raw_response=data)
except Exception as e:
log.error(f"Error parsing completion: {e}")
return GeneratorOutput(
data=None, error=str(e), raw_response=str(completion)
)

def track_completion_usage(self, completion: Dict) -> CompletionUsage:
r"""Track the completion usage."""
usage = completion['usage']
return CompletionUsage(
completion_tokens=usage['outputTokens'],
prompt_tokens=usage['inputTokens'],
total_tokens=usage['totalTokens']
)

def convert_inputs_to_api_kwargs(
self,
input: Optional[Any] = None,
model_kwargs: Dict = {},
model_type: ModelType = ModelType.UNDEFINED
):
"""
check the converse api doc here:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/bedrock-runtime/client/converse.html
"""
api_kwargs = model_kwargs.copy()
if model_type == ModelType.LLM:
api_kwargs["messages"] = [
{"role": "user", "content": [{"text": input}]},
]
else:
raise ValueError(f"Model type {model_type} not supported")
return api_kwargs

@backoff.on_exception(
backoff.expo,
(
bedrock_runtime_exceptions.ThrottlingException,
bedrock_runtime_exceptions.ModelTimeoutException,
bedrock_runtime_exceptions.InternalServerException,
bedrock_runtime_exceptions.ModelErrorException,
bedrock_runtime_exceptions.ValidationException
),
max_time=5,
)
def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINED):
"""
kwargs is the combined input and model_kwargs
"""
if model_type == ModelType.LLM:
return self.sync_client.converse(**api_kwargs)
else:
raise ValueError(f"model_type {model_type} is not supported")

async def acall(self):
raise NotImplementedError("Async call not implemented yet.")
8 changes: 8 additions & 0 deletions adalflow/adalflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from .config import new_components_from_config, new_component
from .lazy_import import LazyImport, OptionalPackages, safe_import
from .setup_env import setup_env
from .data import DataLoader, Dataset, Subset
from .global_config import get_adalflow_default_root_path
from .cache import CachedEngine


__all__ = [
Expand All @@ -43,4 +46,9 @@
"write_list_to_jsonl",
"safe_import",
"setup_env",
"DataLoader",
"Dataset",
"Subset",
"get_adalflow_default_root_path",
"CachedEngine",
]
4 changes: 2 additions & 2 deletions adalflow/adalflow/utils/lazy_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import logging
from types import ModuleType


from enum import Enum

log = logging.getLogger(__name__)
Expand All @@ -19,6 +18,7 @@ class OptionalPackages(Enum):
GROQ = ("groq", "Please install groq with: pip install groq")
OPENAI = ("openai", "Please install openai with: pip install openai")
ANTHROPIC = ("anthropic", "Please install anthropic with: pip install anthropic")
BEDROCK = ("bedrock", "Please install boto3 with: pip install boto3")
GOOGLE_GENERATIVEAI = (
"google.generativeai",
"Please install google-generativeai with: pip install google-generativeai",
Expand Down Expand Up @@ -78,7 +78,7 @@ class LazyImport:
"""

def __init__(
self, import_path: str, optional_package: OptionalPackages, *args, **kwargs
self, import_path: str, optional_package: OptionalPackages, *args, **kwargs
):
if args or kwargs:
raise TypeError(
Expand Down
Loading

0 comments on commit 794319d

Please sign in to comment.