Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/visual_gen/quickstart_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


def main():
visual_gen = VisualGen(model_path="Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
visual_gen = VisualGen(model="Wan-AI/Wan2.1-T2V-1.3B-Diffusers")
params = VisualGenParams(
height=480,
width=832,
Expand Down
4 changes: 2 additions & 2 deletions examples/visual_gen/visual_gen_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def main():

logger.info(f"Initializing VisualGen: ulysses_size={diffusion_args.parallel.dit_ulysses_size}")
visual_gen = VisualGen(
model_path=args.model_path,
diffusion_args=diffusion_args,
model=args.model_path,
args=diffusion_args,
)

try:
Expand Down
4 changes: 2 additions & 2 deletions examples/visual_gen/visual_gen_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ def main():
f"Initializing VisualGen (LTX2): cfg_size={args.cfg_size}, ulysses_size={args.ulysses_size}"
)
visual_gen = VisualGen(
model_path=args.model_path,
diffusion_args=diffusion_args,
model=args.model_path,
args=diffusion_args,
)

try:
Expand Down
4 changes: 2 additions & 2 deletions examples/visual_gen/visual_gen_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,8 @@ def main():
f"ulysses_size={diffusion_args.parallel.dit_ulysses_size}"
)
visual_gen = VisualGen(
model_path=args.model_path,
diffusion_args=diffusion_args,
model=args.model_path,
args=diffusion_args,
)

try:
Expand Down
4 changes: 2 additions & 2 deletions examples/visual_gen/visual_gen_wan_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ def main():
f"ulysses_size={diffusion_args.parallel.dit_ulysses_size}"
)
visual_gen = VisualGen(
model_path=args.model_path,
diffusion_args=diffusion_args,
model=args.model_path,
args=diffusion_args,
)

try:
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def _setup_vendored_triton_kernels():

