From 2c65e2425766e71e244c7790d8c4ee043993d890 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Fri, 6 Feb 2026 03:38:06 +0100 Subject: [PATCH 1/6] Fix DyPE high-resolution noise by simplifying NTK scaling - Replace complex YaRN 3-band frequency blending with direct NTK interpolation - Update mscale formula: 1.0 + 0.1 * log(s) / sqrt(s) - Add k_t calculation: dype_scale * (sigma ^ dype_exponent) - Interpolate ntk_factor from 1.0 (late steps) to base_ntk (early steps) - Make mscale timestep-dependent for smoother transitions - Add helper functions: find_correction_factor, find_correction_range, linear_ramp_mask - Update tests for new implementation --- invokeai/backend/flux/dype/base.py | 305 +++++++++++++++++---------- invokeai/backend/flux/dype/embed.py | 16 +- invokeai/backend/flux/dype/rope.py | 5 + tests/backend/flux/dype/test_dype.py | 265 +++++++++++++++++++++-- 4 files changed, 466 insertions(+), 125 deletions(-) diff --git a/invokeai/backend/flux/dype/base.py b/invokeai/backend/flux/dype/base.py index 7b25a7f71f3..087fd069d15 100644 --- a/invokeai/backend/flux/dype/base.py +++ b/invokeai/backend/flux/dype/base.py @@ -1,4 +1,8 @@ -"""DyPE base configuration and utilities.""" +"""DyPE base configuration and utilities. + +Implements Dynamic Position Extrapolation (DyPE) with YaRN-style frequency blending. +Based on ComfyUI-DyPE: https://github.com/wildminder/ComfyUI-DyPE +""" import math from dataclasses import dataclass @@ -7,6 +11,16 @@ import torch from torch import Tensor +# YaRN default parameters for FLUX +# These define the frequency correction ranges for blending +YARN_BETA_0 = 1.25 # Low-frequency ratio (β₀) +YARN_BETA_1 = 0.75 # High-frequency ratio (β₁) +YARN_GAMMA_0 = 16.0 # Original position range (γ₀) +YARN_GAMMA_1 = 2.0 # Extended position range (γ₁) + +# FLUX model constants +FLUX_BASE_PE_LEN = 256 # Base position embedding length for FLUX + @dataclass class DyPEConfig: @@ -20,66 +34,149 @@ class DyPEConfig: dype_start_sigma: float = 1.0 # When DyPE decay starts -def get_mscale(scale: float, mscale_factor: float = 1.0) -> float: - """Calculate magnitude scaling factor. +def get_mscale(scale: float) -> float: + """Calculate magnitude scaling factor (mscale). + + Uses the formula from YaRN paper: mscale = 1 + 0.1 * log(s) / sqrt(s) + This provides better attention score normalization for high-resolution. Args: - scale: The resolution scaling factor - mscale_factor: Adjustment factor for the scaling + scale: The resolution scaling factor (NTK scale) Returns: The magnitude scaling factor """ if scale <= 1.0: return 1.0 - return mscale_factor * math.log(scale) + 1.0 + return 1.0 + 0.1 * math.log(scale) / math.sqrt(scale) -def get_timestep_mscale( - scale: float, +def find_correction_factor( + num_rotations: int, + dim: int, + base: int, + max_position_embeddings: int, +) -> float: + """Calculate correction factor for YaRN frequency masking. + + Finds the dimension index where the wavelength equals a given number of rotations. + Used to determine which frequency components need interpolation. + + Args: + num_rotations: Number of rotations to find the factor for + dim: Embedding dimension + base: RoPE base frequency (theta) + max_position_embeddings: Original maximum position embeddings + + Returns: + The dimension index (can be fractional) where wavelength matches num_rotations + """ + # Wavelength at dimension d = 2π * base^(d/dim) + # We want to find d where: wavelength / max_pe_len = num_rotations + # => 2π * base^(d/dim) = num_rotations * max_pe_len + # => d = dim * log(num_rotations * max_pe_len / 2π) / log(base) + return (dim * math.log(max_position_embeddings / (num_rotations * 2.0 * math.pi))) / ( + 2.0 * math.log(base) + ) + + +def find_correction_range( + low_ratio: float, + high_ratio: float, + dim: int, + base: int, + ori_max_pe_len: int, +) -> tuple[float, float]: + """Find the dimension range for frequency correction. + + Determines the range of dimensions that need interpolation between + different frequency scaling methods. + + Args: + low_ratio: Low frequency ratio (beta or gamma low) + high_ratio: High frequency ratio (beta or gamma high) + dim: Embedding dimension + base: RoPE base frequency (theta) + ori_max_pe_len: Original maximum position embedding length + + Returns: + Tuple of (low_dim, high_dim) indices for the correction range + """ + low = max(find_correction_factor(low_ratio, dim, base, ori_max_pe_len), 0.0) + high = min(find_correction_factor(high_ratio, dim, base, ori_max_pe_len), dim - 1.0) + return low, high + + +def linear_ramp_mask( + min_val: float, + max_val: float, + dim: int, + device: torch.device, + dtype: torch.dtype, +) -> Tensor: + """Create linear interpolation mask between frequency bands. + + Creates a tensor that ramps linearly from 0 to 1 between min_val and max_val, + with values clamped to [0, 1] outside that range. + + Args: + min_val: Dimension index where mask starts (becomes > 0) + max_val: Dimension index where mask ends (becomes 1) + dim: Number of frequency components (half the embedding dimension) + device: Target device for the tensor + dtype: Target dtype for the tensor + + Returns: + Tensor of shape (dim,) with values in [0, 1] + """ + if max_val <= min_val: + # Degenerate case: no interpolation range, return step function + indices = torch.arange(dim, device=device, dtype=dtype) + return (indices >= min_val).to(dtype) + + # Linear ramp: (i - min) / (max - min), clamped to [0, 1] + indices = torch.arange(dim, device=device, dtype=dtype) + mask = (indices - min_val) / (max_val - min_val) + return torch.clamp(mask, 0.0, 1.0) + + +def compute_dype_k_t( current_sigma: float, dype_scale: float, dype_exponent: float, dype_start_sigma: float, ) -> float: - """Calculate timestep-dependent magnitude scaling. + """Compute the DyPE timestep modulation factor k_t. The key insight of DyPE: early steps focus on low frequencies (global structure), - late steps on high frequencies (details). This function modulates the scaling - based on the current timestep/sigma. + late steps on high frequencies (details). This function computes k_t which + modulates the YaRN correction parameters (beta, gamma). + + Formula: k_t = dype_scale * (timestep ^ dype_exponent) + + Where timestep is the normalized sigma value (0 at end, 1 at start of denoising). Args: - scale: Resolution scaling factor current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) - dype_scale: DyPE magnitude (λs) - dype_exponent: DyPE decay speed (λt) + dype_scale: DyPE magnitude (λs, 0.0-8.0) + dype_exponent: DyPE decay speed (λt, 0.0-1000.0) dype_start_sigma: Sigma threshold to start decay Returns: - Timestep-modulated scaling factor + k_t modulation factor - larger values mean stronger extrapolation """ - if scale <= 1.0: - return 1.0 - # Normalize sigma to [0, 1] range relative to start_sigma if current_sigma >= dype_start_sigma: - t_normalized = 1.0 + timestep = 1.0 else: - t_normalized = current_sigma / dype_start_sigma - - # Apply exponential decay: stronger extrapolation early, weaker late - # decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late - decay = math.exp(-dype_exponent * (1.0 - t_normalized)) + timestep = current_sigma / dype_start_sigma - # Base mscale from resolution - base_mscale = get_mscale(scale) + # DyPE formula: k_t = scale * (timestep ^ exponent) + # At timestep=1 (early, high sigma): k_t = dype_scale + # At timestep=0 (late, low sigma): k_t = 0 + k_t = dype_scale * (timestep**dype_exponent) - # Interpolate between base_mscale and 1.0 based on decay and dype_scale - # When decay=1 (early): use scaled value - # When decay=0 (late): use base value - scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay - - return scaled_mscale + return k_t def compute_vision_yarn_freqs( @@ -90,18 +187,17 @@ def compute_vision_yarn_freqs( scale_w: float, current_sigma: float, dype_config: DyPEConfig, + ori_max_pe_len: int = FLUX_BASE_PE_LEN, ) -> tuple[Tensor, Tensor]: - """Compute RoPE frequencies using NTK-aware scaling for high-resolution. + """Compute RoPE frequencies using DyPE-modulated NTK scaling. This method extends FLUX's position encoding to handle resolutions beyond - the 1024px training resolution by scaling the base frequency (theta). - - The NTK-aware approach smoothly interpolates frequencies to cover larger - position ranges without breaking the attention patterns. + the 1024px training resolution. Instead of complex YaRN 3-band blending, + it uses a simpler approach that directly modulates the NTK scaling factor + based on the current timestep. - DyPE (Dynamic Position Extrapolation) modulates the NTK scaling based on - the current timestep - stronger extrapolation in early steps (global structure), - weaker in late steps (fine details). + DyPE insight: Early denoising steps focus on global structure (need stronger + extrapolation), late steps focus on fine details (need weaker extrapolation). Args: pos: Position tensor @@ -111,53 +207,69 @@ def compute_vision_yarn_freqs( scale_w: Width scaling factor current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) dype_config: DyPE configuration + ori_max_pe_len: Original maximum position embedding length (unused, kept for API compat) Returns: Tuple of (cos, sin) frequency tensors """ assert dim % 2 == 0 - # Use the larger scale for NTK calculation - scale = max(scale_h, scale_w) - device = pos.device dtype = torch.float64 if device.type != "mps" else torch.float32 - # NTK-aware theta scaling: extends position coverage for high-res - # Formula: theta_scaled = theta * scale^(dim/(dim-2)) - # This increases the wavelength of position encodings proportionally - if scale > 1.0: - ntk_alpha = scale ** (dim / (dim - 2)) - - # Apply timestep-dependent DyPE modulation - # mscale controls how strongly we apply the NTK extrapolation - # Early steps (high sigma): stronger extrapolation for global structure - # Late steps (low sigma): weaker extrapolation for fine details - mscale = get_timestep_mscale( - scale=scale, - current_sigma=current_sigma, - dype_scale=dype_config.dype_scale, - dype_exponent=dype_config.dype_exponent, - dype_start_sigma=dype_config.dype_start_sigma, - ) - - # Modulate NTK alpha by mscale - # When mscale > 1: interpolate towards stronger extrapolation - # When mscale = 1: use base NTK alpha - modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale - scaled_theta = theta * modulated_alpha + # Calculate base scale + linear_scale = max(scale_h, scale_w) + + if linear_scale <= 1.0: + # No scaling needed - use base frequencies + freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim + freqs = 1.0 / (theta**freq_seq) + angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) + cos = torch.cos(angles) + sin = torch.sin(angles) + return cos.to(pos.dtype), sin.to(pos.dtype) + + # === DyPE-modulated NTK scaling === + + # Compute k_t: timestep modulation factor + k_t = compute_dype_k_t( + current_sigma=current_sigma, + dype_scale=dype_config.dype_scale, + dype_exponent=dype_config.dype_exponent, + dype_start_sigma=dype_config.dype_start_sigma, + ) + + # Base NTK factor + base_ntk = linear_scale ** (dim / (dim - 2)) + + # Apply DyPE modulation to NTK factor + # At high k_t (early steps): ntk_factor closer to base_ntk (stronger extrapolation) + # At low k_t (late steps): ntk_factor closer to 1.0 (weaker extrapolation) + # Formula: ntk_factor = 1 + (base_ntk - 1) * k_t / dype_scale + # This interpolates from 1.0 (no scaling) to base_ntk (full NTK) + if dype_config.dype_scale > 0: + blend_factor = min(k_t / dype_config.dype_scale, 1.0) else: - scaled_theta = theta + blend_factor = 1.0 + ntk_factor = 1.0 + (base_ntk - 1.0) * blend_factor + + # Compute scaled theta + scaled_theta = theta * ntk_factor - # Standard RoPE frequency computation + # Compute frequencies freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim freqs = 1.0 / (scaled_theta**freq_seq) - # Compute angles = position * frequency + # Compute angles angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) - cos = torch.cos(angles) - sin = torch.sin(angles) + # Compute mscale (timestep-modulated) + # mscale goes from get_mscale(ntk_factor) at early steps to 1.0 at late steps + mscale_full = get_mscale(ntk_factor) + mscale = 1.0 + (mscale_full - 1.0) * blend_factor + + cos = torch.cos(angles) * mscale + sin = torch.sin(angles) * mscale return cos.to(pos.dtype), sin.to(pos.dtype) @@ -169,11 +281,11 @@ def compute_yarn_freqs( scale: float, current_sigma: float, dype_config: DyPEConfig, + ori_max_pe_len: int = FLUX_BASE_PE_LEN, ) -> tuple[Tensor, Tensor]: - """Compute RoPE frequencies using YARN/NTK method. + """Compute RoPE frequencies using DyPE-modulated NTK scaling. - Uses NTK-aware theta scaling for high-resolution support with - timestep-dependent DyPE modulation. + Uses the same approach as vision_yarn but with a uniform scale factor. Args: pos: Position tensor @@ -182,43 +294,22 @@ def compute_yarn_freqs( scale: Uniform scaling factor current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) dype_config: DyPE configuration + ori_max_pe_len: Original maximum position embedding length (unused) Returns: Tuple of (cos, sin) frequency tensors """ - assert dim % 2 == 0 - - device = pos.device - dtype = torch.float64 if device.type != "mps" else torch.float32 - - # NTK-aware theta scaling with DyPE modulation - if scale > 1.0: - ntk_alpha = scale ** (dim / (dim - 2)) - - # Apply timestep-dependent DyPE modulation - mscale = get_timestep_mscale( - scale=scale, - current_sigma=current_sigma, - dype_scale=dype_config.dype_scale, - dype_exponent=dype_config.dype_exponent, - dype_start_sigma=dype_config.dype_start_sigma, - ) - - # Modulate NTK alpha by mscale - modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale - scaled_theta = theta * modulated_alpha - else: - scaled_theta = theta - - freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim - freqs = 1.0 / (scaled_theta**freq_seq) - - angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) - - cos = torch.cos(angles) - sin = torch.sin(angles) - - return cos.to(pos.dtype), sin.to(pos.dtype) + # Delegate to vision_yarn with uniform scale + return compute_vision_yarn_freqs( + pos=pos, + dim=dim, + theta=theta, + scale_h=scale, + scale_w=scale, + current_sigma=current_sigma, + dype_config=dype_config, + ori_max_pe_len=ori_max_pe_len, + ) def compute_ntk_freqs( diff --git a/invokeai/backend/flux/dype/embed.py b/invokeai/backend/flux/dype/embed.py index ace6a56ab0f..022591c5b71 100644 --- a/invokeai/backend/flux/dype/embed.py +++ b/invokeai/backend/flux/dype/embed.py @@ -3,9 +3,14 @@ import torch from torch import Tensor, nn -from invokeai.backend.flux.dype.base import DyPEConfig +from invokeai.backend.flux.dype.base import DyPEConfig, FLUX_BASE_PE_LEN from invokeai.backend.flux.dype.rope import rope_dype +# FLUX uses 8x8 patch compression with 2x2 packing +# base_resolution / 8 / 2 = base_patch_grid +FLUX_PATCH_SIZE = 8 +FLUX_PACKING_FACTOR = 2 + class DyPEEmbedND(nn.Module): """N-dimensional position embedding with DyPE support. @@ -17,6 +22,7 @@ class DyPEEmbedND(nn.Module): - Maintains step state (current_sigma, target dimensions) - Uses rope_dype() instead of rope() for frequency computation - Applies timestep-dependent scaling for better high-resolution generation + - Calculates base_patch_grid for proper scale computation """ def __init__( @@ -40,6 +46,13 @@ def __init__( self.axes_dim = axes_dim self.dype_config = dype_config + # Calculate base patch grid from base resolution + # FLUX: 1024 / 8 / 2 = 64 patches per side + self.base_patch_grid = dype_config.base_resolution // FLUX_PATCH_SIZE // FLUX_PACKING_FACTOR + + # Original max position embedding length (for YaRN correction calculation) + self.ori_max_pe_len = FLUX_BASE_PE_LEN + # Step state - updated before each denoising step self._current_sigma: float = 1.0 self._target_height: int = 1024 @@ -83,6 +96,7 @@ def forward(self, ids: Tensor) -> Tensor: target_height=self._target_height, target_width=self._target_width, dype_config=self.dype_config, + ori_max_pe_len=self.ori_max_pe_len, ) embeddings.append(axis_emb) diff --git a/invokeai/backend/flux/dype/rope.py b/invokeai/backend/flux/dype/rope.py index f6a1594f6be..a5202f1520d 100644 --- a/invokeai/backend/flux/dype/rope.py +++ b/invokeai/backend/flux/dype/rope.py @@ -5,6 +5,7 @@ from torch import Tensor from invokeai.backend.flux.dype.base import ( + FLUX_BASE_PE_LEN, DyPEConfig, compute_ntk_freqs, compute_vision_yarn_freqs, @@ -20,6 +21,7 @@ def rope_dype( target_height: int, target_width: int, dype_config: DyPEConfig, + ori_max_pe_len: int = FLUX_BASE_PE_LEN, ) -> Tensor: """Compute RoPE with Dynamic Position Extrapolation. @@ -34,6 +36,7 @@ def rope_dype( target_height: Target image height in pixels target_width: Target image width in pixels dype_config: DyPE configuration + ori_max_pe_len: Original maximum position embedding length for YaRN correction Returns: Rotary position embedding tensor with shape suitable for FLUX attention @@ -62,6 +65,7 @@ def rope_dype( scale_w=scale_w, current_sigma=current_sigma, dype_config=dype_config, + ori_max_pe_len=ori_max_pe_len, ) elif method == "yarn": cos, sin = compute_yarn_freqs( @@ -71,6 +75,7 @@ def rope_dype( scale=scale, current_sigma=current_sigma, dype_config=dype_config, + ori_max_pe_len=ori_max_pe_len, ) elif method == "ntk": cos, sin = compute_ntk_freqs( diff --git a/tests/backend/flux/dype/test_dype.py b/tests/backend/flux/dype/test_dype.py index cc0b99011cd..6e36ffd760c 100644 --- a/tests/backend/flux/dype/test_dype.py +++ b/tests/backend/flux/dype/test_dype.py @@ -1,14 +1,24 @@ """Tests for DyPE (Dynamic Position Extrapolation) module.""" +import math + import torch from invokeai.backend.flux.dype.base import ( + FLUX_BASE_PE_LEN, + YARN_BETA_0, + YARN_BETA_1, + YARN_GAMMA_0, + YARN_GAMMA_1, DyPEConfig, + compute_dype_k_t, compute_ntk_freqs, compute_vision_yarn_freqs, compute_yarn_freqs, + find_correction_factor, + find_correction_range, get_mscale, - get_timestep_mscale, + linear_ramp_mask, ) from invokeai.backend.flux.dype.embed import DyPEEmbedND from invokeai.backend.flux.dype.presets import ( @@ -68,36 +78,153 @@ def test_get_mscale_with_scaling(self): assert mscale_2x > 1.0 assert mscale_4x > mscale_2x - def test_get_timestep_mscale_no_scaling(self): - """When scale <= 1.0, timestep_mscale should be 1.0.""" - result = get_timestep_mscale( - scale=1.0, - current_sigma=0.5, + def test_get_mscale_formula(self): + """Test mscale uses the correct YaRN formula: 1 + 0.1 * log(s) / sqrt(s).""" + scale = 4.0 + expected = 1.0 + 0.1 * math.log(scale) / math.sqrt(scale) + actual = get_mscale(scale) + assert abs(actual - expected) < 1e-10 + + +class TestDyPEKt: + """Tests for DyPE k_t calculation.""" + + def test_compute_dype_k_t_formula(self): + """Test k_t uses correct formula: scale * (timestep ^ exponent).""" + dype_scale = 2.0 + dype_exponent = 2.0 + current_sigma = 0.5 # normalized timestep + dype_start_sigma = 1.0 + + k_t = compute_dype_k_t( + current_sigma=current_sigma, + dype_scale=dype_scale, + dype_exponent=dype_exponent, + dype_start_sigma=dype_start_sigma, + ) + + # Expected: 2.0 * (0.5 ^ 2.0) = 2.0 * 0.25 = 0.5 + expected = dype_scale * (current_sigma**dype_exponent) + assert abs(k_t - expected) < 1e-10 + + def test_compute_dype_k_t_at_start(self): + """At timestep=1.0 (start), k_t should equal dype_scale.""" + k_t = compute_dype_k_t( + current_sigma=1.0, dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, ) - assert result == 1.0 + assert abs(k_t - 2.0) < 1e-10 - def test_get_timestep_mscale_high_sigma(self): - """Early steps (high sigma) should have stronger scaling.""" - early_mscale = get_timestep_mscale( - scale=2.0, - current_sigma=1.0, # Early step + def test_compute_dype_k_t_at_end(self): + """At timestep=0.0 (end), k_t should be 0.""" + k_t = compute_dype_k_t( + current_sigma=0.0, dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, ) - late_mscale = get_timestep_mscale( - scale=2.0, - current_sigma=0.1, # Late step + assert k_t == 0.0 + + def test_compute_dype_k_t_decreases_over_time(self): + """k_t should decrease as sigma decreases (denoising progresses).""" + k_t_early = compute_dype_k_t( + current_sigma=1.0, + dype_scale=2.0, + dype_exponent=2.0, + dype_start_sigma=1.0, + ) + k_t_mid = compute_dype_k_t( + current_sigma=0.5, dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, ) + k_t_late = compute_dype_k_t( + current_sigma=0.1, + dype_scale=2.0, + dype_exponent=2.0, + dype_start_sigma=1.0, + ) + + assert k_t_early > k_t_mid > k_t_late - # Early steps should have larger mscale than late steps - assert early_mscale >= late_mscale + +class TestYaRNHelpers: + """Tests for YaRN helper functions.""" + + def test_find_correction_factor(self): + """Test correction factor calculation.""" + # Basic sanity check - should return a reasonable dimension index + factor = find_correction_factor( + num_rotations=1.25, + dim=32, + base=10000, + max_position_embeddings=256, + ) + assert factor >= 0 + assert factor <= 32 + + def test_find_correction_range(self): + """Test correction range returns valid bounds.""" + low, high = find_correction_range( + low_ratio=YARN_BETA_0, + high_ratio=YARN_BETA_1, + dim=28, # half of 56 (FLUX spatial dim) + base=10000, + ori_max_pe_len=FLUX_BASE_PE_LEN, + ) + + # Low should be <= high + assert low <= high + # Both should be in valid range + assert low >= 0 + assert high <= 27 # dim - 1 + + def test_linear_ramp_mask_shape(self): + """Test linear ramp mask has correct shape.""" + mask = linear_ramp_mask( + min_val=5.0, + max_val=15.0, + dim=28, + device=torch.device("cpu"), + dtype=torch.float32, + ) + assert mask.shape == (28,) + + def test_linear_ramp_mask_values(self): + """Test linear ramp mask has correct values.""" + mask = linear_ramp_mask( + min_val=5.0, + max_val=15.0, + dim=20, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Values before min should be 0 + assert mask[0].item() == 0.0 + assert mask[4].item() == 0.0 + + # Values after max should be 1 + assert mask[15].item() == 1.0 + assert mask[19].item() == 1.0 + + # Values in between should be in (0, 1) + assert 0.0 < mask[10].item() < 1.0 + + def test_linear_ramp_mask_degenerate(self): + """Test linear ramp mask handles degenerate case (min >= max).""" + mask = linear_ramp_mask( + min_val=10.0, + max_val=5.0, # max < min + dim=20, + device=torch.device("cpu"), + dtype=torch.float32, + ) + # Should return step function at min_val + assert mask.shape == (20,) class TestRopeDype: @@ -404,3 +531,107 @@ def test_compute_ntk_freqs_shape(self): assert cos.shape == sin.shape assert cos.shape[0] == 1 + + +class TestThreeBandBlending: + """Tests for 3-band frequency blending in YaRN implementation.""" + + def test_different_timesteps_produce_different_freqs(self): + """Different timesteps should produce different frequency outputs.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos_early, sin_early = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + scale_h=2.0, + scale_w=2.0, + current_sigma=1.0, # Early step + dype_config=config, + ) + + cos_late, sin_late = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + scale_h=2.0, + scale_w=2.0, + current_sigma=0.1, # Late step + dype_config=config, + ) + + # Outputs should be different due to DyPE timestep modulation + assert not torch.allclose(cos_early, cos_late) + assert not torch.allclose(sin_early, sin_late) + + def test_no_scaling_returns_base_freqs(self): + """When scale <= 1.0, should return base frequencies without mscale.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos, sin = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + scale_h=1.0, + scale_w=1.0, + current_sigma=0.5, + dype_config=config, + ) + + # Verify shape is correct + assert cos.shape[0] == 1 + assert cos.shape[1] == 16 + + def test_yarn_freqs_matches_vision_yarn_for_uniform_scale(self): + """yarn and vision_yarn should produce same results for uniform scale.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos_vision, sin_vision = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + scale_h=2.0, + scale_w=2.0, + current_sigma=0.5, + dype_config=config, + ) + + cos_yarn, sin_yarn = compute_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + scale=2.0, + current_sigma=0.5, + dype_config=config, + ) + + # Should be very close (same algorithm with same scale) + assert torch.allclose(cos_vision, cos_yarn, atol=1e-6) + assert torch.allclose(sin_vision, sin_yarn, atol=1e-6) + + def test_mscale_applied_to_output(self): + """Verify mscale is applied to cos/sin outputs.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos, sin = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + scale_h=4.0, + scale_w=4.0, + current_sigma=0.5, + dype_config=config, + ) + + # With mscale applied, max values can exceed 1.0 + # (pure cos/sin are in [-1, 1], but mscale > 1 stretches them) + ntk_scale = 4.0 ** (32 / (32 - 2)) + mscale = get_mscale(ntk_scale) + assert mscale > 1.0 + + # The actual values depend on the blending, but shape should be correct + assert cos.shape == (1, 16, 16) # (batch, seq_len, dim/2) From a681decebe885f73aeed1b60e9c2f08c42a0b48b Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Fri, 6 Feb 2026 03:49:38 +0100 Subject: [PATCH 2/6] Chore Ruff --- invokeai/backend/flux/dype/base.py | 4 +--- invokeai/backend/flux/dype/embed.py | 2 +- tests/backend/flux/dype/test_dype.py | 2 -- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/flux/dype/base.py b/invokeai/backend/flux/dype/base.py index 087fd069d15..3cdbde39118 100644 --- a/invokeai/backend/flux/dype/base.py +++ b/invokeai/backend/flux/dype/base.py @@ -75,9 +75,7 @@ def find_correction_factor( # We want to find d where: wavelength / max_pe_len = num_rotations # => 2π * base^(d/dim) = num_rotations * max_pe_len # => d = dim * log(num_rotations * max_pe_len / 2π) / log(base) - return (dim * math.log(max_position_embeddings / (num_rotations * 2.0 * math.pi))) / ( - 2.0 * math.log(base) - ) + return (dim * math.log(max_position_embeddings / (num_rotations * 2.0 * math.pi))) / (2.0 * math.log(base)) def find_correction_range( diff --git a/invokeai/backend/flux/dype/embed.py b/invokeai/backend/flux/dype/embed.py index 022591c5b71..476f8295982 100644 --- a/invokeai/backend/flux/dype/embed.py +++ b/invokeai/backend/flux/dype/embed.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn -from invokeai.backend.flux.dype.base import DyPEConfig, FLUX_BASE_PE_LEN +from invokeai.backend.flux.dype.base import FLUX_BASE_PE_LEN, DyPEConfig from invokeai.backend.flux.dype.rope import rope_dype # FLUX uses 8x8 patch compression with 2x2 packing diff --git a/tests/backend/flux/dype/test_dype.py b/tests/backend/flux/dype/test_dype.py index 6e36ffd760c..7e9c42ed105 100644 --- a/tests/backend/flux/dype/test_dype.py +++ b/tests/backend/flux/dype/test_dype.py @@ -8,8 +8,6 @@ FLUX_BASE_PE_LEN, YARN_BETA_0, YARN_BETA_1, - YARN_GAMMA_0, - YARN_GAMMA_1, DyPEConfig, compute_dype_k_t, compute_ntk_freqs, From bc1818c1943fb439bfded5840a4e074e1de51b43 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Fri, 6 Feb 2026 17:58:41 +0100 Subject: [PATCH 3/6] Fix DyPE noise and left-bias with proper YaRN 3-band blending Aligns InvokeAI's DyPE implementation with ComfyUI-DyPE to fix chroma noise and left-biased composition on wide images. Key changes: - Skip DyPE scaling on axis 0 (time/channel), only scale spatial axes - Implement YaRN 3-band frequency blending (base/linear/NTK) with DyPE-modulated beta/gamma correction masks - Use per-axis linear_scale with global ntk_scale for non-square images - Fix FLUX_BASE_PE_LEN (256 -> 64) to match actual spatial positions - Add timestep-dependent mscale matching ComfyUI's _get_mscale - Use floor/ceil in find_correction_range matching reference impl --- invokeai/backend/flux/dype/base.py | 155 +++++++++++++++++---------- invokeai/backend/flux/dype/embed.py | 8 +- invokeai/backend/flux/dype/rope.py | 25 ++++- tests/backend/flux/dype/test_dype.py | 62 ++++++++--- 4 files changed, 176 insertions(+), 74 deletions(-) diff --git a/invokeai/backend/flux/dype/base.py b/invokeai/backend/flux/dype/base.py index 3cdbde39118..053f83275c9 100644 --- a/invokeai/backend/flux/dype/base.py +++ b/invokeai/backend/flux/dype/base.py @@ -19,7 +19,8 @@ YARN_GAMMA_1 = 2.0 # Extended position range (γ₁) # FLUX model constants -FLUX_BASE_PE_LEN = 256 # Base position embedding length for FLUX +# Base position embedding length = base_resolution / patch_size / packing = 1024/8/2 = 64 +FLUX_BASE_PE_LEN = 64 @dataclass @@ -100,9 +101,9 @@ def find_correction_range( Returns: Tuple of (low_dim, high_dim) indices for the correction range """ - low = max(find_correction_factor(low_ratio, dim, base, ori_max_pe_len), 0.0) - high = min(find_correction_factor(high_ratio, dim, base, ori_max_pe_len), dim - 1.0) - return low, high + low = math.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len)) + high = math.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len)) + return max(low, 0.0), min(high, dim - 1.0) def linear_ramp_mask( @@ -177,59 +178,87 @@ def compute_dype_k_t( return k_t +def compute_timestep_mscale( + ntk_scale: float, + current_sigma: float, + dype_config: DyPEConfig, +) -> float: + """Compute timestep-dependent magnitude scaling. + + Interpolates from aggressive mscale at early steps to 1.0 at late steps. + Matches ComfyUI-DyPE's _get_mscale behavior. + + Args: + ntk_scale: Global NTK scaling factor + current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) + dype_config: DyPE configuration + + Returns: + Timestep-modulated magnitude scaling factor + """ + if ntk_scale <= 1.0: + return 1.0 + + # Aggressive mscale formula (start value at high sigma) + mscale_start = 0.1 * math.log(ntk_scale) + 1.0 + mscale_end = 1.0 + + # Normalize sigma + if current_sigma >= dype_config.dype_start_sigma: + t_norm = 1.0 + else: + t_norm = current_sigma / dype_config.dype_start_sigma + + # Interpolate: full mscale at early steps, 1.0 at late steps + return mscale_end + (mscale_start - mscale_end) * (t_norm**dype_config.dype_exponent) + + def compute_vision_yarn_freqs( pos: Tensor, dim: int, theta: int, - scale_h: float, - scale_w: float, + linear_scale: float, + ntk_scale: float, current_sigma: float, dype_config: DyPEConfig, ori_max_pe_len: int = FLUX_BASE_PE_LEN, ) -> tuple[Tensor, Tensor]: - """Compute RoPE frequencies using DyPE-modulated NTK scaling. - - This method extends FLUX's position encoding to handle resolutions beyond - the 1024px training resolution. Instead of complex YaRN 3-band blending, - it uses a simpler approach that directly modulates the NTK scaling factor - based on the current timestep. + """Compute RoPE frequencies using DyPE-modulated YaRN 3-band blending. - DyPE insight: Early denoising steps focus on global structure (need stronger - extrapolation), late steps focus on fine details (need weaker extrapolation). + Uses three frequency bands (base, linear, NTK) blended via beta/gamma masks. + DyPE modulates the correction ranges so early denoising steps use stronger + extrapolation (global structure) and late steps use weaker (fine details). Args: pos: Position tensor dim: Embedding dimension theta: RoPE base frequency - scale_h: Height scaling factor - scale_w: Width scaling factor + linear_scale: Per-axis linear scaling factor (height or width ratio) + ntk_scale: Global NTK scaling factor (max of height/width ratios) current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) dype_config: DyPE configuration - ori_max_pe_len: Original maximum position embedding length (unused, kept for API compat) + ori_max_pe_len: Original maximum position embedding length for YaRN correction Returns: Tuple of (cos, sin) frequency tensors """ assert dim % 2 == 0 + n_freqs = dim // 2 device = pos.device dtype = torch.float64 if device.type != "mps" else torch.float32 - # Calculate base scale - linear_scale = max(scale_h, scale_w) + linear_scale = max(linear_scale, 1.0) + ntk_scale = max(ntk_scale, 1.0) - if linear_scale <= 1.0: + if ntk_scale <= 1.0: # No scaling needed - use base frequencies freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim freqs = 1.0 / (theta**freq_seq) angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) - cos = torch.cos(angles) - sin = torch.sin(angles) - return cos.to(pos.dtype), sin.to(pos.dtype) + return torch.cos(angles).to(pos.dtype), torch.sin(angles).to(pos.dtype) - # === DyPE-modulated NTK scaling === - - # Compute k_t: timestep modulation factor + # === Step 1: Compute DyPE modulation factor k_t === k_t = compute_dype_k_t( current_sigma=current_sigma, dype_scale=dype_config.dype_scale, @@ -237,34 +266,50 @@ def compute_vision_yarn_freqs( dype_start_sigma=dype_config.dype_start_sigma, ) - # Base NTK factor - base_ntk = linear_scale ** (dim / (dim - 2)) - - # Apply DyPE modulation to NTK factor - # At high k_t (early steps): ntk_factor closer to base_ntk (stronger extrapolation) - # At low k_t (late steps): ntk_factor closer to 1.0 (weaker extrapolation) - # Formula: ntk_factor = 1 + (base_ntk - 1) * k_t / dype_scale - # This interpolates from 1.0 (no scaling) to base_ntk (full NTK) - if dype_config.dype_scale > 0: - blend_factor = min(k_t / dype_config.dype_scale, 1.0) - else: - blend_factor = 1.0 - ntk_factor = 1.0 + (base_ntk - 1.0) * blend_factor + # === Step 2: DyPE-modulate YaRN correction parameters === + beta_0: float = YARN_BETA_0 + beta_1: float = YARN_BETA_1 + gamma_0: float = YARN_GAMMA_0 + gamma_1: float = YARN_GAMMA_1 - # Compute scaled theta - scaled_theta = theta * ntk_factor + if dype_config.enable_dype: + beta_0 = beta_0**k_t + beta_1 = beta_1**k_t + gamma_0 = gamma_0**k_t + gamma_1 = gamma_1**k_t - # Compute frequencies + # === Step 3: Compute three frequency bands === freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim - freqs = 1.0 / (scaled_theta**freq_seq) - # Compute angles + # Band 1: Base frequencies (original, unscaled) + freqs_base = 1.0 / (theta**freq_seq) + + # Band 2: Linearly scaled frequencies (per-axis) + freqs_linear = freqs_base / linear_scale + + # Band 3: NTK-scaled frequencies (global) + new_base = theta * (ntk_scale ** (dim / (dim - 2))) + freqs_ntk = 1.0 / (new_base**freq_seq) + + # === Step 4: Beta mask - blend linear <-> NTK === + low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len) + low = max(0, low) + high = min(n_freqs, high) + mask_beta = 1.0 - linear_ramp_mask(low, high, n_freqs, device, dtype) + freqs = freqs_linear * (1.0 - mask_beta) + freqs_ntk * mask_beta + + # === Step 5: Gamma mask - blend result <-> base === + low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len) + low = max(0, low) + high = min(n_freqs, high) + mask_gamma = 1.0 - linear_ramp_mask(low, high, n_freqs, device, dtype) + freqs = freqs * (1.0 - mask_gamma) + freqs_base * mask_gamma + + # === Step 6: Compute angles === angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) - # Compute mscale (timestep-modulated) - # mscale goes from get_mscale(ntk_factor) at early steps to 1.0 at late steps - mscale_full = get_mscale(ntk_factor) - mscale = 1.0 + (mscale_full - 1.0) * blend_factor + # === Step 7: Apply timestep-dependent mscale === + mscale = compute_timestep_mscale(ntk_scale, current_sigma, dype_config) cos = torch.cos(angles) * mscale sin = torch.sin(angles) * mscale @@ -281,29 +326,29 @@ def compute_yarn_freqs( dype_config: DyPEConfig, ori_max_pe_len: int = FLUX_BASE_PE_LEN, ) -> tuple[Tensor, Tensor]: - """Compute RoPE frequencies using DyPE-modulated NTK scaling. + """Compute RoPE frequencies using DyPE-modulated YaRN with uniform scale. - Uses the same approach as vision_yarn but with a uniform scale factor. + Uses the same 3-band blending as vision_yarn but with a uniform scale + factor for both linear and NTK components. Args: pos: Position tensor dim: Embedding dimension theta: RoPE base frequency - scale: Uniform scaling factor + scale: Uniform scaling factor (used for both linear and NTK) current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) dype_config: DyPE configuration - ori_max_pe_len: Original maximum position embedding length (unused) + ori_max_pe_len: Original maximum position embedding length Returns: Tuple of (cos, sin) frequency tensors """ - # Delegate to vision_yarn with uniform scale return compute_vision_yarn_freqs( pos=pos, dim=dim, theta=theta, - scale_h=scale, - scale_w=scale, + linear_scale=scale, + ntk_scale=scale, current_sigma=current_sigma, dype_config=dype_config, ori_max_pe_len=ori_max_pe_len, diff --git a/invokeai/backend/flux/dype/embed.py b/invokeai/backend/flux/dype/embed.py index 476f8295982..8a062d460a6 100644 --- a/invokeai/backend/flux/dype/embed.py +++ b/invokeai/backend/flux/dype/embed.py @@ -3,7 +3,7 @@ import torch from torch import Tensor, nn -from invokeai.backend.flux.dype.base import FLUX_BASE_PE_LEN, DyPEConfig +from invokeai.backend.flux.dype.base import DyPEConfig from invokeai.backend.flux.dype.rope import rope_dype # FLUX uses 8x8 patch compression with 2x2 packing @@ -50,8 +50,9 @@ def __init__( # FLUX: 1024 / 8 / 2 = 64 patches per side self.base_patch_grid = dype_config.base_resolution // FLUX_PATCH_SIZE // FLUX_PACKING_FACTOR - # Original max position embedding length (for YaRN correction calculation) - self.ori_max_pe_len = FLUX_BASE_PE_LEN + # Original max position embedding length = base patch grid + # This is the number of spatial positions per side at base resolution + self.ori_max_pe_len = self.base_patch_grid # Step state - updated before each denoising step self._current_sigma: float = 1.0 @@ -96,6 +97,7 @@ def forward(self, ids: Tensor) -> Tensor: target_height=self._target_height, target_width=self._target_width, dype_config=self.dype_config, + axis_index=i, ori_max_pe_len=self.ori_max_pe_len, ) embeddings.append(axis_emb) diff --git a/invokeai/backend/flux/dype/rope.py b/invokeai/backend/flux/dype/rope.py index a5202f1520d..8af42b041bd 100644 --- a/invokeai/backend/flux/dype/rope.py +++ b/invokeai/backend/flux/dype/rope.py @@ -21,6 +21,7 @@ def rope_dype( target_height: int, target_width: int, dype_config: DyPEConfig, + axis_index: int = 0, ori_max_pe_len: int = FLUX_BASE_PE_LEN, ) -> Tensor: """Compute RoPE with Dynamic Position Extrapolation. @@ -28,6 +29,9 @@ def rope_dype( This is the core DyPE function that replaces the standard rope() function. It applies resolution-aware and timestep-aware scaling to position embeddings. + DyPE scaling is only applied to spatial axes (axis_index > 0). Axis 0 + (time/channel) always uses plain RoPE to avoid distorting temporal attention. + Args: pos: Position indices tensor dim: Embedding dimension per axis @@ -36,6 +40,8 @@ def rope_dype( target_height: Target image height in pixels target_width: Target image width in pixels dype_config: DyPE configuration + axis_index: Which axis this is (0=time/channel, 1=height, 2=width). + Axis 0 always uses plain RoPE without DyPE scaling. ori_max_pe_len: Original maximum position embedding length for YaRN correction Returns: @@ -43,6 +49,10 @@ def rope_dype( """ assert dim % 2 == 0 + # Axis 0 (time/channel) never gets DyPE scaling - only spatial axes do + if axis_index == 0: + return _rope_base(pos, dim, theta) + # Calculate scaling factors base_res = dype_config.base_resolution scale_h = target_height / base_res @@ -53,6 +63,17 @@ def rope_dype( if not dype_config.enable_dype or scale <= 1.0: return _rope_base(pos, dim, theta) + # Compute per-axis linear_scale and global ntk_scale + # linear_scale: the scale for THIS specific axis (height or width) + # ntk_scale: the global scale = max(scale_h, scale_w) + if axis_index == 1: + linear_scale = scale_h + elif axis_index == 2: + linear_scale = scale_w + else: + linear_scale = scale + ntk_scale = scale + # Select method and compute frequencies method = dype_config.method @@ -61,8 +82,8 @@ def rope_dype( pos=pos, dim=dim, theta=theta, - scale_h=scale_h, - scale_w=scale_w, + linear_scale=linear_scale, + ntk_scale=ntk_scale, current_sigma=current_sigma, dype_config=dype_config, ori_max_pe_len=ori_max_pe_len, diff --git a/tests/backend/flux/dype/test_dype.py b/tests/backend/flux/dype/test_dype.py index 7e9c42ed105..0965bb6a35b 100644 --- a/tests/backend/flux/dype/test_dype.py +++ b/tests/backend/flux/dype/test_dype.py @@ -256,7 +256,7 @@ def test_rope_dype_no_scaling(self): config = DyPEConfig(base_resolution=1024) - # No scaling needed + # No scaling needed (spatial axis) result_no_scale = rope_dype( pos=pos, dim=dim, @@ -265,9 +265,10 @@ def test_rope_dype_no_scaling(self): target_height=1024, target_width=1024, dype_config=config, + axis_index=1, # spatial axis ) - # With scaling + # With scaling (spatial axis) result_with_scale = rope_dype( pos=pos, dim=dim, @@ -276,11 +277,44 @@ def test_rope_dype_no_scaling(self): target_height=2048, target_width=2048, dype_config=config, + axis_index=1, # spatial axis ) # Results should be different when scaling is applied assert not torch.allclose(result_no_scale, result_with_scale) + def test_rope_dype_axis0_always_base(self): + """Axis 0 (time/channel) should always return base RoPE regardless of scaling.""" + pos = torch.arange(16).unsqueeze(0).float() + dim = 16 + theta = 10000 + + config = DyPEConfig(base_resolution=1024) + + result_no_scale = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=0.5, + target_height=1024, + target_width=1024, + dype_config=config, + axis_index=0, + ) + result_with_scale = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=0.5, + target_height=4096, + target_width=4096, + dype_config=config, + axis_index=0, + ) + + # Axis 0 should be identical regardless of target resolution + assert torch.allclose(result_no_scale, result_with_scale) + class TestDyPEEmbedND: """Tests for DyPEEmbedND module.""" @@ -489,8 +523,8 @@ def test_compute_vision_yarn_freqs_shape(self): pos=pos, dim=32, theta=10000, - scale_h=2.0, - scale_w=2.0, + linear_scale=2.0, + ntk_scale=2.0, current_sigma=0.5, dype_config=config, ) @@ -543,8 +577,8 @@ def test_different_timesteps_produce_different_freqs(self): pos=pos, dim=32, theta=10000, - scale_h=2.0, - scale_w=2.0, + linear_scale=2.0, + ntk_scale=2.0, current_sigma=1.0, # Early step dype_config=config, ) @@ -553,8 +587,8 @@ def test_different_timesteps_produce_different_freqs(self): pos=pos, dim=32, theta=10000, - scale_h=2.0, - scale_w=2.0, + linear_scale=2.0, + ntk_scale=2.0, current_sigma=0.1, # Late step dype_config=config, ) @@ -572,8 +606,8 @@ def test_no_scaling_returns_base_freqs(self): pos=pos, dim=32, theta=10000, - scale_h=1.0, - scale_w=1.0, + linear_scale=1.0, + ntk_scale=1.0, current_sigma=0.5, dype_config=config, ) @@ -591,8 +625,8 @@ def test_yarn_freqs_matches_vision_yarn_for_uniform_scale(self): pos=pos, dim=32, theta=10000, - scale_h=2.0, - scale_w=2.0, + linear_scale=2.0, + ntk_scale=2.0, current_sigma=0.5, dype_config=config, ) @@ -619,8 +653,8 @@ def test_mscale_applied_to_output(self): pos=pos, dim=32, theta=10000, - scale_h=4.0, - scale_w=4.0, + linear_scale=4.0, + ntk_scale=4.0, current_sigma=0.5, dype_config=config, ) From 70c7cc66a7313dfaacba9c1fa533088aa4ad9948 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Sat, 7 Feb 2026 21:21:49 +0100 Subject: [PATCH 4/6] Fix DyPE Vision YaRN 3-band frequency blending to match ComfyUI-DyPE The previous implementation used a simplified NTK-only approach that destroyed positional information at high resolutions. This restores proper YaRN 3-band blending (base/linear/NTK) with correct mask inversion and timestep-dependent mscale, matching ComfyUI-DyPE's get_1d_dype_yarn_pos_embed implementation. --- invokeai/backend/flux/dype/base.py | 86 ++++++++++++++++++------------ invokeai/backend/flux/dype/rope.py | 4 ++ 2 files changed, 56 insertions(+), 34 deletions(-) diff --git a/invokeai/backend/flux/dype/base.py b/invokeai/backend/flux/dype/base.py index 053f83275c9..7cf950f25b9 100644 --- a/invokeai/backend/flux/dype/base.py +++ b/invokeai/backend/flux/dype/base.py @@ -222,12 +222,22 @@ def compute_vision_yarn_freqs( current_sigma: float, dype_config: DyPEConfig, ori_max_pe_len: int = FLUX_BASE_PE_LEN, + mscale_override: float | None = None, ) -> tuple[Tensor, Tensor]: """Compute RoPE frequencies using DyPE-modulated YaRN 3-band blending. - Uses three frequency bands (base, linear, NTK) blended via beta/gamma masks. - DyPE modulates the correction ranges so early denoising steps use stronger - extrapolation (global structure) and late steps use weaker (fine details). + Implements the Vision YaRN method from ComfyUI-DyPE: three frequency bands + (base, linear, NTK) are blended using beta and gamma correction masks. + DyPE modulates the correction parameters via k_t = scale * (sigma ^ exponent). + + The three bands: + - freqs_base: Original RoPE frequencies (unchanged) + - freqs_linear: Position Interpolation (freqs_base / linear_scale) + - freqs_ntk: NTK-scaled (theta * ntk_alpha for new base) + + Blending order: + 1. Beta mask blends linear <-> NTK (low freqs -> NTK, high freqs -> linear) + 2. Gamma mask blends result <-> base (low freqs -> base, high freqs -> keep blend) Args: pos: Position tensor @@ -237,13 +247,13 @@ def compute_vision_yarn_freqs( ntk_scale: Global NTK scaling factor (max of height/width ratios) current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) dype_config: DyPE configuration - ori_max_pe_len: Original maximum position embedding length for YaRN correction + ori_max_pe_len: Original maximum position embedding length + mscale_override: Optional timestep-dependent mscale (from compute_timestep_mscale) Returns: Tuple of (cos, sin) frequency tensors """ assert dim % 2 == 0 - n_freqs = dim // 2 device = pos.device dtype = torch.float64 if device.type != "mps" else torch.float32 @@ -258,61 +268,69 @@ def compute_vision_yarn_freqs( angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) return torch.cos(angles).to(pos.dtype), torch.sin(angles).to(pos.dtype) - # === Step 1: Compute DyPE modulation factor k_t === - k_t = compute_dype_k_t( - current_sigma=current_sigma, - dype_scale=dype_config.dype_scale, - dype_exponent=dype_config.dype_exponent, - dype_start_sigma=dype_config.dype_start_sigma, - ) + half_dim = dim // 2 - # === Step 2: DyPE-modulate YaRN correction parameters === - beta_0: float = YARN_BETA_0 - beta_1: float = YARN_BETA_1 - gamma_0: float = YARN_GAMMA_0 - gamma_1: float = YARN_GAMMA_1 + # === Step 1: YaRN correction parameters, modulated by DyPE k_t === + beta_0: float = YARN_BETA_0 # 1.25 + beta_1: float = YARN_BETA_1 # 0.75 + gamma_0: float = YARN_GAMMA_0 # 16.0 + gamma_1: float = YARN_GAMMA_1 # 2.0 if dype_config.enable_dype: + k_t = compute_dype_k_t( + current_sigma=current_sigma, + dype_scale=dype_config.dype_scale, + dype_exponent=dype_config.dype_exponent, + dype_start_sigma=dype_config.dype_start_sigma, + ) beta_0 = beta_0**k_t beta_1 = beta_1**k_t gamma_0 = gamma_0**k_t gamma_1 = gamma_1**k_t - # === Step 3: Compute three frequency bands === + # === Step 2: Three frequency bands === freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim - # Band 1: Base frequencies (original, unscaled) + # Band 1: Base frequencies (original RoPE) freqs_base = 1.0 / (theta**freq_seq) - # Band 2: Linearly scaled frequencies (per-axis) + # Band 2: Linear interpolation (Position Interpolation) freqs_linear = freqs_base / linear_scale - # Band 3: NTK-scaled frequencies (global) + # Band 3: NTK-scaled frequencies new_base = theta * (ntk_scale ** (dim / (dim - 2))) freqs_ntk = 1.0 / (new_base**freq_seq) - # === Step 4: Beta mask - blend linear <-> NTK === + # === Step 3: Beta blending (linear <-> NTK) === + # Low-frequency components -> NTK, high-frequency -> linear low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len) - low = max(0, low) - high = min(n_freqs, high) - mask_beta = 1.0 - linear_ramp_mask(low, high, n_freqs, device, dtype) + low, high = max(0, low), min(half_dim, high) + ramp_beta = linear_ramp_mask(low, high, half_dim, device, dtype) + mask_beta = 1.0 - ramp_beta freqs = freqs_linear * (1.0 - mask_beta) + freqs_ntk * mask_beta - # === Step 5: Gamma mask - blend result <-> base === + # === Step 4: Gamma blending (result <-> base) === + # Low-frequency components -> base (unchanged), high-frequency -> keep blend low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len) - low = max(0, low) - high = min(n_freqs, high) - mask_gamma = 1.0 - linear_ramp_mask(low, high, n_freqs, device, dtype) + low, high = max(0, low), min(half_dim, high) + ramp_gamma = linear_ramp_mask(low, high, half_dim, device, dtype) + mask_gamma = 1.0 - ramp_gamma freqs = freqs * (1.0 - mask_gamma) + freqs_base * mask_gamma - # === Step 6: Compute angles === + # === Step 5: Compute angles === angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) - # === Step 7: Apply timestep-dependent mscale === - mscale = compute_timestep_mscale(ntk_scale, current_sigma, dype_config) + cos = torch.cos(angles) + sin = torch.sin(angles) + + # === Step 6: Apply mscale === + if mscale_override is not None: + mscale = mscale_override + else: + mscale = 1.0 + 0.1 * math.log(ntk_scale) / math.sqrt(ntk_scale) - cos = torch.cos(angles) * mscale - sin = torch.sin(angles) * mscale + cos = cos * mscale + sin = sin * mscale return cos.to(pos.dtype), sin.to(pos.dtype) diff --git a/invokeai/backend/flux/dype/rope.py b/invokeai/backend/flux/dype/rope.py index 8af42b041bd..611d5feea2f 100644 --- a/invokeai/backend/flux/dype/rope.py +++ b/invokeai/backend/flux/dype/rope.py @@ -8,6 +8,7 @@ FLUX_BASE_PE_LEN, DyPEConfig, compute_ntk_freqs, + compute_timestep_mscale, compute_vision_yarn_freqs, compute_yarn_freqs, ) @@ -78,6 +79,8 @@ def rope_dype( method = dype_config.method if method == "vision_yarn": + # Compute timestep-dependent mscale (matches ComfyUI-DyPE's _get_mscale) + mscale = compute_timestep_mscale(ntk_scale, current_sigma, dype_config) cos, sin = compute_vision_yarn_freqs( pos=pos, dim=dim, @@ -87,6 +90,7 @@ def rope_dype( current_sigma=current_sigma, dype_config=dype_config, ori_max_pe_len=ori_max_pe_len, + mscale_override=mscale, ) elif method == "yarn": cos, sin = compute_yarn_freqs( From fe61243d4f0a3612d3f8da55c5e2d3f32032ddf6 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Sun, 15 Mar 2026 06:29:44 +0100 Subject: [PATCH 5/6] Fix DyPE high-res artifacts: cap schedule shift and fix preset scaling - Cap FLUX noise schedule shift (mu) at base resolution when DyPE is active. Previously mu extrapolated linearly with resolution (e.g. mu=3.39 at 2496x1776 vs correct 1.15), causing the model to spend nearly all denoising at high noise levels, producing "crumpled paper" texture artifacts. --- invokeai/app/invocations/flux_denoise.py | 28 +++++++++++++++++------- invokeai/backend/flux/dype/base.py | 5 ++--- invokeai/backend/flux/dype/rope.py | 4 ++-- tests/backend/flux/dype/test_dype.py | 7 +++--- 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index d6102b105b3..c6f7a4b32bf 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -280,10 +280,28 @@ def _run_diffusion( transformer_config.base is BaseModelType.Flux and transformer_config.variant is FluxVariantType.Schnell ) + # Prepare DyPE config early to adjust schedule if needed. + dype_config = get_dype_config_from_preset( + preset=self.dype_preset, + width=self.width, + height=self.height, + custom_scale=self.dype_scale, + custom_exponent=self.dype_exponent, + ) + # Calculate the timestep schedule. + # When DyPE is active, cap image_seq_len at the base resolution's seq_len + # to prevent excessive mu/shift values. DyPE handles position extrapolation, + # so the scheduler shouldn't also compensate for higher resolution. + schedule_seq_len = packed_h * packed_w + if dype_config is not None: + base_patches = dype_config.base_resolution // 8 // 2 + base_seq_len = base_patches * base_patches + schedule_seq_len = min(schedule_seq_len, base_seq_len) + timesteps = get_schedule( num_steps=self.num_steps, - image_seq_len=packed_h * packed_w, + image_seq_len=schedule_seq_len, shift=not is_schnell, ) @@ -461,14 +479,8 @@ def _run_diffusion( img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids # Prepare DyPE extension for high-resolution generation + # (dype_config was already computed above for schedule adjustment) dype_extension: DyPEExtension | None = None - dype_config = get_dype_config_from_preset( - preset=self.dype_preset, - width=self.width, - height=self.height, - custom_scale=self.dype_scale, - custom_exponent=self.dype_exponent, - ) if dype_config is not None: dype_extension = DyPEExtension( config=dype_config, diff --git a/invokeai/backend/flux/dype/base.py b/invokeai/backend/flux/dype/base.py index 7cf950f25b9..4f49655aed7 100644 --- a/invokeai/backend/flux/dype/base.py +++ b/invokeai/backend/flux/dype/base.py @@ -53,7 +53,7 @@ def get_mscale(scale: float) -> float: def find_correction_factor( - num_rotations: int, + num_rotations: float, dim: int, base: int, max_position_embeddings: int, @@ -186,7 +186,6 @@ def compute_timestep_mscale( """Compute timestep-dependent magnitude scaling. Interpolates from aggressive mscale at early steps to 1.0 at late steps. - Matches ComfyUI-DyPE's _get_mscale behavior. Args: ntk_scale: Global NTK scaling factor @@ -199,7 +198,7 @@ def compute_timestep_mscale( if ntk_scale <= 1.0: return 1.0 - # Aggressive mscale formula (start value at high sigma) + # Aggressive mscale formula mscale_start = 0.1 * math.log(ntk_scale) + 1.0 mscale_end = 1.0 diff --git a/invokeai/backend/flux/dype/rope.py b/invokeai/backend/flux/dype/rope.py index 611d5feea2f..e876b82f1f1 100644 --- a/invokeai/backend/flux/dype/rope.py +++ b/invokeai/backend/flux/dype/rope.py @@ -60,7 +60,7 @@ def rope_dype( scale_w = target_width / base_res scale = max(scale_h, scale_w) - # If no scaling needed and DyPE disabled, use base method + # If no scaling needed or DyPE disabled, use base method if not dype_config.enable_dype or scale <= 1.0: return _rope_base(pos, dim, theta) @@ -79,7 +79,7 @@ def rope_dype( method = dype_config.method if method == "vision_yarn": - # Compute timestep-dependent mscale (matches ComfyUI-DyPE's _get_mscale) + # Compute timestep-dependent mscale mscale = compute_timestep_mscale(ntk_scale, current_sigma, dype_config) cos, sin = compute_vision_yarn_freqs( pos=pos, diff --git a/tests/backend/flux/dype/test_dype.py b/tests/backend/flux/dype/test_dype.py index 0965bb6a35b..7321a3c23cf 100644 --- a/tests/backend/flux/dype/test_dype.py +++ b/tests/backend/flux/dype/test_dype.py @@ -406,8 +406,8 @@ def test_get_dype_config_for_resolution_above_threshold(self): assert config.enable_dype is True assert config.method == "vision_yarn" - def test_get_dype_config_for_resolution_dynamic_scale(self): - """Higher resolution should result in higher dype_scale.""" + def test_get_dype_config_for_resolution_fixed_scale(self): + """All resolutions above threshold should use fixed dype_scale=2.0.""" config_2k = get_dype_config_for_resolution( width=2048, height=2048, @@ -423,7 +423,8 @@ def test_get_dype_config_for_resolution_dynamic_scale(self): assert config_2k is not None assert config_4k is not None - assert config_4k.dype_scale > config_2k.dype_scale + assert config_2k.dype_scale == 2.0 + assert config_4k.dype_scale == 2.0 def test_get_dype_config_for_area_below_threshold(self): """When area is below threshold area, should return None.""" From c79afc6d6771786141b91e78b84d34bda5246cc8 Mon Sep 17 00:00:00 2001 From: Alexander Eichhorn Date: Sun, 15 Mar 2026 06:39:34 +0100 Subject: [PATCH 6/6] Fix for the wrongly changed test. --- tests/backend/flux/dype/test_dype.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/backend/flux/dype/test_dype.py b/tests/backend/flux/dype/test_dype.py index 7321a3c23cf..0965bb6a35b 100644 --- a/tests/backend/flux/dype/test_dype.py +++ b/tests/backend/flux/dype/test_dype.py @@ -406,8 +406,8 @@ def test_get_dype_config_for_resolution_above_threshold(self): assert config.enable_dype is True assert config.method == "vision_yarn" - def test_get_dype_config_for_resolution_fixed_scale(self): - """All resolutions above threshold should use fixed dype_scale=2.0.""" + def test_get_dype_config_for_resolution_dynamic_scale(self): + """Higher resolution should result in higher dype_scale.""" config_2k = get_dype_config_for_resolution( width=2048, height=2048, @@ -423,8 +423,7 @@ def test_get_dype_config_for_resolution_fixed_scale(self): assert config_2k is not None assert config_4k is not None - assert config_2k.dype_scale == 2.0 - assert config_4k.dype_scale == 2.0 + assert config_4k.dype_scale > config_2k.dype_scale def test_get_dype_config_for_area_below_threshold(self): """When area is below threshold area, should return None."""