diff --git a/google/genai/_automatic_function_calling_util.py b/google/genai/_automatic_function_calling_util.py index ec7f9a702..ef4d8c01d 100644 --- a/google/genai/_automatic_function_calling_util.py +++ b/google/genai/_automatic_function_calling_util.py @@ -13,6 +13,7 @@ # limitations under the License. # +import enum import inspect import sys import types as builtin_types @@ -152,6 +153,29 @@ def _parse_schema_from_parameter( # type: ignore[return] f' {func_name} is not compatible with the parameter annotation' f' {param.annotation}.' ) + if inspect.isclass(param.annotation) and issubclass(param.annotation, enum.Enum): + member_values = [member.value for member in param.annotation] + if all(isinstance(value, str) for value in member_values): + schema.type = _py_builtin_type_to_schema_type[str] + schema.enum = member_values + elif all(isinstance(value, int) for value in member_values): + schema.type = _py_builtin_type_to_schema_type[int] + schema.enum = [str(value) for value in member_values] + else: + raise ValueError( + f'Enum type {param.annotation} must have members that are all' + ' strings or all integers.' + ) + + if param.default is not inspect.Parameter.empty: + default_value = param.default + if isinstance(default_value, param.annotation): + default_value = default_value.value + + if default_value not in member_values: + raise ValueError(default_value_error_msg) + schema.default = default_value + return schema if _is_builtin_primitive_or_compound(param.annotation): if param.default is not inspect.Parameter.empty: if not _is_default_value_compatible(param.default, param.annotation): diff --git a/google/genai/tests/types/test_types.py b/google/genai/tests/types/test_types.py index 099ff2a2d..ae6c93f45 100644 --- a/google/genai/tests/types/test_types.py +++ b/google/genai/tests/types/test_types.py @@ -18,6 +18,7 @@ import json import sys import typing +from enum import IntEnum from typing import Optional, assert_never import PIL.Image import pydantic @@ -2542,6 +2543,43 @@ def func_under_test(a: tuple[int, int]) -> str: assert actual_schema_vertex.parameters_json_schema == expected_parameters_json_schema +def test_function_with_int_enum_parameter(): + + class DaysEnum(IntEnum): + ONE = 1 + FIVE = 5 + TEN = 10 + + def func_under_test(days: DaysEnum) -> str: + """test IntEnum parameter.""" + return '' + + expected_schema = types.FunctionDeclaration( + name='func_under_test', + parameters=types.Schema( + type='OBJECT', + properties={ + 'days': types.Schema( + type='INTEGER', + enum=['1', '5', '10'], + ), + }, + required=['days'], + ), + description='test IntEnum parameter.', + ) + + actual_schema_mldev = types.FunctionDeclaration.from_callable( + client=mldev_client, callable=func_under_test + ) + actual_schema_vertex = types.FunctionDeclaration.from_callable( + client=vertex_client, callable=func_under_test + ) + + assert actual_schema_mldev == expected_schema + assert actual_schema_vertex == expected_schema + + def test_function_gemini_api(monkeypatch): api_key = 'google_api_key' monkeypatch.setenv('GOOGLE_API_KEY', api_key)