diff --git a/README.md b/README.md index f33ae2e..e63c1e7 100644 --- a/README.md +++ b/README.md @@ -46,37 +46,12 @@ make test `make dev` installs the project, syncs development dependencies, and sets up [`prek`](https://prek.j178.dev/) Git hooks. -## Run the Example Workflow +## Examples -The repository includes a minimal runnable example at -[`examples/graphon_openai_slim`](examples/graphon_openai_slim). +Runnable examples live under [`examples/`](examples/). -It builds and executes this workflow: - -```text -start -> llm -> output -``` - -To run it: - -```bash -make dev -source .venv/bin/activate -cd examples/graphon_openai_slim -cp .env.example .env -python3 workflow.py "Explain Graphon in one short sentence." -``` - -Before running the example, fill in the required values in `.env`. - -The example currently expects: - -- an `OPENAI_API_KEY` -- a `SLIM_PLUGIN_ID` -- a local `dify-plugin-daemon-slim` setup or equivalent Slim runtime - -For the exact environment variables and runtime notes, see -[examples/graphon_openai_slim/README.md](examples/graphon_openai_slim/README.md). +Each example is self-contained in its own subdirectory and includes its own +setup instructions, environment template, and `workflow.py` entrypoint. ## How Graphon Fits Together @@ -88,8 +63,8 @@ At a high level, Graphon usage looks like this: 4. Run `GraphEngine` and consume emitted graph events. 5. Read final outputs from runtime state. -The bundled example follows exactly that path. The execution loop is centered -around `GraphEngine.run()`: +The examples under [`examples/`](examples/) follow exactly that path. The +execution loop is centered around `GraphEngine.run()`: ```python engine = GraphEngine( @@ -103,10 +78,8 @@ for event in engine.run(): ... ``` -See -[examples/graphon_openai_slim/workflow.py](examples/graphon_openai_slim/workflow.py) -for the full example, including `SlimRuntime`, `SlimPreparedLLM`, graph -construction, input seeding, and streamed output handling. +See [`examples/`](examples/) for the current runnable workflows and their +example-specific setup notes. ## Project Layout @@ -126,8 +99,7 @@ construction, input seeding, and streamed output handling. ## Internal Docs - [CONTRIBUTING.md](CONTRIBUTING.md): contributor workflow, CI, commit/PR rules -- [examples/graphon_openai_slim/README.md](examples/graphon_openai_slim/README.md): - runnable example setup +- [examples/](examples/): runnable examples and per-example setup notes - [src/graphon/model_runtime/README.md](src/graphon/model_runtime/README.md): model runtime overview - [src/graphon/graph_engine/layers/README.md](src/graphon/graph_engine/layers/README.md): diff --git a/examples/graphon_openai_slim/README.md b/examples/graphon_openai_slim/README.md deleted file mode 100644 index 7c1a22d..0000000 --- a/examples/graphon_openai_slim/README.md +++ /dev/null @@ -1,63 +0,0 @@ -# Graphon OpenAI Slim Example - -This example runs a minimal Graphon workflow: - -`start -> llm -> output` - -It uses: - -- Graphon as the Python package import surface -- `dify-plugin-daemon-slim` as the local model runtime bridge -- the Dify OpenAI plugin package -- the `gpt-5.4` model - -## Files - -- `workflow.py`: runnable example script -- `.env.example`: template configuration -- `.env`: local configuration file for this example only - -## Run - -1. Change into this directory: - -```bash -cd examples/graphon_openai_slim -``` - -2. Copy the template: - -```bash -cp .env.example .env -``` - -3. Fill in the required values in `.env`. - -4. Run the example: - -```bash -python3 workflow.py -``` - -The CLI streams LLM text to stdout as chunks arrive. - -You can also pass a custom prompt: - -```bash -python3 workflow.py "Explain graph sparsity in one sentence." -``` - -## Notes - -- `workflow.py` first tries to import an installed `graphon` package. -- If `graphon` is not installed, it falls back to the local repository `src/` - directory automatically. That lets you run the example directly from this - checkout without setting `PYTHONPATH`. -- If your current interpreter is missing runtime dependencies but the repository - `.venv` exists, `workflow.py` will re-exec itself with that local virtualenv - interpreter automatically. -- Path-like variables in `.env` are resolved relative to this example - directory, not relative to your shell's current working directory. -- By default, `SLIM_PLUGIN_FOLDER` resolves to the repository-root - `.slim/plugins` cache. That keeps generated plugin files out of this example - directory while still letting you run `python3 workflow.py` from here. diff --git a/examples/graphon_openai_slim/__init__.py b/examples/graphon_openai_slim/__init__.py deleted file mode 100644 index 8f0f62d..0000000 --- a/examples/graphon_openai_slim/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""OpenAI Slim workflow example for Graphon.""" diff --git a/examples/graphon_openai_slim/.env.example b/examples/openai_slim_minimal/.env.example similarity index 64% rename from examples/graphon_openai_slim/.env.example rename to examples/openai_slim_minimal/.env.example index 8d7719e..8cfe5c6 100644 --- a/examples/graphon_openai_slim/.env.example +++ b/examples/openai_slim_minimal/.env.example @@ -1,7 +1,7 @@ -# Example configuration for `examples/graphon_openai_slim/workflow.py`. +# Example configuration for `examples/openai_slim_minimal/workflow.py`. # -# The example loads `examples/graphon_openai_slim/.env` automatically. Copy this file to `.env` -# in the same directory and fill in the required values. +# The example loads `.env` from this directory automatically. Copy this file to +# `.env` in the same directory and fill in the required values. # Required: OpenAI API key used by the OpenAI Slim plugin. OPENAI_API_KEY= @@ -13,17 +13,16 @@ OPENAI_API_KEY= SLIM_PLUGIN_ID=langgenius/openai:0.3.0@99770a45f77910fe0f64c985524f4fe2294fc6ea25cbf1053ba6bddd7604d850 # Optional: path to the local `dify-plugin-daemon-slim` binary. -# If empty, Graphon will look for `dify-plugin-daemon-slim` in `PATH`. -SLIM_BINARY_PATH= +# Recommended Unix default: a user-local install under `~/.local/bin`. +SLIM_BINARY_PATH=~/.local/bin/dify-plugin-daemon-slim # Optional: provider name inside the plugin package. # For this example we only support OpenAI, so this should stay `openai`. SLIM_PROVIDER=openai # Optional: local folder where Slim stores downloaded/extracted plugins. -# The default points at the repository-root `.slim/plugins` cache so this -# example directory does not accumulate generated plugin code. -SLIM_PLUGIN_FOLDER=../../.slim/plugins +# Recommended Unix default: a user-local plugin cache under `~/.local/share`. +SLIM_PLUGIN_FOLDER=~/.local/share/graphon/slim/plugins # Optional: path to an already unpacked local plugin directory. # If set, Slim uses this directory directly and skips marketplace download. diff --git a/examples/openai_slim_minimal/README.md b/examples/openai_slim_minimal/README.md new file mode 100644 index 0000000..51c8cf4 --- /dev/null +++ b/examples/openai_slim_minimal/README.md @@ -0,0 +1,29 @@ +# OpenAI Slim Minimal Example + +A tiny Graphon workflow: + +`start -> llm -> output` + +## What You Need + +- `workflow.py`: runnable example +- `.env.example`: template settings +- `.env`: your local copy of the template + +## Run + +```bash +cd examples/openai_slim_minimal +cp .env.example .env +python3 workflow.py +``` + +Fill in `.env` before running. The script reads `.env` from this directory. + +## Custom Prompt + +```bash +python3 workflow.py "Explain graph sparsity in one sentence." +``` + +The example streams text to stdout as it arrives. If nothing is streamed, it prints the final answer at the end. diff --git a/examples/openai_slim_minimal/__init__.py b/examples/openai_slim_minimal/__init__.py new file mode 100644 index 0000000..3c2c522 --- /dev/null +++ b/examples/openai_slim_minimal/__init__.py @@ -0,0 +1 @@ +"""Minimal OpenAI Slim workflow example for Graphon.""" diff --git a/examples/graphon_openai_slim/workflow.py b/examples/openai_slim_minimal/workflow.py similarity index 50% rename from examples/graphon_openai_slim/workflow.py rename to examples/openai_slim_minimal/workflow.py index d9146d8..167661b 100644 --- a/examples/graphon_openai_slim/workflow.py +++ b/examples/openai_slim_minimal/workflow.py @@ -4,83 +4,38 @@ python3 workflow.py "Explain Graphon in one short sentence." -The script automatically loads `examples/graphon_openai_slim/.env`. +The script automatically loads `.env` from this example directory. Existing environment variables take precedence over `.env` values. - -Required environment variables: -- `OPENAI_API_KEY` -- `SLIM_PLUGIN_ID` - -Optional environment variables: -- `SLIM_BINARY_PATH` points at a custom `dify-plugin-daemon-slim` binary -- `SLIM_PROVIDER` defaults to `openai` -- `SLIM_PLUGIN_FOLDER` defaults to the repository `.slim/plugins` cache -- `SLIM_PLUGIN_ROOT` points at an already unpacked local plugin directory """ from __future__ import annotations import argparse -import importlib.util -import os import sys import time -from collections.abc import Sequence from pathlib import Path from typing import IO -EXAMPLE_DIR = Path(__file__).resolve().parent -REPO_ROOT = EXAMPLE_DIR.parents[1] -LOCAL_SRC_DIR = REPO_ROOT / "src" -LOCAL_VENV_PYTHON = REPO_ROOT / ".venv" / "bin" / "python" -DEFAULT_ENV_FILE = EXAMPLE_DIR / ".env" -BOOTSTRAP_ENV_VAR = "GRAPHON_EXAMPLE_BOOTSTRAPPED" -RUNTIME_MODULES = ("pydantic", "httpx", "yaml") -MIN_QUOTED_VALUE_LENGTH = 2 - - -def bootstrap_local_python() -> None: - if os.environ.get(BOOTSTRAP_ENV_VAR) == "1": - return - if all(importlib.util.find_spec(module) is not None for module in RUNTIME_MODULES): - return - if not LOCAL_VENV_PYTHON.is_file(): - return - - env = dict(os.environ) - env[BOOTSTRAP_ENV_VAR] = "1" - os.execve( # noqa: S606 - str(LOCAL_VENV_PYTHON), - [str(LOCAL_VENV_PYTHON), str(Path(__file__).resolve()), *sys.argv[1:]], - env, - ) +EXAMPLE_FILE = Path(__file__).resolve() +EXAMPLE_DIR = EXAMPLE_FILE.parent +REPO_ROOT = EXAMPLE_FILE.parents[2] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) -bootstrap_local_python() +from examples import openai_slim_support -if importlib.util.find_spec("graphon") is None and str(LOCAL_SRC_DIR) not in sys.path: - sys.path.insert(0, str(LOCAL_SRC_DIR)) +openai_slim_support.prepare_example_imports(EXAMPLE_FILE) # ruff: noqa: E402 from graphon.entities.graph_init_params import GraphInitParams -from graphon.file.enums import FileType -from graphon.file.models import File from graphon.graph.graph import Graph from graphon.graph_engine.command_channels import InMemoryChannel from graphon.graph_engine.graph_engine import GraphEngine from graphon.graph_events.node import NodeRunStreamChunkEvent from graphon.model_runtime.entities.llm_entities import LLMMode -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - PromptMessageRole, -) -from graphon.model_runtime.slim import ( - SlimConfig, - SlimLocalSettings, - SlimPreparedLLM, - SlimProviderBinding, - SlimRuntime, -) +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.model_runtime.slim import SlimPreparedLLM, SlimRuntime from graphon.nodes.answer.answer_node import AnswerNode from graphon.nodes.answer.entities import AnswerNodeData from graphon.nodes.llm import ( @@ -89,160 +44,17 @@ def bootstrap_local_python() -> None: LLMNodeData, ModelConfig, ) -from graphon.nodes.llm.entities import ContextConfig from graphon.nodes.start import StartNode from graphon.nodes.start.entities import StartNodeData from graphon.runtime.graph_runtime_state import GraphRuntimeState from graphon.runtime.variable_pool import VariablePool from graphon.variables.input_entities import VariableEntity, VariableEntityType -ALLOWED_ENV_VARS: dict[str, str] = { - "OPENAI_API_KEY": "", - "SLIM_PLUGIN_ID": "", - "SLIM_BINARY_PATH": "", - "SLIM_PROVIDER": "openai", - "SLIM_PLUGIN_FOLDER": "../../.slim/plugins", - "SLIM_PLUGIN_ROOT": "", -} -PATH_ENV_VARS = { - "SLIM_BINARY_PATH", - "SLIM_PLUGIN_FOLDER", - "SLIM_PLUGIN_ROOT", -} STREAM_SELECTOR = ("llm", "text") -def load_default_env_file() -> None: - if DEFAULT_ENV_FILE.is_file(): - load_env_file(DEFAULT_ENV_FILE) - - -def load_env_file(path: Path) -> None: - env_dir = path.resolve().parent - for line_number, raw_line in enumerate( - path.read_text(encoding="utf-8").splitlines(), - start=1, - ): - line = raw_line.strip() - if not line or line.startswith("#"): - continue - if line.startswith("export "): - line = line.removeprefix("export ").strip() - if "=" not in line: - msg = f"Invalid .env line {line_number} in {path}: {raw_line}" - raise ValueError(msg) - - key, value = line.split("=", 1) - key = key.strip() - if not key: - msg = f"Invalid .env key on line {line_number} in {path}" - raise ValueError(msg) - if key not in ALLOWED_ENV_VARS: - msg = f"Unsupported .env key {key!r} on line {line_number} in {path}" - raise ValueError(msg) - - os.environ.setdefault( - key, - normalize_env_value( - key, - strip_optional_quotes(value.strip()), - base_dir=env_dir, - ), - ) - - -def strip_optional_quotes(value: str) -> str: - if ( - len(value) >= MIN_QUOTED_VALUE_LENGTH - and value[0] == value[-1] - and value[0] in {'"', "'"} - ): - return value[1:-1] - return value - - -def normalize_env_value(name: str, value: str, *, base_dir: Path) -> str: - if name not in PATH_ENV_VARS or not value: - return value - - path_value = Path(value).expanduser() - if not path_value.is_absolute(): - path_value = (base_dir / path_value).resolve() - else: - path_value = path_value.resolve() - return str(path_value) - - -class PassthroughPromptMessageSerializer: - def serialize( - self, - *, - model_mode: LLMMode, - prompt_messages: Sequence[PromptMessage], - ) -> object: - _ = model_mode - return list(prompt_messages) - - -class TextOnlyFileSaver: - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - _ = data, mime_type, file_type, extension_override - msg = "This example only supports text responses." - raise RuntimeError(msg) - - def save_remote_url(self, url: str, file_type: FileType) -> File: - _ = url, file_type - msg = "This example only supports text responses." - raise RuntimeError(msg) - - -def require_env(name: str) -> str: - value = env_value(name) - if value: - return value - msg = f"{name} is required." - raise ValueError(msg) - - -def env_value(name: str) -> str: - raw_value = os.environ.get(name) - if raw_value is not None: - return raw_value.strip() - return normalize_env_value( - name, - ALLOWED_ENV_VARS[name], - base_dir=EXAMPLE_DIR, - ).strip() - - -def optional_path(name: str) -> Path | None: - value = env_value(name) - return Path(value).expanduser() if value else None - - def build_runtime() -> tuple[SlimRuntime, str]: - provider = env_value("SLIM_PROVIDER") - plugin_folder = Path(env_value("SLIM_PLUGIN_FOLDER")).expanduser() - plugin_root = optional_path("SLIM_PLUGIN_ROOT") - - runtime = SlimRuntime( - SlimConfig( - bindings=[ - SlimProviderBinding( - plugin_id=require_env("SLIM_PLUGIN_ID"), - provider=provider, - plugin_root=plugin_root, - ), - ], - local=SlimLocalSettings(folder=plugin_folder), - ), - ) + runtime, provider = openai_slim_support.build_runtime(example_dir=EXAMPLE_DIR) return runtime, provider @@ -294,14 +106,15 @@ def build_graph( text="{{#start.query#}}", ), ], - context=ContextConfig(enabled=False), ), }, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, model_instance=prepared_llm, - llm_file_saver=TextOnlyFileSaver(), - prompt_message_serializer=PassthroughPromptMessageSerializer(), + llm_file_saver=openai_slim_support.TextOnlyFileSaver(), + prompt_message_serializer=( + openai_slim_support.PassthroughPromptMessageSerializer() + ), ) output_node = AnswerNode( @@ -343,7 +156,7 @@ def _execute_workflow( *, stream_output: IO[str] | None = None, ) -> tuple[str, bool]: - load_default_env_file() + openai_slim_support.load_default_env_file(EXAMPLE_DIR) runtime, provider = build_runtime() workflow_id = "example-start-llm-output" graph_init_params = GraphInitParams( @@ -362,7 +175,12 @@ def _execute_workflow( runtime=runtime, provider=provider, model_name="gpt-5.4", - credentials={"openai_api_key": require_env("OPENAI_API_KEY")}, + credentials={ + "openai_api_key": openai_slim_support.require_env( + "OPENAI_API_KEY", + example_dir=EXAMPLE_DIR, + ), + }, parameters={}, ) graph = build_graph( diff --git a/examples/openai_slim_parallel_translation/.env.example b/examples/openai_slim_parallel_translation/.env.example new file mode 100644 index 0000000..cea2c58 --- /dev/null +++ b/examples/openai_slim_parallel_translation/.env.example @@ -0,0 +1,31 @@ +# Example configuration for +# `examples/openai_slim_parallel_translation/workflow.py`. +# +# The example loads `.env` from this directory automatically. Copy this file to +# `.env` in the same directory and fill in the required values. + +# Required: OpenAI API key used by the OpenAI Slim plugin. +OPENAI_API_KEY= + +# Required: Dify marketplace plugin unique identifier. +# This must match the exact OpenAI plugin package/version you want Slim to use. +# Example format: +# publisher/plugin:version@digest +SLIM_PLUGIN_ID=langgenius/openai:0.3.0@99770a45f77910fe0f64c985524f4fe2294fc6ea25cbf1053ba6bddd7604d850 + +# Optional: path to the local `dify-plugin-daemon-slim` binary. +# Recommended Unix default: a user-local install under `~/.local/bin`. +SLIM_BINARY_PATH=~/.local/bin/dify-plugin-daemon-slim + +# Optional: provider name inside the plugin package. +# For this example we only support OpenAI, so this should stay `openai`. +SLIM_PROVIDER=openai + +# Optional: local folder where Slim stores downloaded/extracted plugins. +# Recommended Unix default: a user-local plugin cache under `~/.local/share`. +SLIM_PLUGIN_FOLDER=~/.local/share/graphon/slim/plugins + +# Optional: path to an already unpacked local plugin directory. +# If set, Slim uses this directory directly and skips marketplace download. +# Leave empty in the normal case. +SLIM_PLUGIN_ROOT= diff --git a/examples/openai_slim_parallel_translation/README.md b/examples/openai_slim_parallel_translation/README.md new file mode 100644 index 0000000..f9b3a3b --- /dev/null +++ b/examples/openai_slim_parallel_translation/README.md @@ -0,0 +1,34 @@ +# OpenAI Slim Parallel Translation Example + +A fan-out / fan-in workflow: + +`start -> 3 llm -> end` + +The `start` node takes `content`. Three LLM nodes translate it into Chinese, +English, and Japanese in parallel. The `end` node returns all three +translations. + +## What You Need + +- `workflow.py`: runnable example +- `.env.example`: template settings +- `.env`: your local copy of the template + +## Run + +```bash +cd examples/openai_slim_parallel_translation +cp .env.example .env +python3 workflow.py "Graph execution is a coordination problem." +``` + +Fill in `.env` before running. The script reads `.env` from this directory. + +## Useful Flags + +```bash +python3 workflow.py --no-stream "Graph execution is a coordination problem." +``` + +By default, the example streams each translation as it becomes available, then +prints the final structured outputs. diff --git a/examples/openai_slim_parallel_translation/__init__.py b/examples/openai_slim_parallel_translation/__init__.py new file mode 100644 index 0000000..5fa2b7c --- /dev/null +++ b/examples/openai_slim_parallel_translation/__init__.py @@ -0,0 +1 @@ +"""Parallel translation OpenAI Slim workflow example for Graphon.""" diff --git a/examples/openai_slim_parallel_translation/workflow.py b/examples/openai_slim_parallel_translation/workflow.py new file mode 100644 index 0000000..90b73db --- /dev/null +++ b/examples/openai_slim_parallel_translation/workflow.py @@ -0,0 +1,271 @@ +"""Parallel translation workflow built with the sequential WorkflowBuilder API. + +Run from this directory: + + python3 workflow.py \ + "Graph execution is a coordination problem." + +The script automatically loads `.env` from this example directory. +""" + +from __future__ import annotations + +import argparse +import sys +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import IO + +EXAMPLE_FILE = Path(__file__).resolve() +EXAMPLE_DIR = EXAMPLE_FILE.parent +REPO_ROOT = EXAMPLE_FILE.parents[2] + +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +from examples import openai_slim_support + +openai_slim_support.prepare_example_imports(EXAMPLE_FILE) + +# ruff: noqa: E402 +from graphon.entities.graph_init_params import GraphInitParams +from graphon.graph.graph import Graph +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_engine.graph_engine import GraphEngine +from graphon.graph_events.node import NodeRunStreamChunkEvent +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.slim import SlimPreparedLLM +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.llm import LLMNodeData, ModelConfig +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.runtime.variable_pool import VariablePool +from graphon.workflow_builder import WorkflowBuilder, paragraph_input, system, user + +TARGET_LANGUAGES: tuple[tuple[str, str, str], ...] = ( + ("translate_zh", "Chinese", "chinese"), + ("translate_en", "English", "english"), + ("translate_ja", "Japanese", "japanese"), +) +STREAM_LABEL_BY_SELECTOR = { + (node_id, "text"): language_name + for node_id, language_name, _output_name in TARGET_LANGUAGES +} + + +@dataclass(slots=True) +class TranslationStreamWriter: + stream_output: IO[str] + seen_selectors: set[tuple[str, str]] = field(default_factory=set) + active_selector: tuple[str, str] | None = None + + def write_event(self, event: object) -> bool: + if not isinstance(event, NodeRunStreamChunkEvent): + return False + + selector = tuple(event.selector) + label = STREAM_LABEL_BY_SELECTOR.get(selector) + if label is None: + return False + + if selector not in self.seen_selectors: + self.stream_output.write(f"{label}: ") + self.seen_selectors.add(selector) + self.active_selector = selector + elif self.active_selector is None: + self.active_selector = selector + + if event.chunk: + self.stream_output.write(event.chunk) + if event.is_final: + self.stream_output.write("\n") + self.active_selector = None + + self.stream_output.flush() + return bool(event.chunk) or event.is_final + + +def build_graph( + *, + provider: str, + prepared_llm: SlimPreparedLLM, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, +) -> Graph: + workflow = WorkflowBuilder( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + prepared_llm=prepared_llm, + ) + + start = workflow.root( + "start", + StartNodeData( + title="Start", + variables=[paragraph_input("content", required=True)], + ), + ) + + model = ModelConfig( + provider=provider, + name="gpt-5.4", + mode=LLMMode.CHAT, + ) + + translation_nodes = [] + for node_id, language_name, _output_name in TARGET_LANGUAGES: + translation_nodes.append( + start.then( + node_id, + LLMNodeData( + title=f"Translate to {language_name}", + model=model, + prompt_template=[ + system( + "Translate the following content to ", + language_name, + ". Return only the translated text.", + ), + user(start.ref("content")), + ], + ), + ), + ) + + output = translation_nodes[0].then( + "output", + EndNodeData( + title="Output", + outputs=[ + node.ref("text").output(output_name) + for node, (_, _, output_name) in zip( + translation_nodes, + TARGET_LANGUAGES, + strict=True, + ) + ], + ), + ) + + for node in translation_nodes[1:]: + node.connect(output) + + return workflow.build() + + +def write_stream_chunk( + event: object, + *, + stream_writer: TranslationStreamWriter, +) -> bool: + return stream_writer.write_event(event) + + +def _execute_workflow( + content: str, + *, + stream_output: IO[str] | None = None, +) -> tuple[dict[str, str], bool]: + openai_slim_support.load_default_env_file(EXAMPLE_DIR) + runtime, provider = openai_slim_support.build_runtime(example_dir=EXAMPLE_DIR) + workflow_id = "parallel-translation-workflow" + graph_init_params = GraphInitParams( + workflow_id=workflow_id, + graph_config={"nodes": [], "edges": []}, + run_context={}, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(), + start_at=time.time(), + ) + graph_runtime_state.variable_pool.add(("start", "content"), content) + + prepared_llm = SlimPreparedLLM( + runtime=runtime, + provider=provider, + model_name="gpt-5.4", + credentials={ + "openai_api_key": openai_slim_support.require_env( + "OPENAI_API_KEY", + example_dir=EXAMPLE_DIR, + ), + }, + parameters={}, + ) + + graph = build_graph( + provider=provider, + prepared_llm=prepared_llm, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + engine = GraphEngine( + workflow_id=workflow_id, + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + ) + + streamed = False + stream_writer = ( + TranslationStreamWriter(stream_output=stream_output) + if stream_output is not None + else None + ) + for event in engine.run(): + if stream_writer is not None and write_stream_chunk( + event, + stream_writer=stream_writer, + ): + streamed = True + + outputs: dict[str, str] = {} + for _node_id, _language_name, output_name in TARGET_LANGUAGES: + value = graph_runtime_state.get_output(output_name) + if not isinstance(value, str): + msg = f"Workflow did not produce output {output_name!r}." + raise TypeError(msg) + outputs[output_name] = value + + return outputs, streamed + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Run a parallel translation workflow built with WorkflowBuilder.", + ) + parser.add_argument( + "content", + nargs="?", + default="Graph execution is a coordination problem.", + help="Input content to translate.", + ) + parser.add_argument( + "--stream", + action=argparse.BooleanOptionalAction, + default=True, + help="Stream translations as they are produced.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + stream_output = sys.stdout if args.stream else None + if args.stream: + sys.stdout.write("Streaming translations:\n") + sys.stdout.flush() + outputs, streamed = _execute_workflow(args.content, stream_output=stream_output) + if args.stream and streamed: + sys.stdout.write("\n") + sys.stdout.write("Structured outputs:\n") + for _node_id, language_name, output_name in TARGET_LANGUAGES: + sys.stdout.write(f"- {language_name}: {outputs[output_name]}\n") + sys.stdout.flush() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/openai_slim_support.py b/examples/openai_slim_support.py new file mode 100644 index 0000000..69f7c8c --- /dev/null +++ b/examples/openai_slim_support.py @@ -0,0 +1,227 @@ +"""Shared bootstrap, environment, and Slim runtime helpers for examples.""" + +from __future__ import annotations + +import importlib +import importlib.util +import os +import sys +from collections.abc import Sequence +from pathlib import Path + +ALLOWED_ENV_VARS: dict[str, str] = { + "OPENAI_API_KEY": "", + "SLIM_PLUGIN_ID": "", + "SLIM_BINARY_PATH": "", + "SLIM_PROVIDER": "openai", + "SLIM_PLUGIN_FOLDER": "../../.slim/plugins", + "SLIM_PLUGIN_ROOT": "", +} +PATH_ENV_VARS = { + "SLIM_BINARY_PATH", + "SLIM_PLUGIN_FOLDER", + "SLIM_PLUGIN_ROOT", +} +BOOTSTRAP_ENV_VAR = "GRAPHON_EXAMPLE_BOOTSTRAPPED" +RUNTIME_MODULES = ("pydantic", "httpx", "yaml") +MIN_QUOTED_VALUE_LENGTH = 2 + + +def repo_root_for(example_file: Path) -> Path: + return example_file.resolve().parents[2] + + +def local_src_dir_for(example_file: Path) -> Path: + return repo_root_for(example_file) / "src" + + +def local_venv_python_for(example_file: Path) -> Path: + return repo_root_for(example_file) / ".venv" / "bin" / "python" + + +def prepare_example_imports( + example_file: Path, + *, + argv: Sequence[str] | None = None, +) -> None: + bootstrap_local_python(example_file, argv=argv) + ensure_local_src_on_path(example_file) + + +def bootstrap_local_python( + example_file: Path, + *, + argv: Sequence[str] | None = None, +) -> None: + if os.environ.get(BOOTSTRAP_ENV_VAR) == "1": + return + if all(importlib.util.find_spec(module) is not None for module in RUNTIME_MODULES): + return + + local_python = local_venv_python_for(example_file) + if not local_python.is_file(): + return + + env = dict(os.environ) + env[BOOTSTRAP_ENV_VAR] = "1" + os.execve( # noqa: S606 + str(local_python), + [ + str(local_python), + str(example_file.resolve()), + *(argv if argv is not None else sys.argv[1:]), + ], + env, + ) + + +def ensure_local_src_on_path(example_file: Path) -> None: + local_src_dir = local_src_dir_for(example_file) + if ( + importlib.util.find_spec("graphon") is None + and str(local_src_dir) not in sys.path + ): + sys.path.insert(0, str(local_src_dir)) + + +def load_default_env_file(example_dir: Path) -> None: + env_file = example_dir / ".env" + if env_file.is_file(): + load_env_file(env_file) + + +def load_env_file(path: Path) -> None: + env_dir = path.resolve().parent + for line_number, raw_line in enumerate( + path.read_text(encoding="utf-8").splitlines(), + start=1, + ): + line = raw_line.strip() + if not line or line.startswith("#"): + continue + if line.startswith("export "): + line = line.removeprefix("export ").strip() + if "=" not in line: + msg = f"Invalid .env line {line_number} in {path}: {raw_line}" + raise ValueError(msg) + + key, value = line.split("=", 1) + key = key.strip() + if not key: + msg = f"Invalid .env key on line {line_number} in {path}" + raise ValueError(msg) + if key not in ALLOWED_ENV_VARS: + msg = f"Unsupported .env key {key!r} on line {line_number} in {path}" + raise ValueError(msg) + + os.environ.setdefault( + key, + normalize_env_value( + key, + strip_optional_quotes(value.strip()), + base_dir=env_dir, + ), + ) + + +def strip_optional_quotes(value: str) -> str: + if ( + len(value) >= MIN_QUOTED_VALUE_LENGTH + and value[0] == value[-1] + and value[0] in {'"', "'"} + ): + return value[1:-1] + return value + + +def normalize_env_value(name: str, value: str, *, base_dir: Path) -> str: + if name not in PATH_ENV_VARS or not value: + return value + + path_value = Path(value).expanduser() + if not path_value.is_absolute(): + path_value = (base_dir / path_value).resolve() + else: + path_value = path_value.resolve() + return str(path_value) + + +def env_value(name: str, *, example_dir: Path) -> str: + raw_value = os.environ.get(name) + if raw_value is not None: + return raw_value.strip() + return normalize_env_value( + name, + ALLOWED_ENV_VARS[name], + base_dir=example_dir, + ).strip() + + +def require_env(name: str, *, example_dir: Path) -> str: + value = env_value(name, example_dir=example_dir) + if value: + return value + msg = f"{name} is required." + raise ValueError(msg) + + +def optional_path(name: str, *, example_dir: Path) -> Path | None: + value = env_value(name, example_dir=example_dir) + return Path(value).expanduser() if value else None + + +def build_runtime(*, example_dir: Path) -> tuple[object, str]: + slim_module = importlib.import_module("graphon.model_runtime.slim") + slim_config = slim_module.SlimConfig + slim_local_settings = slim_module.SlimLocalSettings + slim_provider_binding = slim_module.SlimProviderBinding + slim_runtime = slim_module.SlimRuntime + + provider = env_value("SLIM_PROVIDER", example_dir=example_dir) + plugin_folder = Path( + env_value("SLIM_PLUGIN_FOLDER", example_dir=example_dir), + ).expanduser() + plugin_root = optional_path("SLIM_PLUGIN_ROOT", example_dir=example_dir) + + runtime = slim_runtime( + slim_config( + bindings=[ + slim_provider_binding( + plugin_id=require_env("SLIM_PLUGIN_ID", example_dir=example_dir), + provider=provider, + plugin_root=plugin_root, + ), + ], + local=slim_local_settings(folder=plugin_folder), + ), + ) + return runtime, provider + + +class PassthroughPromptMessageSerializer: + def serialize( + self, + *, + model_mode: object, + prompt_messages: Sequence[object], + ) -> object: + _ = model_mode + return list(prompt_messages) + + +class TextOnlyFileSaver: + def save_binary_string( + self, + data: bytes, + mime_type: str, + file_type: object, + extension_override: str | None = None, + ) -> object: + _ = data, mime_type, file_type, extension_override + msg = "This example only supports text responses." + raise RuntimeError(msg) + + def save_remote_url(self, url: str, file_type: object) -> object: + _ = url, file_type + msg = "This example only supports text responses." + raise RuntimeError(msg) diff --git a/src/graphon/nodes/llm/entities.py b/src/graphon/nodes/llm/entities.py index f0fc6a2..473c171 100644 --- a/src/graphon/nodes/llm/entities.py +++ b/src/graphon/nodes/llm/entities.py @@ -23,7 +23,7 @@ class ModelConfig(BaseModel): class ContextConfig(BaseModel): - enabled: bool + enabled: bool = False variable_selector: list[str] | None = None @@ -72,7 +72,7 @@ class LLMNodeData(BaseNodeData): ) prompt_config: PromptConfig = Field(default_factory=PromptConfig) memory: MemoryConfig | None = None - context: ContextConfig + context: ContextConfig = Field(default_factory=ContextConfig) vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: Mapping[str, Any] | None = None # We used 'structured_output_enabled' in the past, but it's not a good name. diff --git a/src/graphon/workflow_builder.py b/src/graphon/workflow_builder.py new file mode 100644 index 0000000..b470777 --- /dev/null +++ b/src/graphon/workflow_builder.py @@ -0,0 +1,359 @@ +from __future__ import annotations + +from collections.abc import Callable, Mapping, Sequence +from dataclasses import dataclass +from typing import Any, final + +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_init_params import GraphInitParams +from graphon.enums import BuiltinNodeTypes +from graphon.file.enums import FileType +from graphon.file.models import File +from graphon.graph.graph import Graph +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageRole, +) +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.nodes.base.entities import OutputVariableEntity, OutputVariableType +from graphon.nodes.base.node import Node +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.llm import LLMNode, LLMNodeChatModelMessage +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, +) +from graphon.nodes.start import StartNode +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.variables.input_entities import VariableEntity, VariableEntityType + +type TemplatePart = object +type NodeBuilder = Callable[[NodeBuildContext], Node] + + +@dataclass(frozen=True, slots=True) +class NodeBuildContext: + node_id: str + data: BaseNodeData + graph_init_params: GraphInitParams + graph_runtime_state: GraphRuntimeState + + +@dataclass(frozen=True, slots=True) +class NodeOutputRef: + node_id: str + output_name: str + + @property + def selector(self) -> tuple[str, str]: + return (self.node_id, self.output_name) + + def as_template(self) -> str: + return "{{#" + ".".join(self.selector) + "#}}" + + def output( + self, + variable: str | None = None, + *, + value_type: OutputVariableType = OutputVariableType.ANY, + ) -> OutputVariableEntity: + return OutputVariableEntity( + variable=variable or self.output_name, + value_type=value_type, + value_selector=self.selector, + ) + + def __str__(self) -> str: + return self.as_template() + + +@dataclass(frozen=True, slots=True) +class NodeHandle: + _builder: WorkflowBuilder + node_id: str + + def then( + self, + node_id: str, + data: BaseNodeData, + *, + source_handle: str = "source", + ) -> NodeHandle: + return self._builder.add_node( + node_id=node_id, + data=data, + from_node_id=self.node_id, + source_handle=source_handle, + ) + + def connect( + self, + target: NodeHandle, + *, + source_handle: str = "source", + ) -> NodeHandle: + return self._builder.connect( + tail=self, + head=target, + source_handle=source_handle, + ) + + def ref(self, output_name: str) -> NodeOutputRef: + return NodeOutputRef(node_id=self.node_id, output_name=output_name) + + +@final +class _PassthroughPromptMessageSerializer: + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> object: + _ = model_mode + return list(prompt_messages) + + +@final +class _TextOnlyFileSaver: + def save_binary_string( + self, + data: bytes, + mime_type: str, + file_type: FileType, + extension_override: str | None = None, + ) -> File: + _ = data, mime_type, file_type, extension_override + msg = "WorkflowBuilder default saver only supports text outputs." + raise RuntimeError(msg) + + def save_remote_url(self, url: str, file_type: FileType) -> File: + _ = url, file_type + msg = "WorkflowBuilder default saver only supports text outputs." + raise RuntimeError(msg) + + +class WorkflowBuilder: + def __init__( + self, + *, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + prepared_llm: PreparedLLMProtocol | None = None, + llm_file_saver: LLMFileSaver | None = None, + prompt_message_serializer: PromptMessageSerializerProtocol | None = None, + node_builders: Mapping[str, NodeBuilder] | None = None, + ) -> None: + self._graph_init_params = graph_init_params + self._graph_runtime_state = graph_runtime_state + self._prepared_llm = prepared_llm + self._llm_file_saver = llm_file_saver or _TextOnlyFileSaver() + self._prompt_message_serializer = ( + prompt_message_serializer or _PassthroughPromptMessageSerializer() + ) + self._graph_builder = Graph.new() + self._handles: dict[str, NodeHandle] = {} + self._node_builders: dict[str, NodeBuilder] = dict(node_builders or {}) + + def register_node_builder(self, node_type: str, builder: NodeBuilder) -> None: + self._node_builders[node_type] = builder + + def register_node_class( + self, + node_cls: type[Node], + *, + extra_kwargs_factory: ( + Callable[[NodeBuildContext], Mapping[str, Any]] | None + ) = None, + ) -> None: + def _builder(context: NodeBuildContext) -> Node: + extra_kwargs = ( + dict(extra_kwargs_factory(context)) + if extra_kwargs_factory is not None + else {} + ) + return node_cls( + **self._base_node_kwargs(context), + **extra_kwargs, + ) + + self.register_node_builder(node_cls.node_type, _builder) + + def root(self, node_id: str, data: BaseNodeData) -> NodeHandle: + node = self._build_node(node_id=node_id, data=data) + self._graph_builder.add_root(node) + return self._remember_handle(node_id) + + def add_node( + self, + *, + node_id: str, + data: BaseNodeData, + from_node_id: str, + source_handle: str = "source", + ) -> NodeHandle: + node = self._build_node(node_id=node_id, data=data) + self._graph_builder.add_node( + node, + from_node_id=from_node_id, + source_handle=source_handle, + ) + return self._remember_handle(node_id) + + def connect( + self, + *, + tail: NodeHandle, + head: NodeHandle, + source_handle: str = "source", + ) -> NodeHandle: + self._ensure_owned_handle(tail) + self._ensure_owned_handle(head) + self._graph_builder.connect( + tail=tail.node_id, + head=head.node_id, + source_handle=source_handle, + ) + return head + + def handle(self, node_id: str) -> NodeHandle: + try: + return self._handles[node_id] + except KeyError as error: + msg = f"Unknown node id {node_id!r}." + raise KeyError(msg) from error + + def build(self) -> Graph: + return self._graph_builder.build() + + def _remember_handle(self, node_id: str) -> NodeHandle: + handle = NodeHandle(_builder=self, node_id=node_id) + self._handles[node_id] = handle + return handle + + def _build_node(self, *, node_id: str, data: BaseNodeData) -> Node: + context = NodeBuildContext( + node_id=node_id, + data=data, + graph_init_params=self._graph_init_params, + graph_runtime_state=self._graph_runtime_state, + ) + + custom_builder = self._node_builders.get(data.type) + if custom_builder is not None: + return custom_builder(context) + + if data.type == BuiltinNodeTypes.START: + return StartNode(**self._base_node_kwargs(context)) + if data.type == BuiltinNodeTypes.ANSWER: + return AnswerNode(**self._base_node_kwargs(context)) + if data.type == BuiltinNodeTypes.END: + return EndNode(**self._base_node_kwargs(context)) + if data.type == BuiltinNodeTypes.LLM: + if self._prepared_llm is None: + msg = "LLM nodes require `prepared_llm` when using WorkflowBuilder." + raise ValueError(msg) + return LLMNode( + **self._base_node_kwargs(context), + model_instance=self._prepared_llm, + llm_file_saver=self._llm_file_saver, + prompt_message_serializer=self._prompt_message_serializer, + ) + + msg = ( + f"No node builder registered for node type {data.type!r}. " + "Use `register_node_builder()` or `register_node_class()`." + ) + raise ValueError(msg) + + def _base_node_kwargs(self, context: NodeBuildContext) -> dict[str, object]: + return { + "node_id": context.node_id, + "config": {"id": context.node_id, "data": context.data}, + "graph_init_params": context.graph_init_params, + "graph_runtime_state": context.graph_runtime_state, + } + + def _ensure_owned_handle(self, handle: NodeHandle) -> None: + if handle._builder is not self: + msg = "NodeHandle belongs to a different WorkflowBuilder instance." + raise ValueError(msg) + + +def template(*parts: TemplatePart) -> str: + rendered: list[str] = [] + for part in parts: + if isinstance(part, NodeOutputRef): + rendered.append(part.as_template()) + else: + rendered.append(str(part)) + return "".join(rendered) + + +def chat_message( + role: PromptMessageRole, + *parts: TemplatePart, +) -> LLMNodeChatModelMessage: + return LLMNodeChatModelMessage(role=role, text=template(*parts)) + + +def system(*parts: TemplatePart) -> LLMNodeChatModelMessage: + return chat_message(PromptMessageRole.SYSTEM, *parts) + + +def user(*parts: TemplatePart) -> LLMNodeChatModelMessage: + return chat_message(PromptMessageRole.USER, *parts) + + +def assistant(*parts: TemplatePart) -> LLMNodeChatModelMessage: + return chat_message(PromptMessageRole.ASSISTANT, *parts) + + +def input_variable( + variable: str, + *, + label: str | None = None, + variable_type: VariableEntityType = VariableEntityType.PARAGRAPH, + required: bool = False, + **kwargs: Any, +) -> VariableEntity: + return VariableEntity( + variable=variable, + label=label or variable.replace("_", " ").title(), + type=variable_type, + required=required, + **kwargs, + ) + + +def paragraph_input( + variable: str, + *, + label: str | None = None, + required: bool = False, + **kwargs: Any, +) -> VariableEntity: + return input_variable( + variable, + label=label, + variable_type=VariableEntityType.PARAGRAPH, + required=required, + **kwargs, + ) + + +__all__ = [ + "NodeBuildContext", + "NodeHandle", + "NodeOutputRef", + "WorkflowBuilder", + "assistant", + "chat_message", + "input_variable", + "paragraph_input", + "system", + "template", + "user", +] diff --git a/tests/examples/test_graphon_openai_slim.py b/tests/examples/test_openai_slim_examples.py similarity index 76% rename from tests/examples/test_graphon_openai_slim.py rename to tests/examples/test_openai_slim_examples.py index 4eb925c..b6048bc 100644 --- a/tests/examples/test_graphon_openai_slim.py +++ b/tests/examples/test_openai_slim_examples.py @@ -6,11 +6,8 @@ import pytest -from examples.graphon_openai_slim.workflow import ( - ALLOWED_ENV_VARS, - load_env_file, - write_stream_chunk, -) +from examples.openai_slim_minimal.workflow import write_stream_chunk +from examples.openai_slim_support import ALLOWED_ENV_VARS, load_env_file from graphon.enums import BuiltinNodeTypes from graphon.graph_events.node import NodeRunStreamChunkEvent @@ -72,11 +69,18 @@ def test_load_env_file_rejects_unknown_key(tmp_path: Path) -> None: load_env_file(env_file) -def test_env_example_matches_allowed_env_vars() -> None: +@pytest.mark.parametrize( + "example_dir_name", + [ + "openai_slim_minimal", + "openai_slim_parallel_translation", + ], +) +def test_env_example_matches_allowed_env_vars(example_dir_name: str) -> None: env_example = ( Path(__file__).resolve().parents[2] / "examples" - / "graphon_openai_slim" + / example_dir_name / ".env.example" ) keys = { @@ -88,11 +92,30 @@ def test_env_example_matches_allowed_env_vars() -> None: assert keys == set(ALLOWED_ENV_VARS) -def test_env_example_leaves_slim_binary_path_empty() -> None: +@pytest.mark.parametrize( + ("example_dir_name", "expected_binary_path", "expected_plugin_folder"), + [ + ( + "openai_slim_minimal", + "~/.local/bin/dify-plugin-daemon-slim", + "~/.local/share/graphon/slim/plugins", + ), + ( + "openai_slim_parallel_translation", + "~/.local/bin/dify-plugin-daemon-slim", + "~/.local/share/graphon/slim/plugins", + ), + ], +) +def test_env_example_sets_recommended_unix_defaults( + example_dir_name: str, + expected_binary_path: str, + expected_plugin_folder: str, +) -> None: env_example = ( Path(__file__).resolve().parents[2] / "examples" - / "graphon_openai_slim" + / example_dir_name / ".env.example" ) values = { @@ -102,7 +125,8 @@ def test_env_example_leaves_slim_binary_path_empty() -> None: for key, value in [line.removeprefix("export ").split("=", 1)] } - assert not values["SLIM_BINARY_PATH"] + assert values["SLIM_BINARY_PATH"] == expected_binary_path + assert values["SLIM_PLUGIN_FOLDER"] == expected_plugin_folder def test_write_stream_chunk_writes_llm_text_chunks() -> None: diff --git a/tests/examples/test_parallel_translation_workflow.py b/tests/examples/test_parallel_translation_workflow.py new file mode 100644 index 0000000..8e63922 --- /dev/null +++ b/tests/examples/test_parallel_translation_workflow.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from io import StringIO + +from examples.openai_slim_parallel_translation.workflow import ( + TranslationStreamWriter, + write_stream_chunk, +) +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events.node import NodeRunStreamChunkEvent + + +def test_write_stream_chunk_formats_translation_sections() -> None: + output = StringIO() + writer = TranslationStreamWriter(stream_output=output) + + zh_chunk = NodeRunStreamChunkEvent( + id="event-1", + node_id="translate_zh", + node_type=BuiltinNodeTypes.LLM, + selector=["translate_zh", "text"], + chunk="你好", + is_final=False, + ) + zh_final = NodeRunStreamChunkEvent( + id="event-2", + node_id="translate_zh", + node_type=BuiltinNodeTypes.LLM, + selector=["translate_zh", "text"], + chunk="", + is_final=True, + ) + en_chunk = NodeRunStreamChunkEvent( + id="event-3", + node_id="translate_en", + node_type=BuiltinNodeTypes.LLM, + selector=["translate_en", "text"], + chunk="hello", + is_final=False, + ) + en_final = NodeRunStreamChunkEvent( + id="event-4", + node_id="translate_en", + node_type=BuiltinNodeTypes.LLM, + selector=["translate_en", "text"], + chunk="", + is_final=True, + ) + + assert write_stream_chunk(zh_chunk, stream_writer=writer) is True + assert write_stream_chunk(zh_final, stream_writer=writer) is True + assert write_stream_chunk(en_chunk, stream_writer=writer) is True + assert write_stream_chunk(en_final, stream_writer=writer) is True + + assert output.getvalue() == "Chinese: 你好\nEnglish: hello\n" + + +def test_write_stream_chunk_ignores_non_translation_selectors() -> None: + output = StringIO() + writer = TranslationStreamWriter(stream_output=output) + event = NodeRunStreamChunkEvent( + id="event-5", + node_id="output", + node_type=BuiltinNodeTypes.END, + selector=["output", "answer"], + chunk="\n", + is_final=False, + ) + + assert write_stream_chunk(event, stream_writer=writer) is False + assert not output.getvalue() diff --git a/tests/test_workflow_builder.py b/tests/test_workflow_builder.py new file mode 100644 index 0000000..ccd3bb5 --- /dev/null +++ b/tests/test_workflow_builder.py @@ -0,0 +1,134 @@ +from __future__ import annotations + +import time +from typing import cast + +from graphon.entities.graph_init_params import GraphInitParams +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.llm import LLMNodeData, ModelConfig +from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.runtime.variable_pool import VariablePool +from graphon.workflow_builder import ( + WorkflowBuilder, + paragraph_input, + system, + user, +) + + +def test_llm_node_data_defaults_context_to_disabled() -> None: + node_data = LLMNodeData( + model=ModelConfig( + provider="mock", + name="mock-chat", + mode=LLMMode.CHAT, + ), + prompt_template=[system("Translate this text.")], + ) + + assert node_data.context.enabled is False + assert node_data.context.variable_selector is None + + +def test_workflow_builder_builds_parallel_translation_workflow() -> None: + graph_init_params = GraphInitParams( + workflow_id="parallel-translation", + graph_config={"nodes": [], "edges": []}, + run_context={}, + call_depth=0, + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(), + start_at=time.time(), + ) + builder = WorkflowBuilder( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + prepared_llm=cast(PreparedLLMProtocol, object()), + ) + + start = builder.root( + "start", + StartNodeData( + variables=[paragraph_input("content", required=True)], + ), + ) + + translation_model = ModelConfig( + provider="mock", + name="mock-chat", + mode=LLMMode.CHAT, + ) + + chinese = start.then( + "translate_zh", + LLMNodeData( + model=translation_model, + prompt_template=[ + system("Translate the following text to Chinese."), + user(start.ref("content")), + ], + ), + ) + english = start.then( + "translate_en", + LLMNodeData( + model=translation_model, + prompt_template=[ + system("Translate the following text to English."), + user(start.ref("content")), + ], + ), + ) + japanese = start.then( + "translate_ja", + LLMNodeData( + model=translation_model, + prompt_template=[ + system("Translate the following text to Japanese."), + user(start.ref("content")), + ], + ), + ) + + output = chinese.then( + "output", + EndNodeData( + outputs=[ + chinese.ref("text").output("chinese"), + english.ref("text").output("english"), + japanese.ref("text").output("japanese"), + ], + ), + ) + english.connect(output) + japanese.connect(output) + + graph = builder.build() + + assert graph.root_node.id == "start" + assert isinstance(graph.nodes["output"], EndNode) + assert sorted((edge.tail, edge.head) for edge in graph.edges.values()) == [ + ("start", "translate_en"), + ("start", "translate_ja"), + ("start", "translate_zh"), + ("translate_en", "output"), + ("translate_ja", "output"), + ("translate_zh", "output"), + ] + + output_node = cast(EndNode, graph.nodes["output"]) + assert [item.variable for item in output_node.node_data.outputs] == [ + "chinese", + "english", + "japanese", + ] + assert [tuple(item.value_selector) for item in output_node.node_data.outputs] == [ + ("translate_zh", "text"), + ("translate_en", "text"), + ("translate_ja", "text"), + ]