from ._common import _init, default_net, default_trtnet, precision
from ._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
from ._torch.visual_gen.config import VisualGenArgs
from ._utils import (default_gpus_per_node, local_mpi_rank, local_mpi_size,
mpi_barrier, mpi_comm, mpi_rank, mpi_world_size,
set_mpi_comm, str_dtype_to_torch, str_dtype_to_trt,
Expand All @@ -133,7 +132,8 @@ def _setup_vendored_triton_kernels():
from .python_plugin import PluginBase
from .sampling_params import SamplingParams
from .version import __version__
from .visual_gen import VisualGen, VisualGenParams
from .visual_gen import (VisualGen, VisualGenArgs, VisualGenError,
VisualGenParams, VisualGenResult)

__all__ = [
'AutoConfig',
Expand Down Expand Up @@ -182,6 +182,8 @@ def _setup_vendored_triton_kernels():
'TrtLlmArgs',
'SamplingParams',
'VisualGenArgs',
'VisualGenError',
'VisualGenResult',
'DisaggregatedParams',
'KvCacheConfig',
'math_utils',
Expand Down
13 changes: 0 additions & 13 deletions tensorrt_llm/_torch/visual_gen/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,19 +420,6 @@ def to_mapping(self) -> Mapping:
"""Derive Mapping from ParallelConfig."""
return self.parallel.to_mapping()

def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary."""
return self.model_dump()

@set_api_status("prototype")
@classmethod
def from_dict(cls, config_dict: Dict[str, Any]) -> "VisualGenArgs":
"""Create from dictionary with automatic nested config parsing.

Unknown fields cause a ValidationError (extra="forbid").
"""
return cls(**config_dict)

@set_api_status("prototype")
@classmethod
def from_yaml(cls, yaml_path: Union[str, Path], **overrides: Any) -> "VisualGenArgs":
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/bench/benchmark/visual_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def visual_gen_command(
# Initialize VisualGen
logger.info(f"Initializing VisualGen ({model_path})")
visual_gen = VisualGen(
model_path=model_path,
diffusion_args=diffusion_args,
model=model_path,
args=diffusion_args,
)

try:
Expand Down
11 changes: 4 additions & 7 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,16 +479,13 @@ def launch_visual_gen_server(
"""
logger.info(f"Initializing VisualGen ({model})")

visual_gen_model = VisualGen(model_path=model,
diffusion_args=diffusion_args)
visual_gen_model = VisualGen(model=model, args=diffusion_args)

n_workers = visual_gen_model.diffusion_args.parallel.n_workers
n_workers = visual_gen_model.args.parallel.n_workers
logger.info(f"World size: {n_workers}")
logger.info(f"CFG size: {visual_gen_model.args.parallel.dit_cfg_size}")
logger.info(
f"CFG size: {visual_gen_model.diffusion_args.parallel.dit_cfg_size}")
logger.info(
f"Ulysses size: {visual_gen_model.diffusion_args.parallel.dit_ulysses_size}"
)
f"Ulysses size: {visual_gen_model.args.parallel.dit_ulysses_size}")

server = OpenAIServer(generator=visual_gen_model,
model=model,
Expand Down
12 changes: 10 additions & 2 deletions tensorrt_llm/visual_gen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .visual_gen import MediaOutput, VisualGen, VisualGenParams
from .args import VisualGenArgs
from .visual_gen import MediaOutput, VisualGen, VisualGenError, VisualGenParams, VisualGenResult

__all__ = ["VisualGen", "VisualGenParams", "MediaOutput"]
__all__ = [
"VisualGen",
"VisualGenArgs",
"VisualGenError",
"VisualGenParams",
"VisualGenResult",
"MediaOutput",
]
26 changes: 26 additions & 0 deletions tensorrt_llm/visual_gen/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Public re-export of VisualGenArgs from its internal home.
# The class definition lives in tensorrt_llm._torch.visual_gen.config so that
# internal code can continue importing it from there without a circular
# dependency. This module provides the canonical public import path:
#
# from tensorrt_llm.visual_gen.args import VisualGenArgs
# from tensorrt_llm.visual_gen import VisualGenArgs
# from tensorrt_llm import VisualGenArgs
from tensorrt_llm._torch.visual_gen.config import VisualGenArgs

__all__ = ["VisualGenArgs"]
64 changes: 41 additions & 23 deletions tensorrt_llm/visual_gen/visual_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
import asyncio
import atexit
import itertools
import os
import queue
import socket
Expand All @@ -29,11 +30,11 @@
import zmq

from tensorrt_llm._torch.visual_gen import DiffusionRequest, DiffusionResponse
from tensorrt_llm._torch.visual_gen.config import VisualGenArgs
from tensorrt_llm._torch.visual_gen.executor import run_diffusion_worker
from tensorrt_llm._torch.visual_gen.output import MediaOutput
from tensorrt_llm.visual_gen.args import VisualGenArgs

__all__ = ["VisualGen", "VisualGenParams", "MediaOutput"]
__all__ = ["VisualGen", "VisualGenParams", "MediaOutput", "VisualGenError", "VisualGenResult"]
from tensorrt_llm.executor.ipc import ZeroMqQueue
from tensorrt_llm.inputs.data import VisualGenInputs
from tensorrt_llm.llmapi.utils import set_api_status
Expand Down Expand Up @@ -64,15 +65,19 @@ def get_ip_address() -> str:
s.close()


class VisualGenError(RuntimeError):
"""Base exception for all VisualGen operations."""


class DiffusionRemoteClient:
"""Client proxy for remote DiffusionExecutor in worker processes."""

def __init__(
self,
diffusion_args: VisualGenArgs,
args: VisualGenArgs,
):
self.diffusion_args = diffusion_args
self.n_workers = diffusion_args.parallel.n_workers
self.args = args
self.n_workers = args.parallel.n_workers

# Setup distributed env
self.master_addr = "127.0.0.1"
Expand Down Expand Up @@ -126,7 +131,7 @@ def __init__(
"master_port": self.master_port,
"request_queue_addr": self.req_addr_connect,
"response_queue_addr": self.resp_addr_connect,
"diffusion_args": self.diffusion_args,
"diffusion_args": self.args,
"req_hmac_key": self.req_hmac_key,
"resp_hmac_key": self.resp_hmac_key,
"log_level": logger.level,
Expand Down Expand Up @@ -384,7 +389,7 @@ async def _wait_ready_async(self):
self.response_event.clear()


class DiffusionGenerationResult:
class VisualGenResult:
"""Future-like object for async generation."""

def __init__(self, request_id: int, executor: DiffusionRemoteClient):
Expand All @@ -394,14 +399,19 @@ def __init__(self, request_id: int, executor: DiffusionRemoteClient):
self._finished = False
self._error = None

@property
def done(self) -> bool:
"""True if the generation has completed (successfully or with error)."""
return self._finished

async def result(self, timeout: Optional[float] = None) -> Any:
"""Wait for and return result (async version).

Can be awaited from any async context (e.g., FastAPI background tasks).
"""
if self._finished:
if self._error:
raise RuntimeError(self._error)
raise VisualGenError(self._error)
return self._result

# Use run_coroutine_threadsafe to execute in the background thread's event loop
Expand All @@ -413,15 +423,26 @@ async def result(self, timeout: Optional[float] = None) -> Any:
# Await the future in the current event loop
response = await asyncio.wrap_future(future)

if response is None:
raise VisualGenError("Generation timed out")

if response.error_msg:
self._error = response.error_msg
self._finished = True
raise RuntimeError(f"Generation failed: {response.error_msg}")
raise VisualGenError(f"Generation failed: {response.error_msg}")

self._result = response.output
self._finished = True
return self._result

def result_sync(self, timeout: Optional[float] = None) -> Any:
"""Blocking wrapper around result() for non-async callers."""
future = asyncio.run_coroutine_threadsafe(
self.result(timeout=timeout),
self.executor._event_loop,
)
return future.result(timeout=timeout)

def cancel(self):
raise NotImplementedError("Cancel request (not yet implemented).")

Expand Down Expand Up @@ -495,18 +516,16 @@ class VisualGen:
@set_api_status("prototype")
def __init__(
self,
model_path: Union[str, Path],
diffusion_args: Optional[VisualGenArgs] = None,
model: Union[str, Path],
args: Optional[VisualGenArgs] = None,
):
self.model_path = str(model_path)
self.diffusion_args = (diffusion_args or VisualGenArgs()).model_copy(
update={"checkpoint_path": self.model_path}
)
self.model = str(model)
self.args = (args or VisualGenArgs()).model_copy(update={"checkpoint_path": self.model})

self.executor = DiffusionRemoteClient(
diffusion_args=self.diffusion_args,
args=self.args,
)
self.req_counter = 0
self._req_counter = itertools.count()

atexit.register(VisualGen._atexit_shutdown, weakref.ref(self))

Expand Down Expand Up @@ -535,25 +554,24 @@ def generate(
# Use the sync wrapper to get result
response = self.executor.await_responses_sync(future.request_id, timeout=None)
if response.error_msg:
raise RuntimeError(f"Generation failed: {response.error_msg}")
raise VisualGenError(f"Generation failed: {response.error_msg}")
return response.output

@set_api_status("prototype")
def generate_async(
self,
inputs: VisualGenInputs,
params: VisualGenParams,
) -> DiffusionGenerationResult:
) -> VisualGenResult:
"""Async generation. Returns immediately with future-like object.

Args:
params: Generation parameters.

Returns:
DiffusionGenerationResult: Call result() to get output dict.
VisualGenResult: Call result() to get output dict.
"""
req_id = self.req_counter
self.req_counter += 1
req_id = next(self._req_counter)

# Normalize inputs to (prompt: List[str], negative_prompt: Optional[str])
# so DiffusionRequest.prompt is always a list.
Expand Down Expand Up @@ -619,7 +637,7 @@ def generate_async(
)

self.executor.enqueue_requests([request])
return DiffusionGenerationResult(req_id, self.executor)
return VisualGenResult(req_id, self.executor)

@staticmethod
def _atexit_shutdown(self_ref):
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/defs/examples/test_visual_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def _generate_ltx2_video(llm_venv, output_subdir, linear_type="default"):
vg_kwargs["parallel"] = {"dit_cfg_size": 2}

diffusion_args = VisualGenArgs(**vg_kwargs)
visual_gen = VisualGen(model_path=model_path, diffusion_args=diffusion_args)
visual_gen = VisualGen(model=model_path, args=diffusion_args)

try:
params = VisualGenParams(
Expand Down
1 change: 1 addition & 0 deletions tests/integration/test_lists/test-db/l0_b200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ l0_b200:
- kv_cache/test_kv_cache_v2_scheduler.py::TestKVCacheV2LoRA::test_lora_eviction
# ------------- Visual Gen tests ---------------
- unittest/_torch/visual_gen/test_visual_gen_args.py
- unittest/_torch/visual_gen/test_warmup.py
- unittest/_torch/visual_gen/test_teacache.py
- unittest/_torch/visual_gen/test_fused_qkv.py
- unittest/_torch/visual_gen/test_quant_ops.py
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest/_torch/visual_gen/test_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def test_diffusion_args_from_dict():
old_world = os.environ.get("WORLD_SIZE")
try:
os.environ["WORLD_SIZE"] = "2"
args = VisualGenArgs.from_dict(config_dict)
args = VisualGenArgs(**config_dict)
assert args.checkpoint_path == "/path/to/model"
assert args.quant_config.quant_algo == QuantAlgo.FP8
assert args.dynamic_weight_quant is True
Expand Down
Loading
Loading