diff --git a/pyproject.toml b/pyproject.toml index 1168373ea..143393b7b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ aioboto3 = [ "aioboto3>=10.4.0", "types-aioboto3[s3]>=10.4.0", ] +google-gemini = [ + "google-genai>=1.66.0", +] [project.urls] Homepage = "https://github.com/temporalio/sdk-python" diff --git a/temporalio/contrib/google_gemini_sdk/__init__.py b/temporalio/contrib/google_gemini_sdk/__init__.py new file mode 100644 index 000000000..a62d54f9c --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/__init__.py @@ -0,0 +1,63 @@ +"""First-class Temporal integration for the Google Gemini SDK. + +.. warning:: + This module is experimental and may change in future versions. + Use with caution in production environments. + +This integration lets you use the Gemini SDK's async client with full +automatic function calling (AFC) support, where every API call and every +tool invocation is a **durable Temporal activity**. + +No credentials are fetched in the workflow, and no auth material appears in +Temporal's event history. + +- :class:`GeminiPlugin` — registers the ``gemini_api_client_async_request`` + activity using a caller-provided ``genai.Client`` on the worker side. +- :func:`gemini_client` — call from a workflow to get an ``AsyncClient`` + that routes API calls through activities. +- :func:`activity_as_tool` — convert any ``@activity.defn`` function into a + Gemini tool callable; Gemini's AFC invokes it as a Temporal activity. + +Quickstart:: + + # ---- worker setup (outside sandbox) ---- + client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"]) + plugin = GeminiPlugin(client) + + @activity.defn + async def get_weather(state: str) -> str: ... + + # ---- workflow (sandbox-safe) ---- + @workflow.defn + class AgentWorkflow: + @workflow.run + async def run(self, query: str) -> str: + client = gemini_client() + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=query, + config=types.GenerateContentConfig( + tools=[ + activity_as_tool( + get_weather, + start_to_close_timeout=timedelta(seconds=30), + ), + ], + ), + ) + return response.text +""" + +from __future__ import annotations + +from temporalio.contrib.google_gemini_sdk._gemini_plugin import GeminiPlugin +from temporalio.contrib.google_gemini_sdk.workflow import ( + activity_as_tool, + gemini_client, +) + +__all__ = [ + "GeminiPlugin", + "activity_as_tool", + "gemini_client", +] diff --git a/temporalio/contrib/google_gemini_sdk/_gemini_activity.py b/temporalio/contrib/google_gemini_sdk/_gemini_activity.py new file mode 100644 index 000000000..58d3ab952 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_gemini_activity.py @@ -0,0 +1,174 @@ +"""Temporal activity that executes Gemini SDK API calls with real credentials. + +The ``TemporalApiClient`` in the workflow dispatches calls here. This +activity holds a user-provided ``genai.Client`` and forwards structured +requests. Credentials are fetched/refreshed only within the activity — +they never appear in workflow event history. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Callable + +import google.auth.credentials +from google.genai import Client as GeminiClient +from google.genai import types +from google.genai.types import HttpOptions +from google.genai.types import HttpResponse as SdkHttpResponse + +from temporalio import activity +from temporalio.contrib.google_gemini_sdk._models import ( + _GeminiApiRequest, + _GeminiApiResponse, + _GeminiApiStreamedResponse, + _GeminiDownloadFileRequest, + _GeminiRegisterFilesRequest, + _GeminiUploadFileRequest, + _GeminiUploadToFileSearchStoreRequest, +) + + +def _resolve_http_options( + overrides: Any, +) -> HttpOptions | None: + """Reconstruct ``HttpOptions`` from serializable overrides, or None.""" + if overrides is None: + return None + return HttpOptions.model_validate(overrides.model_dump(exclude_none=True)) + + +class GeminiApiCaller: + """Wraps a ``genai.Client`` and exposes Temporal activities for SDK calls. + + The caller owns a reference to the user-provided ``genai.Client``. + All credential management, HTTP client configuration, etc. is the + responsibility of whoever constructs the client. + """ + + def __init__( + self, + client: GeminiClient, + credentials: google.auth.credentials.Credentials | None = None, + ) -> None: + """Initialize with a genai.Client and optional extra credentials.""" + self._client = client + self._credentials = credentials + + def activities(self) -> Sequence[Callable]: + """Return activities that route SDK calls through this client.""" + + @activity.defn(name="gemini_api_client_async_request") + async def gemini_api_client_async_request( + req: _GeminiApiRequest, + ) -> _GeminiApiResponse: + """Execute a Gemini SDK API call with real credentials.""" + response: SdkHttpResponse = ( + await self._client.aio._api_client.async_request( + http_method=req.http_method, + path=req.path, + request_dict=req.request_dict, + http_options=_resolve_http_options(req.http_options_overrides), + ) + ) + return _GeminiApiResponse( + headers=response.headers or {}, + body=response.body or "", + ) + + @activity.defn(name="gemini_api_client_async_request_streamed") + async def gemini_api_client_async_request_streamed( + req: _GeminiApiRequest, + ) -> _GeminiApiStreamedResponse: + """Execute a streamed Gemini SDK API call, collecting all chunks.""" + stream = await self._client.aio._api_client.async_request_streamed( + http_method=req.http_method, + path=req.path, + request_dict=req.request_dict, + http_options=_resolve_http_options(req.http_options_overrides), + ) + chunks = [] + async for chunk in stream: + chunks.append( + _GeminiApiResponse( + headers=chunk.headers or {}, + body=chunk.body or "", + ) + ) + return _GeminiApiStreamedResponse(chunks=chunks) + + @activity.defn(name="gemini_files_upload") + async def gemini_files_upload( + req: _GeminiUploadFileRequest, + ) -> types.File: + """Upload a file using the real genai.Client on the worker.""" + if req.file_bytes is not None: + import io + + file_arg: Any = io.BytesIO(req.file_bytes) + else: + file_arg = req.file_path + + return await self._client.aio.files.upload(file=file_arg, config=req.config) + + @activity.defn(name="gemini_files_download") + async def gemini_files_download( + req: _GeminiDownloadFileRequest, + ) -> bytes: + """Download a file using the real genai.Client on the worker.""" + return await self._client.aio.files.download( + file=req.file, config=req.config + ) + + @activity.defn(name="gemini_files_register") + async def gemini_files_register( + req: _GeminiRegisterFilesRequest, + ) -> types.RegisterFilesResponse: + """Register GCS files using the real genai.Client on the worker. + + Uses ``credentials`` if provided at plugin init, + otherwise falls back to the client's own credentials. + Token refresh happens here on the worker side, so no auth + material enters the workflow event history. + """ + auth = self._credentials or self._client._api_client._credentials + if auth is None: + raise ValueError( + "No credentials available for register_files(). " + "Pass extra_credentials to GeminiPlugin or initialize " + "the genai.Client with credentials." + ) + return await self._client.aio.files.register_files( + auth=auth, + uris=req.uris, + config=req.config, + ) + + @activity.defn(name="gemini_file_search_stores_upload") + async def gemini_file_search_stores_upload( + req: _GeminiUploadToFileSearchStoreRequest, + ) -> types.UploadToFileSearchStoreOperation: + """Upload a file to a file search store on the worker.""" + if req.file_bytes is not None: + import io + + file_arg: Any = io.BytesIO(req.file_bytes) + else: + file_arg = req.file_path + + return ( + await self._client.aio.file_search_stores.upload_to_file_search_store( + file_search_store_name=req.file_search_store_name, + file=file_arg, + config=req.config, + ) + ) + + return [ + gemini_api_client_async_request, + gemini_api_client_async_request_streamed, + gemini_files_upload, + gemini_files_download, + gemini_files_register, + gemini_file_search_stores_upload, + ] diff --git a/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py new file mode 100644 index 000000000..7ced6bd52 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_gemini_plugin.py @@ -0,0 +1,98 @@ +"""Temporal plugin for Google Gemini SDK integration.""" + +from __future__ import annotations + +import dataclasses + +import google.auth.credentials +from google.genai import Client as GeminiClient + +from temporalio.contrib.google_gemini_sdk._gemini_activity import GeminiApiCaller +from temporalio.contrib.pydantic import PydanticPayloadConverter +from temporalio.converter import DataConverter, DefaultPayloadConverter +from temporalio.plugin import SimplePlugin +from temporalio.worker import WorkflowRunner +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner + + +def _data_converter(converter: DataConverter | None) -> DataConverter: + if converter is None: + return DataConverter(payload_converter_class=PydanticPayloadConverter) + elif converter.payload_converter_class is DefaultPayloadConverter: + return dataclasses.replace( + converter, payload_converter_class=PydanticPayloadConverter + ) + return converter + + +class GeminiPlugin(SimplePlugin): + """A Temporal Worker Plugin configured for the Google Gemini SDK. + + .. warning:: + This class is experimental and may change in future versions. + Use with caution in production environments. + + This plugin registers the ``gemini_api_client_async_request`` activity + using the provided ``genai.Client`` with real credentials. Workflows use + :func:`~temporalio.contrib.google_gemini_sdk.workflow.gemini_client` to + get an ``AsyncClient`` backed by a ``TemporalApiClient`` that routes all + API calls through this activity. + + No credentials are passed to or from the workflow. Auth material never + appears in Temporal's event history. + + Example (Gemini Developer API):: + + client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"]) + plugin = GeminiPlugin(client) + + Example (Vertex AI):: + + client = genai.Client( + vertexai=True, project="my-project", location="us-central1", + ) + plugin = GeminiPlugin(client) + + Example (with separate GCS credentials for file registration):: + + client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"]) + gcs_creds, _ = google.auth.default() + plugin = GeminiPlugin(client, extra_credentials=gcs_creds) + """ + + def __init__( + self, + client: GeminiClient, + extra_credentials: google.auth.credentials.Credentials | None = None, + ) -> None: + """Initialize the Gemini plugin. + + Args: + client: A fully configured ``genai.Client`` instance. + All credential management, HTTP client configuration, etc. + is the responsibility of the caller. + extra_credentials: Optional Google Cloud credentials used for + operations that require explicit auth (e.g. + ``files.register_files()``). If not provided, the + client's own credentials are used. + """ + self._api_caller = GeminiApiCaller(client, credentials=extra_credentials) + + def workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: + if not runner: + raise ValueError("No WorkflowRunner provided to GeminiPlugin.") + if isinstance(runner, SandboxedWorkflowRunner): + return dataclasses.replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + "google.genai" + ), + ) + return runner + + super().__init__( + name="GeminiPlugin", + data_converter=_data_converter, + activities=self._api_caller.activities(), + workflow_runner=workflow_runner, + ) diff --git a/temporalio/contrib/google_gemini_sdk/_models.py b/temporalio/contrib/google_gemini_sdk/_models.py new file mode 100644 index 000000000..dccbcdb25 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_models.py @@ -0,0 +1,107 @@ +"""Serializable Pydantic models for the Gemini SDK Temporal integration. + +These models cross the activity boundary — they're constructed on the +workflow side and deserialized on the activity side (or vice versa). +""" + +from __future__ import annotations + +from typing import Any + +from google.genai import types +from pydantic import BaseModel + +__all__ = [ + "_GeminiApiRequest", + "_GeminiApiResponse", + "_GeminiApiStreamedResponse", + "_GeminiDownloadFileRequest", + "_GeminiRegisterFilesRequest", + "_GeminiUploadFileRequest", + "_GeminiUploadToFileSearchStoreRequest", + "_SerializableHttpOptions", +] + + +class _SerializableHttpOptions(BaseModel): + """Per-request HTTP options that can be serialized across the activity boundary. + + Non-serializable fields (httpx_client, httpx_async_client, aiohttp_client, + client_args, async_client_args) must be configured at GeminiPlugin init. + + ``timeout`` is excluded because Temporal owns timeouts/retries — configure + via ``ActivityConfig`` instead. + """ + + base_url: str | None = None + base_url_resource_scope: str | None = None + api_version: str | None = None + headers: dict[str, str] | None = None + extra_body: dict[str, Any] | None = None + + +# ── async_request models ────────────────────────────────────────────────── + + +class _GeminiApiRequest(BaseModel): + """Serializable activity input for a Gemini SDK API call.""" + + http_method: str + path: str + request_dict: dict[str, object] + http_options_overrides: _SerializableHttpOptions | None = None + + +class _GeminiApiResponse(BaseModel): + """Serializable activity output for a Gemini SDK API call.""" + + headers: dict[str, str] + body: str + + +class _GeminiApiStreamedResponse(BaseModel): + """Serializable activity output for a batched streamed API call. + + The activity collects all streamed chunks and returns them as a list. + The ``TemporalApiClient`` then yields them one at a time to the SDK. + """ + + chunks: list[_GeminiApiResponse] + + +# ── files upload/download models ────────────────────────────────────────── + + +class _GeminiUploadFileRequest(BaseModel): + """Serializable activity input for a file upload. + + For file path uploads the path is resolved on the worker. For + in-memory uploads the raw bytes are sent across the activity boundary. + """ + + file_bytes: bytes | None = None + file_path: str | None = None + config: types.UploadFileConfig | None = None + + +class _GeminiDownloadFileRequest(BaseModel): + """Serializable activity input for a file download.""" + + file: str + config: types.DownloadFileConfig | None = None + + +class _GeminiRegisterFilesRequest(BaseModel): + """Serializable activity input for registering GCS files.""" + + uris: list[str] + config: types.RegisterFilesConfig | None = None + + +class _GeminiUploadToFileSearchStoreRequest(BaseModel): + """Serializable activity input for uploading a file to a file search store.""" + + file_search_store_name: str + file_bytes: bytes | None = None + file_path: str | None = None + config: types.UploadToFileSearchStoreConfig | None = None diff --git a/temporalio/contrib/google_gemini_sdk/_temporal_api_client.py b/temporalio/contrib/google_gemini_sdk/_temporal_api_client.py new file mode 100644 index 000000000..30d77b9eb --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_temporal_api_client.py @@ -0,0 +1,276 @@ +"""Temporal-aware BaseApiClient that routes SDK calls through activities. + +This module provides ``TemporalApiClient``, a ``BaseApiClient`` subclass +whose HTTP methods dispatch through Temporal activities instead of making +direct calls. The real ``genai.Client`` with real credentials only exists +on the worker side inside the activity. + +This ensures: +- No credential fetching or refreshing happens in the workflow. +- No auth material (tokens, API keys) appears in Temporal event history. +- The SDK's AFC (automatic function calling) loop runs in the workflow, + so ``activity_as_tool()`` wrappers work naturally. +""" + +from __future__ import annotations + +from datetime import timedelta +from typing import Any + +from google.genai._api_client import BaseApiClient +from google.genai.types import HttpOptions, HttpOptionsOrDict +from google.genai.types import HttpResponse as SdkHttpResponse + +from temporalio import workflow as temporal_workflow +from temporalio.contrib.google_gemini_sdk._models import ( + _GeminiApiRequest, + _GeminiApiResponse, + _GeminiApiStreamedResponse, + _SerializableHttpOptions, +) +from temporalio.workflow import ActivityConfig + +# Fields on HttpOptions that cannot be serialized or should not be forwarded. +_REJECTED_HTTP_OPTION_FIELDS = frozenset( + { + "httpx_client", + "httpx_async_client", + "aiohttp_client", + "client_args", + "async_client_args", + } +) + + +def _validate_http_options(http_options: HttpOptions | None) -> None: + """Raise if http_options contains non-serializable fields.""" + if http_options is None: + return + bad_fields = [ + f + for f in _REJECTED_HTTP_OPTION_FIELDS + if getattr(http_options, f, None) is not None + ] + if bad_fields: + raise ValueError( + f"http_options cannot include {bad_fields}. " + f"Configure custom HTTP clients at GeminiPlugin init instead." + ) + + +class TemporalApiClient(BaseApiClient): + """A ``BaseApiClient`` that routes all API calls through Temporal activities. + + This client is used on the workflow side. It does NOT initialize HTTP + clients, load credentials, or make any network calls. It only holds the + minimal configuration needed for the SDK's request formatting logic + (e.g., choosing between Vertex AI and ML Dev parameter transformations). + + All actual HTTP calls are dispatched via ``workflow.execute_activity``. + """ + + def __init__( # pyright: ignore[reportMissingSuperCall] + self, + *, + vertexai: bool = False, + project: str | None = None, + location: str | None = None, + activity_config: ActivityConfig | None = None, + ) -> None: + """Initialize without calling super (no HTTP clients needed).""" + # Do NOT call super().__init__() — it creates HTTP clients, loads + # credentials, etc. We only set the properties the SDK's request + # formatting code accesses. + self.vertexai = vertexai + self.project = project + self.location = location + self.api_key: str | None = None + self.custom_base_url: str | None = None + + self._activity_config = activity_config or ActivityConfig( + start_to_close_timeout=__import__("datetime").timedelta(seconds=60), + ) + + def _verify_response(self, response_model: Any) -> None: + """No-op — matches the base implementation.""" + pass + + def close(self) -> None: + """No-op — no HTTP resources to close.""" + pass + + async def aclose(self) -> None: + """No-op — no HTTP resources to close.""" + pass + + def __del__(self) -> None: + """No-op — no HTTP resources to clean up.""" + pass + + @staticmethod + def _process_http_options( + http_options: HttpOptionsOrDict | None, + config: ActivityConfig, + ) -> _SerializableHttpOptions | None: + """Validate and extract serializable per-request HTTP options. + + Rejects non-serializable fields (custom HTTP clients), maps timeout + to the Temporal activity config, and returns the remaining options + for forwarding to the activity. + + Args: + http_options: Per-request options from the SDK call. + config: Mutable activity config dict — timeout is applied here. + + Returns: + Serializable options to forward, or None if nothing to forward. + """ + if http_options is None: + return None + + if isinstance(http_options, HttpOptions): + opts = http_options + else: + opts = HttpOptions.model_validate(http_options) + + _validate_http_options(opts) + + # timeout is owned by Temporal — apply it to the activity config + # rather than forwarding to the underlying HTTP client. + if opts.timeout is not None: + config["start_to_close_timeout"] = timedelta(milliseconds=opts.timeout) + + result = _SerializableHttpOptions( + base_url=opts.base_url, + base_url_resource_scope=( + opts.base_url_resource_scope.value + if opts.base_url_resource_scope + else None + ), + api_version=opts.api_version, + headers=opts.headers, + extra_body=opts.extra_body, + ) + # Only return if there are actual values set + if not result.model_dump(exclude_none=True): + return None + return result + + # ── Async (primary path for workflows) ────────────────────────────── + + async def async_request( + self, + http_method: str, + path: str, + request_dict: dict[str, object], + http_options: HttpOptionsOrDict | None = None, + ) -> SdkHttpResponse: + """Dispatch an async API request through a Temporal activity.""" + config: ActivityConfig = {**self._activity_config} + if "summary" not in config: + # Default summary is the API path (e.g. "models/gemini-2.5-flash:generateContent"). + config["summary"] = f"{http_method.upper()} {path}" + overrides = self._process_http_options(http_options, config) + + resp = await temporal_workflow.execute_activity( + "gemini_api_client_async_request", + _GeminiApiRequest( + http_method=http_method, + path=path, + request_dict=request_dict, + http_options_overrides=overrides, + ), + result_type=_GeminiApiResponse, + **config, + ) + return SdkHttpResponse(headers=resp.headers, body=resp.body) + + # ── Sync (not expected in async workflows, but raise clearly) ─────── + + def request( + self, + http_method: str, + path: str, + request_dict: dict[str, object], + http_options: HttpOptionsOrDict | None = None, + ) -> SdkHttpResponse: + """Raise — sync requests not supported in workflows.""" + raise RuntimeError( + "Synchronous requests are not supported in Temporal workflows. " + "Use the AsyncClient returned by gemini_client() instead." + ) + + def request_streamed( + self, + http_method: str, + path: str, + request_dict: dict[str, object], + http_options: HttpOptionsOrDict | None = None, + ) -> Any: + """Raise — sync streaming not supported in workflows.""" + raise RuntimeError( + "Synchronous streaming is not supported in Temporal workflows. " + "Use the AsyncClient returned by gemini_client() instead." + ) + + async def async_request_streamed( + self, + http_method: str, + path: str, + request_dict: dict[str, object], + http_options: HttpOptionsOrDict | None = None, + ) -> Any: + """Dispatch a streamed request, batching chunks in the activity.""" + config: ActivityConfig = {**self._activity_config} + if "summary" not in config: + config["summary"] = f"{http_method.upper()} {path}" + overrides = self._process_http_options(http_options, config) + + resp = await temporal_workflow.execute_activity( + "gemini_api_client_async_request_streamed", + _GeminiApiRequest( + http_method=http_method, + path=path, + request_dict=request_dict, + http_options_overrides=overrides, + ), + result_type=_GeminiApiStreamedResponse, + **config, + ) + + async def _yield_chunks(): + for chunk in resp.chunks: + yield SdkHttpResponse(headers=chunk.headers, body=chunk.body) + + return _yield_chunks() + + # ── File upload/download ───────────────────────────────────────────── + # File operations are handled at a higher level by TemporalAsyncFiles + # (in _temporal_files.py), which dispatches the entire upload/download + # as a Temporal activity using the real client on the worker side. + # These internal BaseApiClient methods are not called in that path, + # so we raise here to catch any unexpected direct usage. + + def upload_file(self, *args: Any, **kwargs: Any) -> Any: + """Raise — use client.files.upload() instead.""" + raise NotImplementedError( + "Use client.files.upload() instead of the internal upload_file() method." + ) + + async def async_upload_file(self, *args: Any, **kwargs: Any) -> Any: + """Raise — use client.files.upload() instead.""" + raise NotImplementedError( + "Use client.files.upload() instead of the internal async_upload_file() method." + ) + + def download_file(self, *args: Any, **kwargs: Any) -> Any: + """Raise — use client.files.download() instead.""" + raise NotImplementedError( + "Use client.files.download() instead of the internal download_file() method." + ) + + async def async_download_file(self, *args: Any, **kwargs: Any) -> Any: + """Raise — use client.files.download() instead.""" + raise NotImplementedError( + "Use client.files.download() instead of the internal async_download_file() method." + ) diff --git a/temporalio/contrib/google_gemini_sdk/_temporal_async_client.py b/temporalio/contrib/google_gemini_sdk/_temporal_async_client.py new file mode 100644 index 000000000..a4982cdac --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_temporal_async_client.py @@ -0,0 +1,47 @@ +"""Temporal-aware AsyncClient shim. + +``TemporalAsyncClient`` is an ``AsyncClient`` subclass that wires up +Temporal-aware replacements for modules that need special handling +(files, file search stores). +""" + +from __future__ import annotations + +from google.genai.client import AsyncClient + +from temporalio.contrib.google_gemini_sdk._temporal_api_client import ( + TemporalApiClient, +) +from temporalio.contrib.google_gemini_sdk._temporal_file_search_stores import ( + TemporalAsyncFileSearchStores, +) +from temporalio.contrib.google_gemini_sdk._temporal_files import ( + TemporalAsyncFiles, +) +from temporalio.workflow import ActivityConfig + + +class TemporalAsyncClient(AsyncClient): + """``AsyncClient`` subclass that uses Temporal-aware modules. + + Replaces ``AsyncFiles`` with ``TemporalAsyncFiles`` and + ``AsyncFileSearchStores`` with ``TemporalAsyncFileSearchStores`` + so that file upload/download operations and file search store uploads + run entirely inside Temporal activities. + + Other modules (models, tunings, caches, batches, live, tokens, + operations) are inherited unchanged and work through + ``TemporalApiClient``'s activity-backed HTTP methods. + """ + + def __init__( + self, + api_client: TemporalApiClient, + activity_config: ActivityConfig | None = None, + ) -> None: + """Initialize with Temporal-aware files and file search stores.""" + super().__init__(api_client) + self._files = TemporalAsyncFiles(api_client, activity_config) + self._file_search_stores = TemporalAsyncFileSearchStores( + api_client, activity_config + ) diff --git a/temporalio/contrib/google_gemini_sdk/_temporal_file_search_stores.py b/temporalio/contrib/google_gemini_sdk/_temporal_file_search_stores.py new file mode 100644 index 000000000..4f93021ca --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_temporal_file_search_stores.py @@ -0,0 +1,101 @@ +"""Temporal-aware AsyncFileSearchStores shim. + +``TemporalAsyncFileSearchStores`` is an ``AsyncFileSearchStores`` subclass +whose ``upload_to_file_search_store`` method dispatches through a Temporal +activity so the entire upload (including filesystem access and resumable +upload negotiation) runs on the activity worker. +""" + +from __future__ import annotations + +import io +import os +from datetime import timedelta + +from google.genai import types +from google.genai.file_search_stores import AsyncFileSearchStores + +from temporalio import workflow as temporal_workflow +from temporalio.contrib.google_gemini_sdk._models import ( + _GeminiUploadToFileSearchStoreRequest, +) +from temporalio.contrib.google_gemini_sdk._temporal_api_client import ( + TemporalApiClient, + _validate_http_options, +) +from temporalio.workflow import ActivityConfig + + +class TemporalAsyncFileSearchStores(AsyncFileSearchStores): + """``AsyncFileSearchStores`` subclass that routes ``upload_to_file_search_store`` through an activity. + + The entire upload operation — including filesystem access, resumable + upload negotiation, and chunked transfer — runs inside a Temporal + activity on the worker. All other methods (``create``, ``get``, + ``delete``, ``list``, ``import_file``, ``documents``) are inherited + and already work through the ``TemporalApiClient``'s ``async_request`` + activity. + """ + + def __init__( + self, + api_client: TemporalApiClient, + activity_config: ActivityConfig | None = None, + ) -> None: + """Initialize with activity config for upload timeouts.""" + super().__init__(api_client) + self._activity_config = activity_config or ActivityConfig( + start_to_close_timeout=timedelta(seconds=60), + ) + + async def upload_to_file_search_store( + self, + *, + file_search_store_name: str, + file: str | os.PathLike[str] | io.IOBase, + config: types.UploadToFileSearchStoreConfigOrDict | None = None, + ) -> types.UploadToFileSearchStoreOperation: + """Upload a file to a file search store via a Temporal activity. + + Accepts a file path (resolved on the worker), ``os.PathLike``, or + an ``io.IOBase`` (bytes sent across the activity boundary). + """ + act_config: ActivityConfig = {**self._activity_config} + if "summary" not in act_config: + act_config["summary"] = "file_search_stores.upload" + + upload_config = None + if config is not None: + if isinstance(config, dict): + upload_config = types.UploadToFileSearchStoreConfig.model_validate( + config + ) + else: + upload_config = config + _validate_http_options(upload_config.http_options) + + if isinstance(file, io.IOBase): + req = _GeminiUploadToFileSearchStoreRequest( + file_search_store_name=file_search_store_name, + file_bytes=file.read(), + config=upload_config, + ) + elif isinstance(file, str): + req = _GeminiUploadToFileSearchStoreRequest( + file_search_store_name=file_search_store_name, + file_path=file, + config=upload_config, + ) + else: + req = _GeminiUploadToFileSearchStoreRequest( + file_search_store_name=file_search_store_name, + file_path=file.__fspath__(), + config=upload_config, + ) + + return await temporal_workflow.execute_activity( + "gemini_file_search_stores_upload", + req, + result_type=types.UploadToFileSearchStoreOperation, + **act_config, + ) diff --git a/temporalio/contrib/google_gemini_sdk/_temporal_files.py b/temporalio/contrib/google_gemini_sdk/_temporal_files.py new file mode 100644 index 000000000..d988c4fc6 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/_temporal_files.py @@ -0,0 +1,161 @@ +"""Temporal-aware AsyncFiles shim. + +``TemporalAsyncFiles`` is an ``AsyncFiles`` subclass whose ``upload`` +and ``download`` methods dispatch through Temporal activities so the +entire file operation (including filesystem access) runs on the +activity worker. +""" + +from __future__ import annotations + +import io +import os +from datetime import timedelta +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import google.auth.credentials +from google.genai import types +from google.genai.files import AsyncFiles + +from temporalio import workflow as temporal_workflow +from temporalio.contrib.google_gemini_sdk._models import ( + _GeminiDownloadFileRequest, + _GeminiRegisterFilesRequest, + _GeminiUploadFileRequest, +) +from temporalio.contrib.google_gemini_sdk._temporal_api_client import ( + TemporalApiClient, + _validate_http_options, +) +from temporalio.workflow import ActivityConfig + + +class TemporalAsyncFiles(AsyncFiles): + """``AsyncFiles`` subclass that routes ``upload`` and ``download`` through activities. + + The entire file operation — including filesystem access, resumable + upload negotiation, and chunked transfer — runs inside a Temporal + activity on the worker. ``get``, ``delete``, and ``list`` are + inherited from ``AsyncFiles`` and already work through the + ``TemporalApiClient``'s ``async_request`` activity. + """ + + def __init__( + self, + api_client: TemporalApiClient, + activity_config: ActivityConfig | None = None, + ) -> None: + """Initialize with activity config for file operation timeouts.""" + super().__init__(api_client) + self._activity_config = activity_config or ActivityConfig( + start_to_close_timeout=timedelta(seconds=60), + ) + + async def upload( + self, + *, + file: str | os.PathLike[str] | io.IOBase, + config: types.UploadFileConfigOrDict | None = None, + ) -> types.File: + """Upload a file via a Temporal activity. + + Accepts a file path (resolved on the worker), ``os.PathLike``, or + an ``io.IOBase`` (bytes sent across the activity boundary). + """ + act_config: ActivityConfig = {**self._activity_config} + if "summary" not in act_config: + act_config["summary"] = "files.upload" + + upload_config = None + if config is not None: + if isinstance(config, dict): + upload_config = types.UploadFileConfig.model_validate(config) + else: + upload_config = config + _validate_http_options(upload_config.http_options) + + if isinstance(file, io.IOBase): + req = _GeminiUploadFileRequest(file_bytes=file.read(), config=upload_config) + elif isinstance(file, str): + req = _GeminiUploadFileRequest(file_path=file, config=upload_config) + else: + # os.PathLike — convert via __fspath__() to avoid importing os + req = _GeminiUploadFileRequest( + file_path=file.__fspath__(), config=upload_config + ) + + return await temporal_workflow.execute_activity( + "gemini_files_upload", + req, + result_type=types.File, + **act_config, + ) + + async def download( + self, + *, + file: str | types.File, + config: types.DownloadFileConfigOrDict | None = None, + ) -> bytes: + """Download a file via a Temporal activity.""" + act_config: ActivityConfig = {**self._activity_config} + if "summary" not in act_config: + act_config["summary"] = "files.download" + + download_config = None + if config is not None: + if isinstance(config, dict): + download_config = types.DownloadFileConfig.model_validate(config) + else: + download_config = config + _validate_http_options(download_config.http_options) + + if isinstance(file, types.File): + if not file.name: + raise ValueError("File object must have a name to download.") + file_name = file.name + else: + file_name = file + + return await temporal_workflow.execute_activity( + "gemini_files_download", + _GeminiDownloadFileRequest(file=file_name, config=download_config), + result_type=bytes, + **act_config, + ) + + async def register_files( + self, + *, + auth: google.auth.credentials.Credentials, + uris: list[str], + config: types.RegisterFilesConfigOrDict | None = None, + ) -> types.RegisterFilesResponse: + """Register GCS files via a Temporal activity. + + .. note:: + The ``auth`` parameter is **ignored**. The activity uses + ``credentials`` if provided to ``GeminiPlugin``, + otherwise falls back to the ``genai.Client``'s own credentials. + Either way, those credentials must have access to the GCS URIs + being registered. + """ + act_config: ActivityConfig = {**self._activity_config} + if "summary" not in act_config: + act_config["summary"] = "files.register_files" + + register_config = None + if config is not None: + if isinstance(config, dict): + register_config = types.RegisterFilesConfig.model_validate(config) + else: + register_config = config + _validate_http_options(register_config.http_options) + + return await temporal_workflow.execute_activity( + "gemini_files_register", + _GeminiRegisterFilesRequest(uris=uris, config=register_config), + result_type=types.RegisterFilesResponse, + **act_config, + ) diff --git a/temporalio/contrib/google_gemini_sdk/justfile b/temporalio/contrib/google_gemini_sdk/justfile new file mode 100644 index 000000000..1f9c33ea6 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/justfile @@ -0,0 +1,46 @@ +set dotenv-filename := ".env.local" +set dotenv-load + +run: + uv run python test_gemini.py + +worker: + uv run python first_class_example/worker.py + +query q="What's the weather right now?": + uv run python first_class_example/start_workflow.py "{{q}}" + +embed-worker: + uv run python embed_example/worker.py + +embed prompt="A serene mountain lake at sunset with pine trees": + uv run python embed_example/start_workflow.py "{{prompt}}" + +vibe-worker: + uv run python vibe_convergence_example/worker.py + +set positional-arguments + +vibe *ARGS: + uv run python vibe_convergence_example/start_workflow.py "$@" + +spy-worker: + uv run python i_spy_example/worker.py + +spy: + uv run python i_spy_example/start_workflow.py + +file-worker: + uv run python file_upload_example/worker.py + +file q="Summarize the key financial metrics and top 3 risks from this report.": + uv run python file_upload_example/start_workflow.py "{{q}}" + +file-store: + uv run python file_upload_example/start_workflow.py --store + +chat-worker: + uv run python chat_example/worker.py + +chat topic="recursion in programming": + uv run python chat_example/start_workflow.py "{{topic}}" diff --git a/temporalio/contrib/google_gemini_sdk/workflow.py b/temporalio/contrib/google_gemini_sdk/workflow.py new file mode 100644 index 000000000..309c07199 --- /dev/null +++ b/temporalio/contrib/google_gemini_sdk/workflow.py @@ -0,0 +1,187 @@ +"""Workflow utilities for Google Gemini SDK integration with Temporal. + +This module provides utilities for using the Google Gemini SDK within Temporal +workflows. The key entry points are: + +- :func:`gemini_client` — returns an ``AsyncClient`` backed by a + ``TemporalApiClient`` that routes all API calls through Temporal activities. +- :func:`activity_as_tool` — converts a Temporal activity into a Gemini tool + callable for use with automatic function calling (AFC). +""" + +from __future__ import annotations + +import functools +import inspect +from collections.abc import Callable +from datetime import timedelta +from typing import Any + +from google.genai.client import AsyncClient + +from temporalio import activity +from temporalio import workflow as temporal_workflow +from temporalio.contrib.google_gemini_sdk._temporal_api_client import ( + TemporalApiClient, +) +from temporalio.contrib.google_gemini_sdk._temporal_async_client import ( + TemporalAsyncClient, +) +from temporalio.exceptions import ApplicationError +from temporalio.workflow import ActivityConfig + + +def activity_as_tool( + fn: Callable, + *, + activity_config: ActivityConfig | None = None, +) -> Callable: + """Convert a Temporal activity into a Gemini-compatible async tool callable. + + .. warning:: + This API is experimental and may change in future versions. + Use with caution in production environments. + + Returns an async callable with the same name, docstring, and type signature as + ``fn``. When Gemini's automatic function calling (AFC) invokes the returned + callable from within a Temporal workflow, the call is executed as a Temporal + activity via :func:`workflow.execute_activity`. Each tool invocation therefore + appears as a separate, durable entry in the workflow event history. + + Because AFC is left **enabled**, the Gemini SDK owns the agentic loop — no + manual ``while`` loop or ``run_agent()`` helper is required. Pass the returned + callable directly to ``GenerateContentConfig(tools=[...])``. + + Args: + fn: A Temporal activity function decorated with ``@activity.defn``. + activity_config: Configuration for the activity execution (timeouts, + retry policy, etc.). Defaults to a 30-second + ``start_to_close_timeout``. + + Returns: + An async callable suitable for use as a Gemini tool. + + Raises: + ApplicationError: If ``fn`` is not decorated with ``@activity.defn`` or + has no activity name. + """ + ret = activity._Definition.from_callable(fn) + if not ret: + raise ApplicationError( + "Bare function without @activity.defn decorator is not supported", + "invalid_tool", + ) + if ret.name is None: + raise ApplicationError( + "Activity must have a name to be used as a Gemini tool", + "invalid_tool", + ) + + config: ActivityConfig = { + **( + activity_config + or ActivityConfig( + start_to_close_timeout=timedelta(seconds=30), + ) + ) + } + if "summary" not in config: + config["summary"] = "tool_call" + + # For class-based activities the first parameter is 'self'. Partially apply + # it so that Gemini inspects only the user-facing parameters when building + # the function-call schema, while the worker resolves the real instance at + # execution time. + params = list(inspect.signature(fn).parameters.keys()) + schema_fn: Callable = fn + if params and params[0] == "self": + partial = functools.partial(fn, None) + setattr(partial, "__name__", fn.__name__) + partial.__annotations__ = getattr(fn, "__annotations__", {}) + setattr( + partial, + "__temporal_activity_definition", + getattr(fn, "__temporal_activity_definition", None), + ) + partial.__doc__ = fn.__doc__ + schema_fn = partial + + activity_name: str = ret.name + + async def wrapper(*args: Any, **kwargs: Any) -> Any: + sig = inspect.signature(schema_fn) + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + activity_args = list(bound.arguments.values()) + return await temporal_workflow.execute_activity( + activity_name, + args=activity_args, + **config, + ) + + wrapper.__name__ = schema_fn.__name__ # type: ignore + wrapper.__doc__ = schema_fn.__doc__ + setattr(wrapper, "__signature__", inspect.signature(schema_fn)) + wrapper.__annotations__ = getattr(schema_fn, "__annotations__", {}) + + return wrapper + + +def gemini_client( + *, + vertexai: bool = False, + project: str | None = None, + location: str | None = None, + activity_config: ActivityConfig | None = None, +) -> AsyncClient: + """Create a Gemini ``AsyncClient`` that routes API calls through Temporal activities. + + .. warning:: + This API is experimental and may change in future versions. + Use with caution in production environments. + + Returns an ``AsyncClient`` backed by a :class:`TemporalApiClient`. The + SDK's code (including the AFC loop) runs in the workflow; only the actual + HTTP API calls cross into activities. Credentials are never fetched or + stored in the workflow — the activity worker handles authentication + independently. + + Call this from within a workflow ``run`` method: + + .. code-block:: python + + @workflow.defn + class MyWorkflow: + @workflow.run + async def run(self, query: str) -> str: + client = gemini_client() + response = await client.models.generate_content( + model="gemini-2.0-flash", + contents=query, + config=GenerateContentConfig( + tools=[activity_as_tool(my_tool)], + ), + ) + return response.text + + Args: + vertexai: Whether to use Vertex AI API endpoints. Must match the + ``GeminiPlugin`` configuration on the worker side. Defaults to + ``False`` (Gemini Developer API). + project: Google Cloud project ID. Only needed when ``vertexai=True`` + and the SDK's request formatting requires it (e.g., cache + operations). + location: Google Cloud location. Same conditions as ``project``. + activity_config: Override the default activity configuration + (timeouts, retry policy, etc.) for Gemini API call activities. + + Returns: + A ``google.genai.client.AsyncClient`` instance. + """ + temporal_api_client = TemporalApiClient( + vertexai=vertexai, + project=project, + location=location, + activity_config=activity_config, + ) + return TemporalAsyncClient(temporal_api_client, activity_config) diff --git a/tests/contrib/google_gemini_sdk/__init__.py b/tests/contrib/google_gemini_sdk/__init__.py new file mode 100644 index 000000000..eeab413ee --- /dev/null +++ b/tests/contrib/google_gemini_sdk/__init__.py @@ -0,0 +1 @@ +"""Tests for the Google Gemini SDK Temporal integration.""" diff --git a/tests/contrib/google_gemini_sdk/test_gemini.py b/tests/contrib/google_gemini_sdk/test_gemini.py new file mode 100644 index 000000000..7dd0bc9db --- /dev/null +++ b/tests/contrib/google_gemini_sdk/test_gemini.py @@ -0,0 +1,1301 @@ +"""Integration tests for the Google Gemini SDK Temporal integration. + +Tests cover: +- Basic generate_content through workflow +- Tool calling via activity_as_tool (single arg, multi arg, class method) +- Workflow method as a plain tool (runs in-workflow, not as an activity) +- Tool failure propagation +- Multiple sequential tool calls with arg verification +- Batched streaming via generate_content_stream +- Per-request http_options propagation +- File upload (str path + io.BytesIO) and download via TemporalAsyncFiles +- File search store upload via TemporalAsyncFileSearchStores +- Multi-turn chat via client.chats +- TemporalAsyncClient wiring (files, file_search_stores) +- TemporalApiClient edge cases (sync raises) +- activity_as_tool validation and metadata preservation +- gemini_client configuration +""" + +import inspect +import io +import json +import uuid +from datetime import timedelta +from typing import Any +from unittest.mock import AsyncMock + +import pytest +from google.genai import Client as GeminiClient +from google.genai import types +from google.genai.types import HttpResponse as SdkHttpResponse + +from temporalio import activity, workflow +from temporalio.client import Client, WorkflowFailureError +from temporalio.common import RetryPolicy +from temporalio.contrib.google_gemini_sdk import ( + GeminiPlugin, + activity_as_tool, + gemini_client, +) +from temporalio.contrib.google_gemini_sdk._models import ( + _GeminiApiRequest, + _GeminiApiResponse, + _GeminiApiStreamedResponse, + _GeminiDownloadFileRequest, + _GeminiUploadFileRequest, + _GeminiUploadToFileSearchStoreRequest, +) +from temporalio.contrib.google_gemini_sdk._temporal_api_client import ( + TemporalApiClient, +) +from temporalio.contrib.google_gemini_sdk._temporal_async_client import ( + TemporalAsyncClient, +) +from temporalio.contrib.google_gemini_sdk._temporal_file_search_stores import ( + TemporalAsyncFileSearchStores, +) +from temporalio.contrib.google_gemini_sdk._temporal_files import ( + TemporalAsyncFiles, +) +from temporalio.exceptions import ApplicationError +from temporalio.workflow import ActivityConfig +from tests.helpers import new_worker + +# --------------------------------------------------------------------------- +# Mock response helpers +# --------------------------------------------------------------------------- + + +def make_text_response(text: str) -> str: + """Build a JSON body string for a simple text response.""" + return json.dumps( + { + "candidates": [ + { + "content": { + "role": "model", + "parts": [{"text": text}], + }, + "finishReason": "STOP", + } + ], + "usageMetadata": { + "promptTokenCount": 5, + "candidatesTokenCount": 10, + }, + } + ) + + +def make_function_call_response(fn_name: str, args: dict) -> str: + """Build a JSON body string for a function-call response.""" + return json.dumps( + { + "candidates": [ + { + "content": { + "role": "model", + "parts": [{"functionCall": {"name": fn_name, "args": args}}], + }, + "finishReason": "STOP", + } + ], + "usageMetadata": { + "promptTokenCount": 10, + "candidatesTokenCount": 15, + }, + } + ) + + +# --------------------------------------------------------------------------- +# Tool call tracker — records every tool invocation for assertion +# --------------------------------------------------------------------------- + + +class ToolCallTracker: + """Tracks tool invocations across activities and workflow methods. + + Each tool appends (name, args_dict) to ``calls`` so tests can assert + exactly which tools were called, in what order, with what arguments. + """ + + def __init__(self) -> None: + self.calls: list[tuple[str, dict]] = [] + + @activity.defn + async def get_weather(self, city: str) -> str: + """Get the weather for a given city.""" + self.calls.append(("get_weather", {"city": city})) + return f"Weather in {city}: Sunny, 20C" + + @activity.defn + async def get_weather_country(self, city: str, country: str) -> str: + """Get the weather for a given city in a country.""" + self.calls.append(("get_weather_country", {"city": city, "country": country})) + return f"Weather in {city}, {country}: Rainy, 15C" + + @activity.defn + async def get_weather_failure(self, city: str) -> str: + """Activity that always fails.""" + self.calls.append(("get_weather_failure", {"city": city})) + raise ApplicationError("Weather service unavailable", non_retryable=True) + + +# --------------------------------------------------------------------------- +# Test helper: tracking gemini_api_client_async_request activity +# --------------------------------------------------------------------------- + + +class GeminiApiCallTracker: + """A test replacement for the gemini_api_client activities. + + Records every ``_GeminiApiRequest`` received and returns canned + ``_GeminiApiResponse`` bodies in order. After the workflow completes, + inspect ``requests`` to verify exactly what the integration sent. + + For streamed requests, the mock response is split into per-line chunks + to simulate multiple streamed chunks. + + The real ``GeminiPlugin`` is still used for its data converter, sandbox + passthrough, and workflow runner configuration — only its activity + registration is suppressed so this tracker can take its place. + """ + + def __init__(self, mock_responses: list[str]) -> None: + self._mock_responses = mock_responses + self.requests: list[_GeminiApiRequest] = [] + self.file_upload_requests: list[_GeminiUploadFileRequest] = [] + self.file_download_requests: list[_GeminiDownloadFileRequest] = [] + self.file_search_store_upload_requests: list[ + _GeminiUploadToFileSearchStoreRequest + ] = [] + self._call_index = 0 + + def _next_response(self, req: _GeminiApiRequest) -> str: + self.requests.append(req) + idx = self._call_index + self._call_index += 1 + if idx >= len(self._mock_responses): + raise ApplicationError( + f"No more mock responses (called {idx + 1} times, " + f"have {len(self._mock_responses)})", + non_retryable=True, + ) + return self._mock_responses[idx] + + @activity.defn(name="gemini_api_client_async_request") + async def gemini_api_client_async_request( + self, req: _GeminiApiRequest + ) -> _GeminiApiResponse: + return _GeminiApiResponse( + headers={"content-type": "application/json"}, + body=self._next_response(req), + ) + + @activity.defn(name="gemini_api_client_async_request_streamed") + async def gemini_api_client_async_request_streamed( + self, req: _GeminiApiRequest + ) -> _GeminiApiStreamedResponse: + body = self._next_response(req) + # Split the response text into word-level chunks so tests can + # verify that multiple chunks are yielded back to the workflow. + parsed = json.loads(body) + full_text = ( + parsed.get("candidates", [{}])[0] + .get("content", {}) + .get("parts", [{}])[0] + .get("text", "") + ) + words = full_text.split() + chunks = [] + for word in words: + chunks.append( + _GeminiApiResponse( + headers={"content-type": "application/json"}, + body=make_text_response(word), + ) + ) + return _GeminiApiStreamedResponse(chunks=chunks) + + @activity.defn(name="gemini_files_upload") + async def gemini_files_upload(self, req: _GeminiUploadFileRequest) -> types.File: + self.file_upload_requests.append(req) + return types.File( + name="files/test-uploaded-file", + uri="https://fake.uri/files/test-uploaded-file", + size_bytes=len(req.file_bytes) if req.file_bytes else 0, + ) + + @activity.defn(name="gemini_files_download") + async def gemini_files_download(self, req: _GeminiDownloadFileRequest) -> bytes: + self.file_download_requests.append(req) + return b"fake file content" + + @activity.defn(name="gemini_file_search_stores_upload") + async def gemini_file_search_stores_upload( + self, req: _GeminiUploadToFileSearchStoreRequest + ) -> types.UploadToFileSearchStoreOperation: + self.file_search_store_upload_requests.append(req) + return types.UploadToFileSearchStoreOperation.model_construct( + name="operations/test-op", + ) + + +def apply_plugin( + client: Client, mock_responses: list[str] +) -> tuple[Client, GeminiApiCallTracker]: + """Create a real GeminiPlugin whose activities include a tracking fake. + + Monkey-patches ``GeminiApiCaller.activities`` so that when the plugin + constructs itself, it registers our tracking activity instead of + the real ones. Everything else — data converter, sandbox passthrough, + workflow runner — is the real plugin code. + + Returns the configured Temporal client and the tracker. + """ + from temporalio.contrib.google_gemini_sdk._gemini_activity import GeminiApiCaller + + tracker = GeminiApiCallTracker(mock_responses) + original_activities = GeminiApiCaller.activities + GeminiApiCaller.activities = lambda self: [ # type: ignore[method-assign] + tracker.gemini_api_client_async_request, + tracker.gemini_api_client_async_request_streamed, + tracker.gemini_files_upload, + tracker.gemini_files_download, + tracker.gemini_file_search_stores_upload, + ] + try: + gemini = GeminiClient(api_key="fake-test-key") + plugin = GeminiPlugin(gemini) + finally: + GeminiApiCaller.activities = original_activities # type: ignore[method-assign] + + config = client.config() + config["plugins"] = [plugin] + return Client(**config), tracker + + +# --------------------------------------------------------------------------- +# Workflows +# --------------------------------------------------------------------------- + + +@workflow.defn +class SimpleGenerateWorkflow: + """Workflow that does a simple generate_content call.""" + + @workflow.run + async def run(self, prompt: str) -> str: + client = gemini_client() + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=prompt, + ) + return response.text or "" + + +@workflow.defn +class SingleArgToolWorkflow: + """Workflow that uses activity_as_tool for a single-arg tool.""" + + @workflow.run + async def run(self, prompt: str) -> str: + client = gemini_client() + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=prompt, + config=types.GenerateContentConfig( + tools=[ + activity_as_tool( + ToolCallTracker.get_weather, + activity_config=ActivityConfig( + start_to_close_timeout=timedelta(seconds=10), + ), + ), + ], + ), + ) + return response.text or "" + + +@workflow.defn +class MultiArgToolWorkflow: + """Workflow with multi-arg tool.""" + + @workflow.run + async def run(self, prompt: str) -> str: + client = gemini_client() + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=prompt, + config=types.GenerateContentConfig( + tools=[ + activity_as_tool( + ToolCallTracker.get_weather_country, + activity_config=ActivityConfig( + start_to_close_timeout=timedelta(seconds=10), + ), + ), + ], + ), + ) + return response.text or "" + + +@workflow.defn +class ToolFailureWorkflow: + """Workflow with a tool that always fails.""" + + @workflow.run + async def run(self, prompt: str) -> str: + client = gemini_client() + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=prompt, + config=types.GenerateContentConfig( + tools=[ + activity_as_tool( + ToolCallTracker.get_weather_failure, + activity_config=ActivityConfig( + start_to_close_timeout=timedelta(seconds=10), + retry_policy=RetryPolicy(maximum_attempts=1), + ), + ), + ], + ), + ) + return response.text or "" + + +@workflow.defn +class MultipleToolsWorkflow: + """Workflow with multiple tools that are called in sequence.""" + + @workflow.run + async def run(self, prompt: str) -> str: + client = gemini_client() + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=prompt, + config=types.GenerateContentConfig( + tools=[ + activity_as_tool( + ToolCallTracker.get_weather, + activity_config=ActivityConfig( + start_to_close_timeout=timedelta(seconds=10), + ), + ), + activity_as_tool( + ToolCallTracker.get_weather_country, + activity_config=ActivityConfig( + start_to_close_timeout=timedelta(seconds=10), + ), + ), + ], + ), + ) + return response.text or "" + + +@workflow.defn +class WorkflowMethodToolWorkflow: + """Workflow that passes a plain method as a tool (runs in-workflow, not as an activity).""" + + def __init__(self) -> None: + self.tool_calls: list[tuple[str, dict]] = [] + + @workflow.run + async def run(self, prompt: str) -> str: + client = gemini_client() + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=prompt, + config=types.GenerateContentConfig( + tools=[self.lookup_city], + ), + ) + return response.text or "" + + async def lookup_city(self, city: str) -> str: + """Look up info about a city.""" + self.tool_calls.append(("lookup_city", {"city": city})) + return f"{city} is a great place to visit" + + @workflow.query + def get_tool_calls(self) -> list[tuple[str, dict]]: + return self.tool_calls + + +@workflow.defn +class StreamedGenerateWorkflow: + """Workflow that uses generate_content_stream.""" + + @workflow.run + async def run(self, prompt: str) -> list[str]: + client = gemini_client() + chunks: list[str] = [] + async for chunk in await client.models.generate_content_stream( + model="gemini-2.5-flash", + contents=prompt, + ): + if chunk.text: + chunks.append(chunk.text) + return chunks + + +@workflow.defn +class HttpOptionsWorkflow: + """Workflow that passes per-request http_options through generate_content.""" + + @workflow.run + async def run(self, prompt: str, http_options: types.HttpOptionsDict) -> str: + client = gemini_client() + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=prompt, + config=types.GenerateContentConfig( + http_options=types.HttpOptions.model_validate(http_options), + ), + ) + return response.text or "" + + +@workflow.defn +class FullIntegrationWorkflow: + """Exercises every activity path in a single workflow run. + + Uses the real GeminiPlugin activities (not the tracker), so this + tests the actual activity implementations end-to-end with a mocked + genai.Client. + """ + + @workflow.run + async def run(self, prompt: str) -> dict[str, Any]: + client = gemini_client() + results: dict[str, Any] = {} + + # 1. generate_content (async_request activity) + response = await client.models.generate_content( + model="gemini-2.5-flash", + contents=prompt, + ) + results["generate"] = response.text or "" + + # 2. generate_content_stream (async_request_streamed activity) + chunks: list[str] = [] + async for chunk in await client.models.generate_content_stream( + model="gemini-2.5-flash", + contents=prompt, + ): + if chunk.text: + chunks.append(chunk.text) + results["stream_chunks"] = chunks + + # 3. files.upload (gemini_files_upload activity) + uploaded = await client.files.upload( + file="/tmp/fake.txt", + config=types.UploadFileConfig(display_name="Integration Test"), + ) + results["upload_name"] = uploaded.name or "" + + # 4. files.download (gemini_files_download activity) + data = await client.files.download(file="files/some-file") + results["download"] = data.decode() if isinstance(data, bytes) else str(data) + + # 5. file_search_stores.upload_to_file_search_store activity + store_name = "fileSearchStores/test" + op = await client.file_search_stores.upload_to_file_search_store( + file_search_store_name=store_name, + file="/tmp/doc.txt", + ) + results["fss_upload_op"] = op.name or "" + + # 6. generate_content grounded with file_search tool (RAG query) + rag_response = await client.models.generate_content( + model="gemini-2.5-flash", + contents="What does the document say?", + config=types.GenerateContentConfig( + tools=[ + types.Tool( + file_search=types.FileSearch( + file_search_store_names=[store_name], + ), + ), + ], + ), + ) + results["rag"] = rag_response.text or "" + + # 7. Clean up the file search store + await client.file_search_stores.delete( + name=store_name, + config=types.DeleteFileSearchStoreConfig(force=True), + ) + results["store_deleted"] = True + + return results + + +@workflow.defn +class FileUploadStrWorkflow: + """Workflow that uploads a file via str path.""" + + @workflow.run + async def run(self, file_path: str) -> str: + client = gemini_client() + uploaded = await client.files.upload( + file=file_path, + config=types.UploadFileConfig( + display_name="Test File", + mime_type="text/plain", + ), + ) + return uploaded.name or "" + + +@workflow.defn +class FileUploadBytesWorkflow: + """Workflow that uploads a file via io.BytesIO.""" + + @workflow.run + async def run(self, data: bytes) -> str: + client = gemini_client() + uploaded = await client.files.upload( + file=io.BytesIO(data), + config=types.UploadFileConfig( + display_name="Bytes File", + mime_type="text/plain", + ), + ) + return uploaded.name or "" + + +@workflow.defn +class FileDownloadWorkflow: + """Workflow that downloads a file by name.""" + + @workflow.run + async def run(self, file_name: str) -> bytes: + client = gemini_client() + return await client.files.download(file=file_name) + + +@workflow.defn +class FileSearchStoreUploadWorkflow: + """Workflow that uploads to a file search store.""" + + @workflow.run + async def run(self, store_name: str, file_path: str) -> str: + client = gemini_client() + op = await client.file_search_stores.upload_to_file_search_store( + file_search_store_name=store_name, + file=file_path, + config=types.UploadToFileSearchStoreConfig( + display_name="Test Doc", + mime_type="text/plain", + ), + ) + return op.name or "" + + +@workflow.defn +class RegisterFilesWorkflow: + """Workflow that calls files.register_files.""" + + @workflow.run + async def run(self, uris: list[str]) -> str: + client = gemini_client() + # auth arg is ignored by TemporalAsyncFiles — the activity uses + # credentials from GeminiPlugin init. We pass a dummy here; + # can't import google.auth.credentials in the sandbox so we + # use a sentinel that satisfies the type at runtime. + resp = await client.files.register_files( + auth=None, # type: ignore[arg-type] + uris=uris, + ) + return str(len(resp.files or [])) + + +@workflow.defn +class ChatWorkflow: + """Workflow that uses client.chats for multi-turn conversation.""" + + @workflow.run + async def run(self, prompt: str) -> list[str]: + client = gemini_client() + chat = client.chats.create( + model="gemini-2.5-flash", + ) + r1 = await chat.send_message(prompt) + r2 = await chat.send_message("Follow up question") + return [r1.text or "", r2.text or ""] + + +# =========================================================================== +# Integration tests — run workflows against a real Temporal test server +# =========================================================================== + + +async def test_simple_generate_content(client: Client): + """Basic generate_content returns text through a workflow.""" + new_client, _ = apply_plugin(client, [make_text_response("Hello from Gemini!")]) + + async with new_worker(new_client, SimpleGenerateWorkflow) as worker: + result = await new_client.execute_workflow( + SimpleGenerateWorkflow.run, + "Say hello", + id=f"gemini-simple-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert result == "Hello from Gemini!" + + +async def test_tool_call_single_arg(client: Client): + """Tool calling with a single-argument activity via AFC.""" + tool_tracker = ToolCallTracker() + new_client, _ = apply_plugin( + client, + [ + make_function_call_response("get_weather", {"city": "Tokyo"}), + make_text_response("The weather in Tokyo is sunny and 20C."), + ], + ) + + async with new_worker( + new_client, + SingleArgToolWorkflow, + activities=[tool_tracker.get_weather], + ) as worker: + result = await new_client.execute_workflow( + SingleArgToolWorkflow.run, + "What's the weather in Tokyo?", + id=f"gemini-tool-single-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert tool_tracker.calls == [("get_weather", {"city": "Tokyo"})] + assert result == "The weather in Tokyo is sunny and 20C." + + +async def test_tool_call_multi_arg(client: Client): + """Tool calling with a multi-argument activity.""" + tool_tracker = ToolCallTracker() + new_client, _ = apply_plugin( + client, + [ + make_function_call_response( + "get_weather_country", {"city": "Paris", "country": "France"} + ), + make_text_response("Paris, France: Rainy, 15C."), + ], + ) + + async with new_worker( + new_client, + MultiArgToolWorkflow, + activities=[tool_tracker.get_weather_country], + ) as worker: + result = await new_client.execute_workflow( + MultiArgToolWorkflow.run, + "What's the weather in Paris, France?", + id=f"gemini-tool-multi-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert tool_tracker.calls == [ + ("get_weather_country", {"city": "Paris", "country": "France"}) + ] + assert result == "Paris, France: Rainy, 15C." + + +async def test_tool_failure_propagation(client: Client): + """Tool activity failure causes the workflow to fail.""" + tool_tracker = ToolCallTracker() + new_client, _ = apply_plugin( + client, + [ + make_function_call_response("get_weather_failure", {"city": "Nowhere"}), + ], + ) + + async with new_worker( + new_client, + ToolFailureWorkflow, + activities=[tool_tracker.get_weather_failure], + ) as worker: + with pytest.raises(WorkflowFailureError): + await new_client.execute_workflow( + ToolFailureWorkflow.run, + "Weather in Nowhere?", + id=f"gemini-tool-fail-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert tool_tracker.calls == [("get_weather_failure", {"city": "Nowhere"})] + + +async def test_multiple_tools_sequential(client: Client): + """Multiple tools called in sequence within one generate_content call.""" + tool_tracker = ToolCallTracker() + new_client, _ = apply_plugin( + client, + [ + make_function_call_response("get_weather", {"city": "Tokyo"}), + make_function_call_response( + "get_weather_country", {"city": "Paris", "country": "France"} + ), + make_text_response("Tokyo is sunny; Paris is rainy."), + ], + ) + + async with new_worker( + new_client, + MultipleToolsWorkflow, + activities=[ + tool_tracker.get_weather, + tool_tracker.get_weather_country, + ], + ) as worker: + result = await new_client.execute_workflow( + MultipleToolsWorkflow.run, + "Compare Tokyo and Paris weather", + id=f"gemini-multi-tools-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=15), + ) + + assert tool_tracker.calls == [ + ("get_weather", {"city": "Tokyo"}), + ("get_weather_country", {"city": "Paris", "country": "France"}), + ] + assert result == "Tokyo is sunny; Paris is rainy." + + +async def test_workflow_method_as_tool(client: Client): + """A plain workflow method (not an activity) used as a tool runs in-workflow.""" + new_client, _ = apply_plugin( + client, + [ + make_function_call_response("lookup_city", {"city": "Berlin"}), + make_text_response("Berlin is wonderful."), + ], + ) + + async with new_worker(new_client, WorkflowMethodToolWorkflow) as worker: + handle = await new_client.start_workflow( + WorkflowMethodToolWorkflow.run, + "Tell me about Berlin", + id=f"gemini-wf-method-tool-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + result = await handle.result() + # Query must happen while worker is alive + tool_calls = await handle.query(WorkflowMethodToolWorkflow.get_tool_calls) + + assert tool_calls == [("lookup_city", {"city": "Berlin"})] + assert result == "Berlin is wonderful." + + +async def test_streamed_generate_content(client: Client): + """generate_content_stream collects batched chunks from the activity.""" + new_client, _ = apply_plugin( + client, [make_text_response("The quick brown fox jumps over the lazy dog")] + ) + + async with new_worker(new_client, StreamedGenerateWorkflow) as worker: + result = await new_client.execute_workflow( + StreamedGenerateWorkflow.run, + "Say something", + id=f"gemini-streamed-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + # The tracker splits the text into per-word chunks + assert len(result) == 9 + assert " ".join(result) == "The quick brown fox jumps over the lazy dog" + + +# =========================================================================== +# http_options propagation tests - per request overrides +# =========================================================================== + + +async def test_http_options_headers_propagate(client: Client): + """Custom headers passed via http_options arrive at the activity.""" + new_client, api_tracker = apply_plugin(client, [make_text_response("ok")]) + + async with new_worker(new_client, HttpOptionsWorkflow) as worker: + await new_client.execute_workflow( + HttpOptionsWorkflow.run, + args=["hi", {"headers": {"X-Custom": "test-value"}}], + id=f"gemini-http-headers-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.requests) == 1 + opts = api_tracker.requests[0].http_options_overrides + assert opts is not None + assert opts.headers == {"X-Custom": "test-value"} + + +async def test_http_options_api_version_propagates(client: Client): + """api_version passed via http_options arrives at the activity.""" + new_client, api_tracker = apply_plugin(client, [make_text_response("ok")]) + + async with new_worker(new_client, HttpOptionsWorkflow) as worker: + await new_client.execute_workflow( + HttpOptionsWorkflow.run, + args=["hi", {"api_version": "v1"}], + id=f"gemini-http-version-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.requests) == 1 + opts = api_tracker.requests[0].http_options_overrides + assert opts is not None + assert opts.api_version == "v1" + + +async def test_http_options_base_url_propagates(client: Client): + """base_url passed via http_options arrives at the activity.""" + new_client, api_tracker = apply_plugin(client, [make_text_response("ok")]) + + async with new_worker(new_client, HttpOptionsWorkflow) as worker: + await new_client.execute_workflow( + HttpOptionsWorkflow.run, + args=["hi", {"base_url": "https://custom.example.com"}], + id=f"gemini-http-base-url-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.requests) == 1 + opts = api_tracker.requests[0].http_options_overrides + assert opts is not None + assert opts.base_url == "https://custom.example.com" + + +async def test_http_options_multiple_fields_propagate(client: Client): + """Multiple http_options fields propagate together to the activity.""" + new_client, api_tracker = apply_plugin(client, [make_text_response("ok")]) + + async with new_worker(new_client, HttpOptionsWorkflow) as worker: + await new_client.execute_workflow( + HttpOptionsWorkflow.run, + args=[ + "hi", + { + "api_version": "v1beta", + "headers": {"X-Foo": "bar"}, + "base_url": "https://other.example.com", + }, + ], + id=f"gemini-http-multi-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.requests) == 1 + opts = api_tracker.requests[0].http_options_overrides + assert opts is not None + assert opts.api_version == "v1beta" + assert opts.headers == {"X-Foo": "bar"} + assert opts.base_url == "https://other.example.com" + + +async def test_no_http_options_passes_none(client: Client): + """When no per-request http_options are set, None reaches the activity.""" + new_client, api_tracker = apply_plugin(client, [make_text_response("ok")]) + + async with new_worker(new_client, SimpleGenerateWorkflow) as worker: + await new_client.execute_workflow( + SimpleGenerateWorkflow.run, + "hi", + id=f"gemini-http-none-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.requests) == 1 + assert api_tracker.requests[0].http_options_overrides is None + + +# =========================================================================== +# File upload/download tests +# =========================================================================== + + +async def test_file_upload_str_path(client: Client): + """Upload a file via str path dispatches through the activity.""" + new_client, api_tracker = apply_plugin(client, []) + + async with new_worker(new_client, FileUploadStrWorkflow) as worker: + result = await new_client.execute_workflow( + FileUploadStrWorkflow.run, + "/tmp/test.txt", + id=f"gemini-file-upload-str-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.file_upload_requests) == 1 + req = api_tracker.file_upload_requests[0] + assert req.file_path == "/tmp/test.txt" + assert req.file_bytes is None + assert req.config is not None + assert req.config.display_name == "Test File" + assert result == "files/test-uploaded-file" + + +async def test_file_upload_bytes(client: Client): + """Upload a file via io.BytesIO sends bytes through the activity.""" + new_client, api_tracker = apply_plugin(client, []) + + async with new_worker(new_client, FileUploadBytesWorkflow) as worker: + result = await new_client.execute_workflow( + FileUploadBytesWorkflow.run, + b"hello world", + id=f"gemini-file-upload-bytes-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.file_upload_requests) == 1 + req = api_tracker.file_upload_requests[0] + assert req.file_bytes == b"hello world" + assert req.file_path is None + assert req.config is not None + assert req.config.display_name == "Bytes File" + assert result == "files/test-uploaded-file" + + +async def test_file_download(client: Client): + """Download a file dispatches through the activity and returns bytes.""" + new_client, api_tracker = apply_plugin(client, []) + + async with new_worker(new_client, FileDownloadWorkflow) as worker: + result = await new_client.execute_workflow( + FileDownloadWorkflow.run, + "files/some-file", + id=f"gemini-file-download-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.file_download_requests) == 1 + assert api_tracker.file_download_requests[0].file == "files/some-file" + assert result == b"fake file content" + + +# =========================================================================== +# File search store upload tests +# =========================================================================== + + +async def test_file_search_store_upload(client: Client): + """Upload to file search store dispatches through the activity.""" + new_client, api_tracker = apply_plugin(client, []) + + async with new_worker(new_client, FileSearchStoreUploadWorkflow) as worker: + result = await new_client.execute_workflow( + FileSearchStoreUploadWorkflow.run, + args=["fileSearchStores/my-store", "/tmp/doc.txt"], + id=f"gemini-fss-upload-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.file_search_store_upload_requests) == 1 + req = api_tracker.file_search_store_upload_requests[0] + assert req.file_search_store_name == "fileSearchStores/my-store" + assert req.file_path == "/tmp/doc.txt" + assert req.config is not None + assert req.config.display_name == "Test Doc" + assert result == "operations/test-op" + + +# =========================================================================== +# Multi-turn chat tests +# =========================================================================== + + +async def test_chat_multi_turn(client: Client): + """Multi-turn chat sends multiple requests through the activity.""" + new_client, api_tracker = apply_plugin( + client, + [ + make_text_response("First answer"), + make_text_response("Second answer"), + ], + ) + + async with new_worker(new_client, ChatWorkflow) as worker: + result = await new_client.execute_workflow( + ChatWorkflow.run, + "Hello", + id=f"gemini-chat-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + assert len(api_tracker.requests) == 2 + assert result == ["First answer", "Second answer"] + + +# =========================================================================== +# Full integration test — real activities, mocked client +# =========================================================================== + + +def _apply_plugin_with_mock_client(client: Client, mock_responses: list[str]) -> Client: + """Create a real GeminiPlugin with real activities but a mocked client. + + Unlike ``apply_plugin``, this does NOT replace the activities. The + real ``GeminiApiCaller.activities()`` are registered, exercising the + full activity code path. The underlying ``genai.Client`` HTTP layer + and high-level file methods are mocked so no network calls are made. + """ + gemini = GeminiClient(api_key="fake-test-key") + + call_state = {"index": 0} + + async def fake_async_request(*_args: Any, **_kwargs: Any) -> SdkHttpResponse: + idx = call_state["index"] + call_state["index"] += 1 + if idx >= len(mock_responses): + raise RuntimeError( + f"No more mock responses (called {idx + 1} times, " + f"have {len(mock_responses)})" + ) + return SdkHttpResponse( + headers={"content-type": "application/json"}, + body=mock_responses[idx], + ) + + async def fake_async_request_streamed(*_args: Any, **_kwargs: Any) -> Any: + idx = call_state["index"] + call_state["index"] += 1 + if idx >= len(mock_responses): + raise RuntimeError( + f"No more mock responses (called {idx + 1} times, " + f"have {len(mock_responses)})" + ) + + async def _gen(): + yield SdkHttpResponse( + headers={"content-type": "application/json"}, + body=mock_responses[idx], + ) + + return _gen() + + gemini._api_client.async_request = fake_async_request # type: ignore[assignment] + gemini._api_client.async_request_streamed = fake_async_request_streamed # type: ignore[assignment] + + # Mock file operations at the high-level SDK interface (these are what + # the real activities call). + gemini.aio.files.upload = AsyncMock( # type: ignore[method-assign] + return_value=types.File( + name="files/mock-uploaded", + uri="https://fake.uri/files/mock-uploaded", + size_bytes=42, + ) + ) + gemini.aio.files.download = AsyncMock(return_value=b"mock download content") # type: ignore[method-assign] + gemini.aio.file_search_stores.upload_to_file_search_store = AsyncMock( # type: ignore[method-assign] + return_value=types.UploadToFileSearchStoreOperation.model_construct( + name="operations/mock-op" + ) + ) + + plugin = GeminiPlugin(gemini) + config = client.config() + config["plugins"] = [plugin] + return Client(**config) + + +async def test_full_integration_with_mock_client(client: Client): + """Run a workflow through real activities with a mocked genai.Client. + + This is the only test that exercises the actual activity implementations + in _gemini_activity.py. Every other test uses the GeminiApiCallTracker + which replaces the activities entirely. + """ + # Mock responses are consumed in order by the async_request and + # async_request_streamed mocks. Steps 3-5 (file upload, download, + # store upload) are mocked separately at the SDK level and don't + # consume from this list. + new_client = _apply_plugin_with_mock_client( + client, + [ + make_text_response("Real activity response"), # generate_content + make_text_response("Streamed via real activity"), # generate_content_stream + make_text_response("Grounded RAG answer"), # RAG query with file_search + make_text_response(""), # file_search_stores.delete + ], + ) + + async with new_worker(new_client, FullIntegrationWorkflow) as worker: + result = await new_client.execute_workflow( + FullIntegrationWorkflow.run, + "test prompt", + id=f"gemini-full-integration-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=15), + ) + + assert result["generate"] == "Real activity response" + assert len(result["stream_chunks"]) > 0 + assert "Streamed" in " ".join(result["stream_chunks"]) + assert result["upload_name"] == "files/mock-uploaded" + assert result["download"] == "mock download content" + assert result["fss_upload_op"] == "operations/mock-op" + assert result["rag"] == "Grounded RAG answer" + assert result["store_deleted"] is True + + +async def test_register_files_without_credentials_fails(client: Client): + """register_files raises when no credentials are available.""" + # _apply_plugin_with_mock_client uses api_key auth with no + # extra_credentials, so the activity should raise ValueError. + new_client = _apply_plugin_with_mock_client(client, []) + + async with new_worker(new_client, RegisterFilesWorkflow) as worker: + with pytest.raises(WorkflowFailureError) as exc_info: + await new_client.execute_workflow( + RegisterFilesWorkflow.run, + ["gs://bucket/file.txt"], + id=f"gemini-register-no-creds-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + # The error is nested: WorkflowFailureError → ActivityError → ApplicationError + cause = exc_info.value.cause + while cause.__cause__ is not None: + cause = cause.__cause__ + assert "No credentials available for register_files" in str(cause) + + +# =========================================================================== +# TemporalAsyncClient wiring tests +# =========================================================================== + + +def test_temporal_async_client_has_temporal_files(): + """gemini_client() returns a client with TemporalAsyncFiles.""" + client = gemini_client() + assert isinstance(client, TemporalAsyncClient) + assert isinstance(client.files, TemporalAsyncFiles) + + +def test_temporal_async_client_has_temporal_file_search_stores(): + """gemini_client() returns a client with TemporalAsyncFileSearchStores.""" + client = gemini_client() + assert isinstance(client.file_search_stores, TemporalAsyncFileSearchStores) + + +# =========================================================================== +# Unit tests for TemporalApiClient +# =========================================================================== + + +def test_sync_request_raises(): + """Synchronous request() raises RuntimeError.""" + api_client = TemporalApiClient() + with pytest.raises(RuntimeError, match="Synchronous requests are not supported"): + api_client.request("GET", "/test", {}) + + +def test_sync_request_streamed_raises(): + """Synchronous request_streamed() raises RuntimeError.""" + api_client = TemporalApiClient() + with pytest.raises(RuntimeError, match="Synchronous streaming is not supported"): + api_client.request_streamed("GET", "/test", {}) + + +def test_upload_file_raises(): + """Low-level upload_file() raises NotImplementedError.""" + api_client = TemporalApiClient() + with pytest.raises(NotImplementedError, match="client.files.upload"): + api_client.upload_file() + + +def test_download_file_raises(): + """Low-level download_file() raises NotImplementedError.""" + api_client = TemporalApiClient() + with pytest.raises(NotImplementedError, match="client.files.download"): + api_client.download_file() + + +# =========================================================================== +# Unit tests for activity_as_tool +# =========================================================================== + + +def test_activity_as_tool_bare_function_raises(): + """activity_as_tool rejects a function without @activity.defn.""" + + async def not_an_activity(x: str) -> str: + return x + + with pytest.raises(ApplicationError, match="@activity.defn"): + activity_as_tool(not_an_activity) + + +def test_activity_as_tool_preserves_name(): + """Returned wrapper keeps the original function name.""" + wrapper = activity_as_tool(ToolCallTracker.get_weather) + assert wrapper.__name__ == "get_weather" + + +def test_activity_as_tool_preserves_doc(): + """Returned wrapper keeps the original docstring.""" + wrapper = activity_as_tool(ToolCallTracker.get_weather) + assert wrapper.__doc__ == "Get the weather for a given city." + + +def test_activity_as_tool_preserves_signature(): + """Returned wrapper has the correct parameter signature (self hidden).""" + wrapper = activity_as_tool(ToolCallTracker.get_weather) + sig = inspect.signature(wrapper) + params = list(sig.parameters.keys()) + assert params == ["city"] + + +def test_activity_as_tool_multi_arg_signature(): + """Multi-arg activity preserves all parameter names (self hidden).""" + wrapper = activity_as_tool(ToolCallTracker.get_weather_country) + sig = inspect.signature(wrapper) + params = list(sig.parameters.keys()) + assert params == ["city", "country"] + + +def test_activity_as_tool_is_async_callable(): + """Returned wrapper is an async callable.""" + wrapper = activity_as_tool(ToolCallTracker.get_weather) + assert inspect.iscoroutinefunction(wrapper) + + +# =========================================================================== +# Unit tests for gemini_client +# =========================================================================== + + +def test_gemini_client_vertexai_config(): + """gemini_client() forwards Vertex AI configuration to the TemporalApiClient.""" + result = gemini_client(vertexai=True, project="proj", location="us-central1") + assert result._api_client.vertexai is True + assert result._api_client.project == "proj" + assert result._api_client.location == "us-central1" diff --git a/uv.lock b/uv.lock index 4fba27fc0..e1f5fb443 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-13T21:30:54.856039Z" +exclude-newer = "2026-04-22T20:17:40.986866Z" exclude-newer-span = "P1W" [options.exclude-newer-package] @@ -5161,6 +5161,9 @@ aioboto3 = [ google-adk = [ { name = "google-adk" }, ] +google-gemini = [ + { name = "google-genai" }, +] grpc = [ { name = "grpcio" }, ] @@ -5231,6 +5234,7 @@ dev = [ requires-dist = [ { name = "aioboto3", marker = "extra == 'aioboto3'", specifier = ">=10.4.0" }, { name = "google-adk", marker = "extra == 'google-adk'", specifier = ">=1.27.0,<2" }, + { name = "google-genai", marker = "extra == 'google-gemini'", specifier = ">=1.66.0" }, { name = "grpcio", marker = "extra == 'grpc'", specifier = ">=1.48.2,<2" }, { name = "langgraph", marker = "extra == 'langgraph'", specifier = ">=1.1.0" }, { name = "langsmith", marker = "extra == 'langsmith'", specifier = ">=0.7.0,<0.8" }, @@ -5251,7 +5255,7 @@ requires-dist = [ { name = "types-protobuf", specifier = ">=3.20,<7.0.0" }, { name = "typing-extensions", specifier = ">=4.2.0,<5" }, ] -provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langgraph", "langsmith", "lambda-worker-otel", "aioboto3"] +provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langgraph", "langsmith", "lambda-worker-otel", "aioboto3", "google-gemini"] [package.metadata.requires-dev] dev = [