Skip to content

Commit

Permalink
feat: Use TypeAliasType to define aliases for union types in genera…
Browse files Browse the repository at this point in the history
…tive models

This is based on the original PR in #4701, just wrapping the typealiases in a try-catch block.

PiperOrigin-RevId: 704506046
  • Loading branch information
yeesian authored and copybara-github committed Dec 19, 2024
1 parent 5a4e9c0 commit 344600f
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 3 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@

genai_requires = (
"pydantic < 3",
"typing_extensions",
"docstring_parser < 1",
)

Expand All @@ -143,7 +144,8 @@
"google-cloud-trace < 2",
"opentelemetry-sdk < 2",
"opentelemetry-exporter-gcp-trace < 2",
"pydantic >= 2.6.3, < 2.10",
"pydantic >= 2.6.3, < 3",
"typing_extensions",
]

evaluation_extra_require = [
Expand Down
3 changes: 1 addition & 2 deletions testing/constraints-langchain.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
langchain
langchain-core
langchain-google-vertexai
pydantic<2.10
langchain-google-vertexai
45 changes: 45 additions & 0 deletions vertexai/generative_models/_generative_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,51 @@
],
]

try:
# For Pydantic to resolve the forward references inside these aliases.
from typing_extensions import TypeAliasType

PartsType = TypeAliasType(
"PartsType",
Union[
str,
"Image",
"Part",
List[Union[str, "Image", "Part"]],
],
)
ContentsType = TypeAliasType(
"ContentsType",
Union[
List["Content"],
List[ContentDict],
str,
"Image",
"Part",
List[Union[str, "Image", "Part"]],
],
)
GenerationConfigType = TypeAliasType(
"GenerationConfigType",
Union[
"GenerationConfig",
GenerationConfigDict,
],
)
SafetySettingsType = TypeAliasType(
"SafetySettingsType",
Union[
List["SafetySetting"],
Dict[
gapic_content_types.HarmCategory,
gapic_content_types.SafetySetting.HarmBlockThreshold,
],
],
)
except ImportError:
# Use existing definitions if typing_extensions is not available.
pass


def _reconcile_model_name(model_name: str, project: str, location: str) -> str:
"""Returns a model name that's one of the following:
Expand Down

0 comments on commit 344600f

Please sign in to comment.