Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6226c55
Fix(Flux2): Correct guidance_embed, add guidance support for Klein 9B…
Pfannkuchensack Mar 26, 2026
edbc705
test(flux2): cover Klein guidance gating, scheduler metadata, and rec…
Pfannkuchensack Apr 7, 2026
96d860a
Chore pnpm fix
Pfannkuchensack Apr 7, 2026
1b60bb1
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 9, 2026
e4f46f7
Merge branch 'main' of https://github.com/invoke-ai/InvokeAI into fix…
Pfannkuchensack Apr 9, 2026
f5866d0
Update version to 1.5.0 in flux2_denoise.py
Pfannkuchensack Apr 9, 2026
4d9e2aa
Update condition for rendering ParamFluxScheduler
Pfannkuchensack Apr 9, 2026
06acaa9
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
JPPhoto Apr 9, 2026
1dd40d1
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 9, 2026
26a7798
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
JPPhoto Apr 10, 2026
18d0a23
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 12, 2026
224ae1b
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 16, 2026
95bf0e0
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
JPPhoto Apr 17, 2026
7b42491
Merge remote-tracking branch 'origin/fix/flux2-klein-guidance-and-met…
Pfannkuchensack Apr 20, 2026
d3d915d
Merge remote-tracking branch 'upstream/main' into fix/flux2-klein-gui…
Pfannkuchensack Apr 20, 2026
6cf6368
feat(flux2): add Klein4BBase variant for FLUX.2 Klein Base 4B models
Pfannkuchensack Apr 21, 2026
6ff70c4
Merge remote-tracking branch 'upstream/main' into fix/flux2-klein-gui…
Pfannkuchensack Apr 21, 2026
7de35e3
Change Wrong Comment
Pfannkuchensack Apr 21, 2026
b8c8084
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 22, 2026
c286c01
Merge branch 'main' of https://github.com/invoke-ai/InvokeAI into fix…
Pfannkuchensack Apr 22, 2026
28aa217
Merge branch 'fix/flux2-klein-guidance-and-metadata' of https://githu…
Pfannkuchensack Apr 22, 2026
f6ecc1a
refactor(flux2): remove inert guidance UI/metadata for FLUX.2 Klein
Pfannkuchensack Apr 22, 2026
af0bd50
Chore typegen
Pfannkuchensack Apr 22, 2026
88dcf04
Merge branch 'fix/flux2-klein-guidance-and-metadata' of https://githu…
Pfannkuchensack Apr 23, 2026
3f742d3
fix test
Pfannkuchensack Apr 23, 2026
7612eb1
fix(flux2): skip Guidance metadata recall for legacy FLUX.2 images
Pfannkuchensack Apr 23, 2026
2c0bc59
Merge branch 'main' into fix/flux2-klein-guidance-and-metadata
Pfannkuchensack Apr 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion invokeai/app/invocations/flux2_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
title="FLUX2 Denoise",
tags=["image", "flux", "flux2", "klein", "denoise"],
category="image",
version="1.4.0",
version="1.5.0",
classification=Classification.Prototype,
)
class Flux2DenoiseInvocation(BaseInvocation):
Expand Down Expand Up @@ -101,6 +101,13 @@ class Flux2DenoiseInvocation(BaseInvocation):
description="Negative conditioning tensor. Can be None if cfg_scale is 1.0.",
input=Input.Connection,
)
guidance: float = InputField(
default=4.0,
ge=0,
le=20,
description="The guidance strength. Only used by undistilled models (Klein 9B Base). "
"Ignored by distilled models (Klein 4B, Klein 9B).",
)
cfg_scale: float = InputField(
default=1.0,
description=FieldDescriptions.cfg_scale,
Expand Down Expand Up @@ -467,6 +474,7 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
txt_ids=txt_ids,
timesteps=timesteps,
step_callback=self._build_step_callback(context),
guidance=self.guidance,
cfg_scale=cfg_scale_list,
neg_txt=neg_txt,
neg_txt_ids=neg_txt_ids,
Expand Down
17 changes: 16 additions & 1 deletion invokeai/backend/flux/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,26 @@ def get_flux_ae_params() -> AutoEncoderParams:
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
guidance_embed=False,
),
# Flux2 Klein 9B uses Qwen3 8B text encoder with stacked embeddings from layers [9, 18, 27]
# The context_in_dim is 3 * hidden_size of Qwen3 (3 * 4096 = 12288)
Flux2VariantType.Klein9B: FluxParams(
in_channels=64,
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
# Flux2 Klein 9B Base is the undistilled foundation model with guidance_embeds=True
Flux2VariantType.Klein9BBase: FluxParams(
in_channels=64,
vec_in_dim=4096, # Qwen3-8B hidden size (used for pooled output)
context_in_dim=12288, # 3 layers * 4096 = 12288 for Qwen3-8B
Expand Down
22 changes: 14 additions & 8 deletions invokeai/backend/flux2/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def denoise(
# sampling parameters
timesteps: list[float],
step_callback: Callable[[PipelineIntermediateState], None],
guidance: float,
cfg_scale: list[float],
# Negative conditioning for CFG
neg_txt: torch.Tensor | None = None,
Expand All @@ -45,7 +46,9 @@ def denoise(
This is a simplified denoise function for FLUX.2 Klein models that uses
the diffusers Flux2Transformer2DModel interface.

Note: FLUX.2 Klein has guidance_embeds=False, so no guidance parameter is used.
Distilled models (Klein 4B, Klein 9B) have guidance_embeds=False, so the guidance
value is passed but ignored by the model. Undistilled models (Klein 9B Base) have
guidance_embeds=True and use the guidance value for generation.
CFG is applied externally using negative conditioning when cfg_scale != 1.0.

Args:
Expand All @@ -56,6 +59,8 @@ def denoise(
txt_ids: Text position IDs tensor.
timesteps: List of timesteps for denoising schedule (linear sigmas from 1.0 to 1/n).
step_callback: Callback function for progress updates.
guidance: Guidance strength. Used by undistilled models (Klein 9B Base),
ignored by distilled models (Klein 4B, Klein 9B).
cfg_scale: List of CFG scale values per step.
neg_txt: Negative text embeddings for CFG (optional).
neg_txt_ids: Negative text position IDs (optional).
Expand All @@ -76,9 +81,10 @@ def denoise(
img = torch.cat([img, img_cond_seq], dim=1)
img_ids = torch.cat([img_ids, img_cond_seq_ids], dim=1)

# Klein has guidance_embeds=False, but the transformer forward() still requires a guidance tensor
# We pass a dummy value (1.0) since it won't affect the output when guidance_embeds=False
guidance = torch.full((img.shape[0],), 1.0, device=img.device, dtype=img.dtype)
# The transformer forward() requires a guidance tensor.
# For distilled models (guidance_embeds=False), this value is ignored by the model.
# For undistilled models (Klein 9B Base, guidance_embeds=True), it controls guidance strength.
guidance_vec = torch.full((img.shape[0],), guidance, device=img.device, dtype=img.dtype)

# Use scheduler if provided
use_scheduler = scheduler is not None
Expand Down Expand Up @@ -121,7 +127,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand All @@ -141,7 +147,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand Down Expand Up @@ -222,7 +228,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand All @@ -242,7 +248,7 @@ def denoise(
timestep=t_vec,
img_ids=img_ids,
txt_ids=neg_txt_ids if neg_txt_ids is not None else txt_ids,
guidance=guidance,
guidance=guidance_vec,
return_dict=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ export const ImageMetadataActions = memo((props: Props) => {
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.RefinerSteps} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.CanvasLayers} />
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.RefImages} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.KleinVAEModel} />
<SingleMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.KleinQwen3EncoderModel} />
<CollectionMetadataDatum metadata={metadata} handler={ImageMetadataHandlers.LoRAs} />
</Flex>
);
Expand Down
128 changes: 128 additions & 0 deletions invokeai/frontend/web/src/features/metadata/parsing.test.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import type { AppStore } from 'app/store/store';
import type * as paramsSliceModule from 'features/controlLayers/store/paramsSlice';
import { ImageMetadataHandlers } from 'features/metadata/parsing';
import type * as modelsApiModule from 'services/api/endpoints/models';
import { beforeEach, describe, expect, it, vi } from 'vitest';

// ---------------------------------------------------------------------------
// Module mocks
//
// We are testing only the *gating* logic of the model-related metadata
// handlers (`VAEModel`, `KleinVAEModel`, `KleinQwen3EncoderModel`). The actual
// model lookup goes through `parseModelIdentifier`, which dispatches RTK
// Query thunks. We stub the models endpoint so that any lookup resolves to a
// canned model identifier — the parse step then succeeds and the assertions
// inside each handler become observable.
// ---------------------------------------------------------------------------

let currentBase: string | null = 'flux2';

vi.mock('features/controlLayers/store/paramsSlice', async (importOriginal) => {
const mod = await importOriginal<typeof paramsSliceModule>();
return { ...mod, selectBase: () => currentBase };
});

const fakeModel = (type: 'vae' | 'qwen3_encoder', base: string) => ({
key: `${type}-key`,
hash: 'hash',
name: `Some ${type}`,
base,
type,
});

let nextResolved: ReturnType<typeof fakeModel> = fakeModel('vae', 'flux2');

vi.mock('services/api/endpoints/models', async (importOriginal) => {
const mod = await importOriginal<typeof modelsApiModule>();
return {
...mod,
modelsApi: {
...mod.modelsApi,
endpoints: {
...mod.modelsApi.endpoints,
getModelConfig: { initiate: (key: string) => ({ type: 'rtkq/initiate', key }) },
},
},
};
});

const makeStore = (): AppStore =>
({
dispatch: vi.fn(() => ({
unwrap: () => Promise.resolve(nextResolved),
})),
getState: () => ({}),
}) as unknown as AppStore;

beforeEach(() => {
currentBase = 'flux2';
nextResolved = fakeModel('vae', 'flux2');
});

describe('ImageMetadataHandlers — Klein recall gating', () => {
describe('KleinVAEModel', () => {
it('parses metadata.vae when the current main model is FLUX.2 Klein', async () => {
currentBase = 'flux2';
nextResolved = fakeModel('vae', 'flux2');
const store = makeStore();

const parsed = await ImageMetadataHandlers.KleinVAEModel.parse({ vae: nextResolved }, store);

expect(parsed.key).toBe('vae-key');
expect(parsed.type).toBe('vae');
});

it('rejects parsing when the current main model is not FLUX.2 Klein', async () => {
currentBase = 'sdxl';
nextResolved = fakeModel('vae', 'flux2');
const store = makeStore();

await expect(ImageMetadataHandlers.KleinVAEModel.parse({ vae: nextResolved }, store)).rejects.toThrow();
});
});

describe('KleinQwen3EncoderModel', () => {
it('parses metadata.qwen3_encoder when the current main model is FLUX.2 Klein', async () => {
currentBase = 'flux2';
nextResolved = fakeModel('qwen3_encoder', 'flux2');
const store = makeStore();

const parsed = await ImageMetadataHandlers.KleinQwen3EncoderModel.parse({ qwen3_encoder: nextResolved }, store);

expect(parsed.key).toBe('qwen3_encoder-key');
expect(parsed.type).toBe('qwen3_encoder');
});

it('rejects parsing when the current main model is not FLUX.2 Klein', async () => {
currentBase = 'sdxl';
nextResolved = fakeModel('qwen3_encoder', 'flux2');
const store = makeStore();

await expect(
ImageMetadataHandlers.KleinQwen3EncoderModel.parse({ qwen3_encoder: nextResolved }, store)
).rejects.toThrow();
});
});

describe('VAEModel (generic)', () => {
// The generic VAEModel handler must NOT also fire for FLUX.2 / Z-Image
// images, otherwise the metadata viewer renders duplicate VAE rows next
// to the dedicated KleinVAEModel / ZImageVAEModel handlers.
it.each(['flux2', 'z-image'])('rejects parsing when current base is %s', async (base) => {
currentBase = base;
nextResolved = fakeModel('vae', base);
const store = makeStore();

await expect(ImageMetadataHandlers.VAEModel.parse({ vae: nextResolved }, store)).rejects.toThrow();
});

it('parses successfully for non-Klein, non-Z-Image bases', async () => {
currentBase = 'sdxl';
nextResolved = fakeModel('vae', 'sdxl');
const store = makeStore();

const parsed = await ImageMetadataHandlers.VAEModel.parse({ vae: nextResolved }, store);
expect(parsed.key).toBe('vae-key');
});
});
});
3 changes: 3 additions & 0 deletions invokeai/frontend/web/src/features/metadata/parsing.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,9 @@ const VAEModel: SingleMetadataHandler<ParameterVAEModel> = {
const parsed = await parseModelIdentifier(raw, store, 'vae');
assert(parsed.type === 'vae');
assert(isCompatibleWithMainModel(parsed, store));
// Z-Image and FLUX.2 Klein have dedicated VAE handlers; avoid rendering a duplicate row.
const base = selectBase(store.getState());
assert(base !== 'z-image' && base !== 'flux2', 'VAEModel handler does not apply to Z-Image or FLUX.2 Klein');
return Promise.resolve(parsed);
},
recall: (value, store) => {
Expand Down
Loading
Loading