Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions google/genai/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,19 @@ def append_library_version_headers(headers: dict[str, str]) -> None:
library_label = f'google-genai-sdk/{version.__version__}'
language_label = 'gl-python/' + sys.version.split()[0]
version_header_value = f'{library_label} {language_label}'
user_agent_key = next(
(key for key in headers if key.lower() == 'user-agent'),
'User-Agent',
)
if (
'user-agent' in headers
and version_header_value not in headers['user-agent']
user_agent_key in headers
and version_header_value not in headers[user_agent_key]
):
headers['user-agent'] = f'{version_header_value} ' + headers['user-agent']
elif 'user-agent' not in headers:
headers['user-agent'] = version_header_value
headers[user_agent_key] = (
f'{version_header_value} ' + headers[user_agent_key]
)
elif user_agent_key not in headers:
headers[user_agent_key] = version_header_value
if (
'x-goog-api-client' in headers
and version_header_value not in headers['x-goog-api-client']
Expand Down
2 changes: 1 addition & 1 deletion google/genai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,7 @@ def _GenerateContentConfig_to_mldev(
)

if getv(from_object, ['labels']) is not None:
raise ValueError('labels parameter is not supported in Gemini API.')
setv(parent_object, ['labels'], getv(from_object, ['labels']))

if getv(from_object, ['cached_content']) is not None:
setv(
Expand Down
14 changes: 14 additions & 0 deletions google/genai/tests/interactions/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from unittest import mock
import httpx
import pytest
from ... import client as client_lib

Expand Down Expand Up @@ -82,3 +83,16 @@ async def test_async_client_timeout():
max_retries=mock.ANY,
client_adapter=mock.ANY,
)


def test_interactions_default_headers_use_single_user_agent():
client = client_lib.Client(
api_key="placeholder",
http_options={"api_version": "v1alpha"},
)

headers = httpx.Headers(client.interactions._client.default_headers)

assert len(headers.get_list("user-agent")) == 1
assert "google-genai-sdk/" in headers["user-agent"]
assert "gl-python/" in headers["user-agent"]
15 changes: 15 additions & 0 deletions google/genai/tests/models/test_generate_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
from .. import pytest_helper
from enum import Enum

from ... import models as models_module

GEMINI_FLASH_LATEST = 'gemini-2.5-flash'
GEMINI_FLASH_2_0 = 'gemini-2.0-flash-001'
GEMINI_FLASH_IMAGE_LATEST = 'gemini-2.5-flash-image'
Expand Down Expand Up @@ -64,6 +66,19 @@ class InstrumentEnum(Enum):
KEYBOARD = 'Keyboard'


def test_generate_content_labels_are_serialized_for_mldev():
request = models_module._GenerateContentConfig_to_mldev(
{
'labels': {'purpose': 'exploration', 'environment': 'development'},
}
)

assert request['labels'] == {
'purpose': 'exploration',
'environment': 'development',
}


test_table: list[pytest_helper.TestTableItem] = [
pytest_helper.TestTableItem(
name='test_http_options_in_method',
Expand Down
56 changes: 56 additions & 0 deletions google/genai/tests/types/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2739,6 +2739,62 @@ def test_unknown_enum_value_in_nested_dict():
assert schema.category.value == 'NEW_CATEGORY'


def test_generate_content_response_surfaces_json_parse_error():
class Recipe(pydantic.BaseModel):
name: str

response = types.GenerateContentResponse._from_response(
response={
'candidates': [
{
'content': {
'parts': [{'text': '{"name": "Soup"'}],
'role': 'model',
}
}
]
},
kwargs={
'config': {
'response_schema': Recipe,
}
},
)

assert response.parsed is None
assert response.parsed_error is not None
assert 'ValidationError' in response.parsed_error
assert '{"name": "Soup"' in response.parsed_error


def test_generate_content_response_surfaces_validation_error():
class Recipe(pydantic.BaseModel):
name: str

response = types.GenerateContentResponse._from_response(
response={
'candidates': [
{
'content': {
'parts': [{'text': '{"title": "Soup"}'}],
'role': 'model',
}
}
]
},
kwargs={
'config': {
'response_schema': Recipe,
}
},
)

assert response.parsed is None
assert response.parsed_error is not None
assert 'ValidationError' in response.parsed_error
assert '{"title": "Soup"}' in response.parsed_error


# Tests that TypedDict types from types.py are compatible with pydantic
# pydantic requires TypedDict from typing_extensions for Python <3.12
def test_typed_dict_pydantic_field():
Expand Down
52 changes: 34 additions & 18 deletions google/genai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7923,6 +7923,10 @@ class GenerateContentResponse(_common.BaseModel):
default=None,
description="""First candidate from the parsed response if response_schema is provided. Not available for streaming.""",
)
parsed_error: Optional[str] = Field(
default=None,
description="""Details about why structured response parsing failed, if a response schema was provided.""",
)

def _get_text(self) -> Optional[str]:
"""Returns the concatenation of all text parts in the response.
Expand Down Expand Up @@ -8098,6 +8102,13 @@ def _from_response(
) -> T:
result = super()._from_response(response=response, kwargs=kwargs)

def _set_parsed_error(exc: Exception, result_text: Optional[str]) -> None:
details = f'{type(exc).__name__}: {exc}'
if result_text is not None:
result.parsed_error = f'{details}. Response text: {result_text}'
else:
result.parsed_error = details

# Handles response schema.
response_schema = _common.get_value_by_path(
kwargs, ['config', 'response_schema']
Expand All @@ -8121,10 +8132,10 @@ def _from_response(
if result_text is not None:
result.parsed = response_schema.model_validate_json(result_text)
# may not be a valid json per stream response
except pydantic.ValidationError:
pass
except json.decoder.JSONDecodeError:
pass
except pydantic.ValidationError as e:
_set_parsed_error(e, result_text if 'result_text' in locals() else None)
except json.decoder.JSONDecodeError as e:
_set_parsed_error(e, result_text if 'result_text' in locals() else None)
elif (
isinstance(response_schema, EnumMeta) and result._get_text() is not None
):
Expand All @@ -8140,8 +8151,8 @@ def _from_response(
and response_schema.__name__ == 'PlaceholderLiteralEnum'
):
result.parsed = str(response_schema(enum_value).name) # type: ignore
except ValueError:
pass
except ValueError as e:
_set_parsed_error(e, result_text)
elif isinstance(response_schema, builtin_types.GenericAlias) or isinstance(
response_schema, type
):
Expand All @@ -8155,10 +8166,10 @@ class Placeholder(pydantic.BaseModel):
parsed = {'placeholder': json.loads(result_text)}
placeholder = Placeholder.model_validate(parsed)
result.parsed = placeholder.placeholder
except json.decoder.JSONDecodeError:
pass
except pydantic.ValidationError:
pass
except json.decoder.JSONDecodeError as e:
_set_parsed_error(e, result_text if 'result_text' in locals() else None)
except pydantic.ValidationError as e:
_set_parsed_error(e, result_text if 'result_text' in locals() else None)

elif isinstance(response_schema, dict) or isinstance(
response_schema, Schema
Expand All @@ -8171,11 +8182,12 @@ class Placeholder(pydantic.BaseModel):
if result_text is not None:
result.parsed = json.loads(result_text)
# may not be a valid json per stream response
except json.decoder.JSONDecodeError:
pass
except json.decoder.JSONDecodeError as e:
_set_parsed_error(e, result_text if 'result_text' in locals() else None)
elif typing.get_origin(response_schema) in _UNION_TYPES:
# Union schema.
union_types = typing.get_args(response_schema)
union_errors: list[tuple[Exception, Optional[str]]] = []
for union_type in union_types:
if issubclass(union_type, pydantic.BaseModel):
try:
Expand All @@ -8188,18 +8200,22 @@ class Placeholder(pydantic.BaseModel): # type: ignore[no-redef]
parsed = {'placeholder': json.loads(result_text)}
placeholder = Placeholder.model_validate(parsed)
result.parsed = placeholder.placeholder
except json.decoder.JSONDecodeError:
pass
except pydantic.ValidationError:
pass
except json.decoder.JSONDecodeError as e:
union_errors.append((e, result_text if 'result_text' in locals() else None))
except pydantic.ValidationError as e:
union_errors.append((e, result_text if 'result_text' in locals() else None))
else:
try:
result_text = result._get_text()
if result_text is not None:
result.parsed = json.loads(result_text)
# may not be a valid json per stream response
except json.decoder.JSONDecodeError:
pass
except json.decoder.JSONDecodeError as e:
union_errors.append((e, result_text if 'result_text' in locals() else None))
if result.parsed is not None:
break
if result.parsed is None and union_errors:
_set_parsed_error(*union_errors[-1])

return result

Expand Down