Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions google/genai/_automatic_function_calling_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
#

import enum
import inspect
import sys
import types as builtin_types
Expand Down Expand Up @@ -152,6 +153,29 @@ def _parse_schema_from_parameter( # type: ignore[return]
f' {func_name} is not compatible with the parameter annotation'
f' {param.annotation}.'
)
if inspect.isclass(param.annotation) and issubclass(param.annotation, enum.Enum):
member_values = [member.value for member in param.annotation]
if all(isinstance(value, str) for value in member_values):
schema.type = _py_builtin_type_to_schema_type[str]
schema.enum = member_values
elif all(isinstance(value, int) for value in member_values):
schema.type = _py_builtin_type_to_schema_type[int]
schema.enum = [str(value) for value in member_values]
else:
raise ValueError(
f'Enum type {param.annotation} must have members that are all'
' strings or all integers.'
)

if param.default is not inspect.Parameter.empty:
default_value = param.default
if isinstance(default_value, param.annotation):
default_value = default_value.value

if default_value not in member_values:
raise ValueError(default_value_error_msg)
schema.default = default_value
return schema
if _is_builtin_primitive_or_compound(param.annotation):
if param.default is not inspect.Parameter.empty:
if not _is_default_value_compatible(param.default, param.annotation):
Expand Down
38 changes: 38 additions & 0 deletions google/genai/tests/types/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import sys
import typing
from enum import IntEnum
from typing import Optional, assert_never
import PIL.Image
import pydantic
Expand Down Expand Up @@ -2542,6 +2543,43 @@ def func_under_test(a: tuple[int, int]) -> str:
assert actual_schema_vertex.parameters_json_schema == expected_parameters_json_schema


def test_function_with_int_enum_parameter():

class DaysEnum(IntEnum):
ONE = 1
FIVE = 5
TEN = 10

def func_under_test(days: DaysEnum) -> str:
"""test IntEnum parameter."""
return ''

expected_schema = types.FunctionDeclaration(
name='func_under_test',
parameters=types.Schema(
type='OBJECT',
properties={
'days': types.Schema(
type='INTEGER',
enum=['1', '5', '10'],
),
},
required=['days'],
),
description='test IntEnum parameter.',
)

actual_schema_mldev = types.FunctionDeclaration.from_callable(
client=mldev_client, callable=func_under_test
)
actual_schema_vertex = types.FunctionDeclaration.from_callable(
client=vertex_client, callable=func_under_test
)

assert actual_schema_mldev == expected_schema
assert actual_schema_vertex == expected_schema


def test_function_gemini_api(monkeypatch):
api_key = 'google_api_key'
monkeypatch.setenv('GOOGLE_API_KEY', api_key)
Expand Down