Skip to content
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6226c55
Fix(Flux2): Correct guidance_embed, add guidance support for Klein 9B…
Pfannkuchensack Mar 26, 2026
edbc705
test(flux2): cover Klein guidance gating, scheduler metadata, and rec…
Pfannkuchensack Apr 7, 2026
96d860a
Chore pnpm fix
Pfannkuchensack Apr 7, 2026
1b60bb1
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 9, 2026
e4f46f7
Merge branch 'main' of https://github.com/invoke-ai/InvokeAI into fix…
Pfannkuchensack Apr 9, 2026
f5866d0
Update version to 1.5.0 in flux2_denoise.py
Pfannkuchensack Apr 9, 2026
4d9e2aa
Update condition for rendering ParamFluxScheduler
Pfannkuchensack Apr 9, 2026
06acaa9
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
JPPhoto Apr 9, 2026
1dd40d1
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 9, 2026
26a7798
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
JPPhoto Apr 10, 2026
18d0a23
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 12, 2026
224ae1b
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 16, 2026
95bf0e0
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
JPPhoto Apr 17, 2026
7b42491
Merge remote-tracking branch 'origin/fix/flux2-klein-guidance-and-met…
Pfannkuchensack Apr 20, 2026
d3d915d
Merge remote-tracking branch 'upstream/main' into fix/flux2-klein-gui…
Pfannkuchensack Apr 20, 2026
6cf6368
feat(flux2): add Klein4BBase variant for FLUX.2 Klein Base 4B models
Pfannkuchensack Apr 21, 2026
6ff70c4
Merge remote-tracking branch 'upstream/main' into fix/flux2-klein-gui…
Pfannkuchensack Apr 21, 2026
7de35e3
Change Wrong Comment
Pfannkuchensack Apr 21, 2026
b8c8084
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 22, 2026
c286c01
Merge branch 'main' of https://github.com/invoke-ai/InvokeAI into fix…
Pfannkuchensack Apr 22, 2026
28aa217
Merge branch 'fix/flux2-klein-guidance-and-metadata' of https://githu…
Pfannkuchensack Apr 22, 2026
f6ecc1a
refactor(flux2): remove inert guidance UI/metadata for FLUX.2 Klein
Pfannkuchensack Apr 22, 2026
af0bd50
Chore typegen
Pfannkuchensack Apr 22, 2026
88dcf04
Merge branch 'fix/flux2-klein-guidance-and-metadata' of https://githu…
Pfannkuchensack Apr 23, 2026
3f742d3
fix test
Pfannkuchensack Apr 23, 2026
7612eb1
fix(flux2): skip Guidance metadata recall for legacy FLUX.2 images
Pfannkuchensack Apr 23, 2026
2c0bc59
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions invokeai/app/invocations/flux2_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
"flux2_denoise",
title="FLUX2 Denoise",
tags=["image", "flux", "flux2", "klein", "denoise"],
category="latents",
version="1.4.0",
category="image",
version="1.5.0",
classification=Classification.Prototype,
)
class Flux2DenoiseInvocation(BaseInvocation):
Expand Down Expand Up @@ -101,6 +101,14 @@ class Flux2DenoiseInvocation(BaseInvocation):
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
guidance: float = InputField(
default=4.0,
ge=0,
le=20,
description="Guidance strength for distilled guidance-embedding models. "
"Inert for all current FLUX.2 Klein variants (their guidance_embeds weights are absent/zero); "
"kept for node-graph compatibility and future guidance-embedded models.",
)
cfg_scale: float = InputField(
default=1.0,
description=FieldDescriptions.cfg_scale,
Expand Down Expand Up @@ -467,6 +475,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
txt_ids=txt_ids,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
cfg_scale=cfg_scale_list,
neg_txt=neg_txt,
neg_txt_ids=neg_txt_ids,
Expand Down
4 changes: 2 additions & 2 deletions invokeai/app/invocations/flux2_klein_model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ def _validate_qwen3_encoder_variant(self, context: InvocationContext, main_confi
flux2_variant = main_config.variant

# Validate the variants match
# Klein4B requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B
# Klein4B/Klein4BBase requires Qwen3_4B, Klein9B/Klein9BBase requires Qwen3_8B
expected_qwen3_variant = None
if flux2_variant == Flux2VariantType.Klein4B:
if flux2_variant in (Flux2VariantType.Klein4B, Flux2VariantType.Klein4BBase):
expected_qwen3_variant = Qwen3VariantType.Qwen3_4B
elif flux2_variant in (Flux2VariantType.Klein9B, Flux2VariantType.Klein9BBase):
expected_qwen3_variant = Qwen3VariantType.Qwen3_8B
Expand Down
38 changes: 36 additions & 2 deletions invokeai/backend/flux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,24 @@ def get_flux_ae_params() -> AutoEncoderParams:
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
guidance_embed=False,
),
# Flux2 Klein 4B Base is the undistilled foundation model. It shares the same
# architecture as Klein 4B (distilled) and reports guidance_embeds=False in its
# HF transformer config - classical CFG (external negative pass) is the guidance mechanism.
Flux2VariantType.Klein4BBase: FluxParams(
in_channels=64,
vec_in_dim=2560, # Qwen3-4B hidden size (used for pooled output)
context_in_dim=7680, # 3 layers * 2560 = 7680 for Qwen3-4B
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
# Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
Expand All @@ -149,7 +166,24 @@ def get_flux_ae_params() -> AutoEncoderParams:
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
guidance_embed=False,
),
# Flux2 Klein 9B Base is the undistilled foundation model. It shares the same
# architecture as Klein 9B (distilled) and reports guidance_embeds=False in its
# HF transformer config - the guidance scalar is inert for all Klein variants.
Flux2VariantType.Klein9BBase: FluxParams(
in_channels=64,
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
}

Expand Down
23 changes: 15 additions & 8 deletions invokeai/backend/flux2/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def denoise(
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
cfg_scale: list[float],
# Negative conditioning for CFG
neg_txt: torch.Tensor | None = None,
Expand All @@ -45,7 +46,10 @@ def denoise(
This is a simplified denoise function for FLUX.2 Klein models that uses
the diffusers Flux2Transformer2DModel interface.

Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
All current FLUX.2 Klein variants (4B, 4B Base, 9B, 9B Base) have guidance_embeds=False
in their HF transformer config (or absent/zeroed projection weights), so the guidance
value is passed but effectively ignored by the model. The argument is retained for
node-graph compatibility and future variants that may ship trained guidance projections.
CFG is applied externally using negative conditioning when cfg_scale != 1.0.

Args:
Expand All @@ -56,6 +60,8 @@ def denoise(
txt_ids: Text position IDs tensor.
timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
step_callback: Callback function for progress updates.
guidance: Guidance strength. Inert for all current FLUX.2 Klein variants
(their guidance_embeds projection weights are absent/zero).
cfg_scale: List of CFG scale values per step.
neg_txt: Negative text embeddings for CFG (optional).
neg_txt_ids: Negative text position IDs (optional).
Expand All @@ -76,9 +82,10 @@ def denoise(
img = torch.cat([img, img_cond_seq], dim=1)
img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)

# Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
# We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
# The transformer forward() requires a guidance tensor even when guidance_embeds=False,
# because the Flux2TimestepGuidanceEmbeddings forward signature takes it unconditionally.
# All current Klein variants have guidance_embeds=False, so the value is ignored internally.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

# Use scheduler if provided
use_scheduler = scheduler is not None
Expand Down Expand Up @@ -121,7 +128,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand All @@ -141,7 +148,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand Down Expand Up @@ -222,7 +229,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand All @@ -242,7 +249,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand Down
24 changes: 15 additions & 9 deletions invokeai/backend/model_manager/configs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ def from_base(
return cls(steps=35, cfg_scale=4.5, width=1024, height=1024)
case BaseModelType.Flux2:
# Different defaults based on variant
if variant == Flux2VariantType.Klein9BBase:
# Undistilled base model needs more steps
if variant in (Flux2VariantType.Klein4BBase, Flux2VariantType.Klein9BBase):
# Undistilled base models need more steps
return cls(steps=28, cfg_scale=1.0, width=1024, height=1024)
else:
# Distilled models (Klein 4B, Klein 9B) use fewer steps
Expand Down Expand Up @@ -389,6 +389,7 @@ def _get_flux2_variant(state_dict: dict[str | int, Any]) -> Flux2VariantType | N
# Default to Klein9B - callers use filename heuristics to detect Klein9BBase
return Flux2VariantType.Klein9B
elif context_in_dim == KLEIN_4B_CONTEXT_DIM:
# Default to Klein4B - callers use filename heuristics to detect Klein4BBase
return Flux2VariantType.Klein4B
elif context_in_dim > 4096:
# Unknown FLUX.2 variant, default to 4B
Expand Down Expand Up @@ -573,10 +574,12 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
if variant is None:
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")

# Klein 9B Base and Klein 9B have identical architectures.
# Use filename heuristic to detect the Base (undistilled) variant.
# Base (undistilled) and distilled variants share identical architectures.
# Use filename heuristic to detect the Base variant.
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
return Flux2VariantType.Klein9BBase
if variant == Flux2VariantType.Klein4B and _filename_suggests_base(mod.name):
return Flux2VariantType.Klein4BBase

return variant

Expand Down Expand Up @@ -745,10 +748,12 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
if variant is None:
raise NotAMatchError("unable to determine FLUX.2 model variant from state dict")

# Klein 9B Base and Klein 9B have identical architectures.
# Use filename heuristic to detect the Base (undistilled) variant.
# Base (undistilled) and distilled variants share identical architectures.
# Use filename heuristic to detect the Base variant.
if variant == Flux2VariantType.Klein9B and _filename_suggests_base(mod.name):
return Flux2VariantType.Klein9BBase
if variant == Flux2VariantType.Klein4B and _filename_suggests_base(mod.name):
return Flux2VariantType.Klein4BBase

return variant

Expand Down Expand Up @@ -856,11 +861,10 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
"""Determine the FLUX.2 variant from the transformer config.

FLUX.2 Klein uses Qwen3 text encoder with larger joint_attention_dim:
- Klein 4B: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
- Klein 4B/4B Base: joint_attention_dim = 7680 (3×Qwen3-4B hidden size)
- Klein 9B/9B Base: joint_attention_dim = 12288 (3×Qwen3-8B hidden size)

Klein 9B (distilled) and Klein 9B Base (undistilled) have identical architectures
and both have guidance_embeds=False. We use a filename heuristic to detect Base models.
Distilled and Base variants share identical architectures. We use a filename heuristic to detect Base models.
"""
KLEIN_4B_CONTEXT_DIM = 7680 # 3 × 2560
KLEIN_9B_CONTEXT_DIM = 12288 # 3 × 4096
Expand All @@ -875,6 +879,8 @@ def _get_variant_or_raise(cls, mod: ModelOnDisk) -> Flux2VariantType:
return Flux2VariantType.Klein9BBase
return Flux2VariantType.Klein9B
elif joint_attention_dim == KLEIN_4B_CONTEXT_DIM:
if _filename_suggests_base(mod.name):
return Flux2VariantType.Klein4BBase
return Flux2VariantType.Klein4B
elif joint_attention_dim > 4096:
# Unknown FLUX.2 variant, default to 4B
Expand Down
5 changes: 4 additions & 1 deletion invokeai/backend/model_manager/taxonomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ class Flux2VariantType(str, Enum):
"""FLUX.2 model variants."""

Klein4B = "klein_4b"
"""Flux2 Klein 4B variant using Qwen3 4B text encoder."""
"""Flux2 Klein 4B variant using Qwen3 4B text encoder (distilled)."""

Klein4BBase = "klein_4b_base"
"""Flux2 Klein 4B Base variant - undistilled foundation model using Qwen3 4B text encoder."""

Klein9B = "klein_9b"
"""Flux2 Klein 9B variant using Qwen3 8B text encoder (distilled)."""
Expand Down
2 changes: 1 addition & 1 deletion invokeai/frontend/web/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -19924,7 +19924,7 @@
},
"Flux2VariantType": {
"type": "string",
"enum": ["klein_4b", "klein_9b", "klein_9b_base"],
"enum": ["klein_4b", "klein_4b_base", "klein_9b", "klein_9b_base"],
"title": "Flux2VariantType",
"description": "FLUX.2 model variants."
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ export const ImageMetadataActions = memo((props: Props) => {
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.QwenImageShift} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.CanvasLayers} />
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.RefImages} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.KleinVAEModel} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.KleinQwen3EncoderModel} />
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.LoRAs} />
</Flex>
);
Expand Down
Loading
Loading