diff --git a/invokeai/app/invocations/flux_denoise.py b/invokeai/app/invocations/flux_denoise.py index d6102b105b3..c6f7a4b32bf 100644 --- a/invokeai/app/invocations/flux_denoise.py +++ b/invokeai/app/invocations/flux_denoise.py @@ -280,10 +280,28 @@ def _run_diffusion( transformer_config.base is BaseModelType.Flux and transformer_config.variant is FluxVariantType.Schnell ) + # Prepare DyPE config early to adjust schedule if needed. + dype_config = get_dype_config_from_preset( + preset=self.dype_preset, + width=self.width, + height=self.height, + custom_scale=self.dype_scale, + custom_exponent=self.dype_exponent, + ) + # Calculate the timestep schedule. + # When DyPE is active, cap image_seq_len at the base resolution's seq_len + # to prevent excessive mu/shift values. DyPE handles position extrapolation, + # so the scheduler shouldn't also compensate for higher resolution. + schedule_seq_len = packed_h * packed_w + if dype_config is not None: + base_patches = dype_config.base_resolution // 8 // 2 + base_seq_len = base_patches * base_patches + schedule_seq_len = min(schedule_seq_len, base_seq_len) + timesteps = get_schedule( num_steps=self.num_steps, - image_seq_len=packed_h * packed_w, + image_seq_len=schedule_seq_len, shift=not is_schnell, ) @@ -461,14 +479,8 @@ def _run_diffusion( img_cond_seq, img_cond_seq_ids = kontext_extension.kontext_latents, kontext_extension.kontext_ids # Prepare DyPE extension for high-resolution generation + # (dype_config was already computed above for schedule adjustment) dype_extension: DyPEExtension | None = None - dype_config = get_dype_config_from_preset( - preset=self.dype_preset, - width=self.width, - height=self.height, - custom_scale=self.dype_scale, - custom_exponent=self.dype_exponent, - ) if dype_config is not None: dype_extension = DyPEExtension( config=dype_config, diff --git a/invokeai/backend/flux/dype/base.py b/invokeai/backend/flux/dype/base.py index 7b25a7f71f3..4f49655aed7 100644 --- a/invokeai/backend/flux/dype/base.py +++ b/invokeai/backend/flux/dype/base.py @@ -1,4 +1,8 @@ -"""DyPE base configuration and utilities.""" +"""DyPE base configuration and utilities. + +Implements Dynamic Position Extrapolation (DyPE) with YaRN-style frequency blending. +Based on ComfyUI-DyPE: https://github.com/wildminder/ComfyUI-DyPE +""" import math from dataclasses import dataclass @@ -7,6 +11,17 @@ import torch from torch import Tensor +# YaRN default parameters for FLUX +# These define the frequency correction ranges for blending +YARN_BETA_0 = 1.25 # Low-frequency ratio (β₀) +YARN_BETA_1 = 0.75 # High-frequency ratio (β₁) +YARN_GAMMA_0 = 16.0 # Original position range (γ₀) +YARN_GAMMA_1 = 2.0 # Extended position range (γ₁) + +# FLUX model constants +# Base position embedding length = base_resolution / patch_size / packing = 1024/8/2 = 64 +FLUX_BASE_PE_LEN = 64 + @dataclass class DyPEConfig: @@ -20,145 +35,302 @@ class DyPEConfig: dype_start_sigma: float = 1.0 # When DyPE decay starts -def get_mscale(scale: float, mscale_factor: float = 1.0) -> float: - """Calculate magnitude scaling factor. +def get_mscale(scale: float) -> float: + """Calculate magnitude scaling factor (mscale). + + Uses the formula from YaRN paper: mscale = 1 + 0.1 * log(s) / sqrt(s) + This provides better attention score normalization for high-resolution. Args: - scale: The resolution scaling factor - mscale_factor: Adjustment factor for the scaling + scale: The resolution scaling factor (NTK scale) Returns: The magnitude scaling factor """ if scale <= 1.0: return 1.0 - return mscale_factor * math.log(scale) + 1.0 + return 1.0 + 0.1 * math.log(scale) / math.sqrt(scale) -def get_timestep_mscale( - scale: float, +def find_correction_factor( + num_rotations: float, + dim: int, + base: int, + max_position_embeddings: int, +) -> float: + """Calculate correction factor for YaRN frequency masking. + + Finds the dimension index where the wavelength equals a given number of rotations. + Used to determine which frequency components need interpolation. + + Args: + num_rotations: Number of rotations to find the factor for + dim: Embedding dimension + base: RoPE base frequency (theta) + max_position_embeddings: Original maximum position embeddings + + Returns: + The dimension index (can be fractional) where wavelength matches num_rotations + """ + # Wavelength at dimension d = 2π * base^(d/dim) + # We want to find d where: wavelength / max_pe_len = num_rotations + # => 2π * base^(d/dim) = num_rotations * max_pe_len + # => d = dim * log(num_rotations * max_pe_len / 2π) / log(base) + return (dim * math.log(max_position_embeddings / (num_rotations * 2.0 * math.pi))) / (2.0 * math.log(base)) + + +def find_correction_range( + low_ratio: float, + high_ratio: float, + dim: int, + base: int, + ori_max_pe_len: int, +) -> tuple[float, float]: + """Find the dimension range for frequency correction. + + Determines the range of dimensions that need interpolation between + different frequency scaling methods. + + Args: + low_ratio: Low frequency ratio (beta or gamma low) + high_ratio: High frequency ratio (beta or gamma high) + dim: Embedding dimension + base: RoPE base frequency (theta) + ori_max_pe_len: Original maximum position embedding length + + Returns: + Tuple of (low_dim, high_dim) indices for the correction range + """ + low = math.floor(find_correction_factor(low_ratio, dim, base, ori_max_pe_len)) + high = math.ceil(find_correction_factor(high_ratio, dim, base, ori_max_pe_len)) + return max(low, 0.0), min(high, dim - 1.0) + + +def linear_ramp_mask( + min_val: float, + max_val: float, + dim: int, + device: torch.device, + dtype: torch.dtype, +) -> Tensor: + """Create linear interpolation mask between frequency bands. + + Creates a tensor that ramps linearly from 0 to 1 between min_val and max_val, + with values clamped to [0, 1] outside that range. + + Args: + min_val: Dimension index where mask starts (becomes > 0) + max_val: Dimension index where mask ends (becomes 1) + dim: Number of frequency components (half the embedding dimension) + device: Target device for the tensor + dtype: Target dtype for the tensor + + Returns: + Tensor of shape (dim,) with values in [0, 1] + """ + if max_val <= min_val: + # Degenerate case: no interpolation range, return step function + indices = torch.arange(dim, device=device, dtype=dtype) + return (indices >= min_val).to(dtype) + + # Linear ramp: (i - min) / (max - min), clamped to [0, 1] + indices = torch.arange(dim, device=device, dtype=dtype) + mask = (indices - min_val) / (max_val - min_val) + return torch.clamp(mask, 0.0, 1.0) + + +def compute_dype_k_t( current_sigma: float, dype_scale: float, dype_exponent: float, dype_start_sigma: float, ) -> float: - """Calculate timestep-dependent magnitude scaling. + """Compute the DyPE timestep modulation factor k_t. The key insight of DyPE: early steps focus on low frequencies (global structure), - late steps on high frequencies (details). This function modulates the scaling - based on the current timestep/sigma. + late steps on high frequencies (details). This function computes k_t which + modulates the YaRN correction parameters (beta, gamma). + + Formula: k_t = dype_scale * (timestep ^ dype_exponent) + + Where timestep is the normalized sigma value (0 at end, 1 at start of denoising). Args: - scale: Resolution scaling factor current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) - dype_scale: DyPE magnitude (λs) - dype_exponent: DyPE decay speed (λt) + dype_scale: DyPE magnitude (λs, 0.0-8.0) + dype_exponent: DyPE decay speed (λt, 0.0-1000.0) dype_start_sigma: Sigma threshold to start decay Returns: - Timestep-modulated scaling factor + k_t modulation factor - larger values mean stronger extrapolation """ - if scale <= 1.0: - return 1.0 - # Normalize sigma to [0, 1] range relative to start_sigma if current_sigma >= dype_start_sigma: - t_normalized = 1.0 + timestep = 1.0 else: - t_normalized = current_sigma / dype_start_sigma + timestep = current_sigma / dype_start_sigma - # Apply exponential decay: stronger extrapolation early, weaker late - # decay = exp(-λt * (1 - t)) where t=1 is early (high sigma), t=0 is late - decay = math.exp(-dype_exponent * (1.0 - t_normalized)) + # DyPE formula: k_t = scale * (timestep ^ exponent) + # At timestep=1 (early, high sigma): k_t = dype_scale + # At timestep=0 (late, low sigma): k_t = 0 + k_t = dype_scale * (timestep**dype_exponent) - # Base mscale from resolution - base_mscale = get_mscale(scale) + return k_t - # Interpolate between base_mscale and 1.0 based on decay and dype_scale - # When decay=1 (early): use scaled value - # When decay=0 (late): use base value - scaled_mscale = 1.0 + (base_mscale - 1.0) * dype_scale * decay - return scaled_mscale +def compute_timestep_mscale( + ntk_scale: float, + current_sigma: float, + dype_config: DyPEConfig, +) -> float: + """Compute timestep-dependent magnitude scaling. + + Interpolates from aggressive mscale at early steps to 1.0 at late steps. + + Args: + ntk_scale: Global NTK scaling factor + current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) + dype_config: DyPE configuration + + Returns: + Timestep-modulated magnitude scaling factor + """ + if ntk_scale <= 1.0: + return 1.0 + + # Aggressive mscale formula + mscale_start = 0.1 * math.log(ntk_scale) + 1.0 + mscale_end = 1.0 + + # Normalize sigma + if current_sigma >= dype_config.dype_start_sigma: + t_norm = 1.0 + else: + t_norm = current_sigma / dype_config.dype_start_sigma + + # Interpolate: full mscale at early steps, 1.0 at late steps + return mscale_end + (mscale_start - mscale_end) * (t_norm**dype_config.dype_exponent) def compute_vision_yarn_freqs( pos: Tensor, dim: int, theta: int, - scale_h: float, - scale_w: float, + linear_scale: float, + ntk_scale: float, current_sigma: float, dype_config: DyPEConfig, + ori_max_pe_len: int = FLUX_BASE_PE_LEN, + mscale_override: float | None = None, ) -> tuple[Tensor, Tensor]: - """Compute RoPE frequencies using NTK-aware scaling for high-resolution. + """Compute RoPE frequencies using DyPE-modulated YaRN 3-band blending. - This method extends FLUX's position encoding to handle resolutions beyond - the 1024px training resolution by scaling the base frequency (theta). + Implements the Vision YaRN method from ComfyUI-DyPE: three frequency bands + (base, linear, NTK) are blended using beta and gamma correction masks. + DyPE modulates the correction parameters via k_t = scale * (sigma ^ exponent). - The NTK-aware approach smoothly interpolates frequencies to cover larger - position ranges without breaking the attention patterns. + The three bands: + - freqs_base: Original RoPE frequencies (unchanged) + - freqs_linear: Position Interpolation (freqs_base / linear_scale) + - freqs_ntk: NTK-scaled (theta * ntk_alpha for new base) - DyPE (Dynamic Position Extrapolation) modulates the NTK scaling based on - the current timestep - stronger extrapolation in early steps (global structure), - weaker in late steps (fine details). + Blending order: + 1. Beta mask blends linear <-> NTK (low freqs -> NTK, high freqs -> linear) + 2. Gamma mask blends result <-> base (low freqs -> base, high freqs -> keep blend) Args: pos: Position tensor dim: Embedding dimension theta: RoPE base frequency - scale_h: Height scaling factor - scale_w: Width scaling factor + linear_scale: Per-axis linear scaling factor (height or width ratio) + ntk_scale: Global NTK scaling factor (max of height/width ratios) current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) dype_config: DyPE configuration + ori_max_pe_len: Original maximum position embedding length + mscale_override: Optional timestep-dependent mscale (from compute_timestep_mscale) Returns: Tuple of (cos, sin) frequency tensors """ assert dim % 2 == 0 - # Use the larger scale for NTK calculation - scale = max(scale_h, scale_w) - device = pos.device dtype = torch.float64 if device.type != "mps" else torch.float32 - # NTK-aware theta scaling: extends position coverage for high-res - # Formula: theta_scaled = theta * scale^(dim/(dim-2)) - # This increases the wavelength of position encodings proportionally - if scale > 1.0: - ntk_alpha = scale ** (dim / (dim - 2)) - - # Apply timestep-dependent DyPE modulation - # mscale controls how strongly we apply the NTK extrapolation - # Early steps (high sigma): stronger extrapolation for global structure - # Late steps (low sigma): weaker extrapolation for fine details - mscale = get_timestep_mscale( - scale=scale, + linear_scale = max(linear_scale, 1.0) + ntk_scale = max(ntk_scale, 1.0) + + if ntk_scale <= 1.0: + # No scaling needed - use base frequencies + freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim + freqs = 1.0 / (theta**freq_seq) + angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) + return torch.cos(angles).to(pos.dtype), torch.sin(angles).to(pos.dtype) + + half_dim = dim // 2 + + # === Step 1: YaRN correction parameters, modulated by DyPE k_t === + beta_0: float = YARN_BETA_0 # 1.25 + beta_1: float = YARN_BETA_1 # 0.75 + gamma_0: float = YARN_GAMMA_0 # 16.0 + gamma_1: float = YARN_GAMMA_1 # 2.0 + + if dype_config.enable_dype: + k_t = compute_dype_k_t( current_sigma=current_sigma, dype_scale=dype_config.dype_scale, dype_exponent=dype_config.dype_exponent, dype_start_sigma=dype_config.dype_start_sigma, ) + beta_0 = beta_0**k_t + beta_1 = beta_1**k_t + gamma_0 = gamma_0**k_t + gamma_1 = gamma_1**k_t - # Modulate NTK alpha by mscale - # When mscale > 1: interpolate towards stronger extrapolation - # When mscale = 1: use base NTK alpha - modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale - scaled_theta = theta * modulated_alpha - else: - scaled_theta = theta - - # Standard RoPE frequency computation + # === Step 2: Three frequency bands === freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim - freqs = 1.0 / (scaled_theta**freq_seq) - # Compute angles = position * frequency + # Band 1: Base frequencies (original RoPE) + freqs_base = 1.0 / (theta**freq_seq) + + # Band 2: Linear interpolation (Position Interpolation) + freqs_linear = freqs_base / linear_scale + + # Band 3: NTK-scaled frequencies + new_base = theta * (ntk_scale ** (dim / (dim - 2))) + freqs_ntk = 1.0 / (new_base**freq_seq) + + # === Step 3: Beta blending (linear <-> NTK) === + # Low-frequency components -> NTK, high-frequency -> linear + low, high = find_correction_range(beta_0, beta_1, dim, theta, ori_max_pe_len) + low, high = max(0, low), min(half_dim, high) + ramp_beta = linear_ramp_mask(low, high, half_dim, device, dtype) + mask_beta = 1.0 - ramp_beta + freqs = freqs_linear * (1.0 - mask_beta) + freqs_ntk * mask_beta + + # === Step 4: Gamma blending (result <-> base) === + # Low-frequency components -> base (unchanged), high-frequency -> keep blend + low, high = find_correction_range(gamma_0, gamma_1, dim, theta, ori_max_pe_len) + low, high = max(0, low), min(half_dim, high) + ramp_gamma = linear_ramp_mask(low, high, half_dim, device, dtype) + mask_gamma = 1.0 - ramp_gamma + freqs = freqs * (1.0 - mask_gamma) + freqs_base * mask_gamma + + # === Step 5: Compute angles === angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) cos = torch.cos(angles) sin = torch.sin(angles) + # === Step 6: Apply mscale === + if mscale_override is not None: + mscale = mscale_override + else: + mscale = 1.0 + 0.1 * math.log(ntk_scale) / math.sqrt(ntk_scale) + + cos = cos * mscale + sin = sin * mscale + return cos.to(pos.dtype), sin.to(pos.dtype) @@ -169,56 +341,35 @@ def compute_yarn_freqs( scale: float, current_sigma: float, dype_config: DyPEConfig, + ori_max_pe_len: int = FLUX_BASE_PE_LEN, ) -> tuple[Tensor, Tensor]: - """Compute RoPE frequencies using YARN/NTK method. + """Compute RoPE frequencies using DyPE-modulated YaRN with uniform scale. - Uses NTK-aware theta scaling for high-resolution support with - timestep-dependent DyPE modulation. + Uses the same 3-band blending as vision_yarn but with a uniform scale + factor for both linear and NTK components. Args: pos: Position tensor dim: Embedding dimension theta: RoPE base frequency - scale: Uniform scaling factor + scale: Uniform scaling factor (used for both linear and NTK) current_sigma: Current noise level (1.0 = full noise, 0.0 = clean) dype_config: DyPE configuration + ori_max_pe_len: Original maximum position embedding length Returns: Tuple of (cos, sin) frequency tensors """ - assert dim % 2 == 0 - - device = pos.device - dtype = torch.float64 if device.type != "mps" else torch.float32 - - # NTK-aware theta scaling with DyPE modulation - if scale > 1.0: - ntk_alpha = scale ** (dim / (dim - 2)) - - # Apply timestep-dependent DyPE modulation - mscale = get_timestep_mscale( - scale=scale, - current_sigma=current_sigma, - dype_scale=dype_config.dype_scale, - dype_exponent=dype_config.dype_exponent, - dype_start_sigma=dype_config.dype_start_sigma, - ) - - # Modulate NTK alpha by mscale - modulated_alpha = 1.0 + (ntk_alpha - 1.0) * mscale - scaled_theta = theta * modulated_alpha - else: - scaled_theta = theta - - freq_seq = torch.arange(0, dim, 2, dtype=dtype, device=device) / dim - freqs = 1.0 / (scaled_theta**freq_seq) - - angles = torch.einsum("...n,d->...nd", pos.to(dtype), freqs) - - cos = torch.cos(angles) - sin = torch.sin(angles) - - return cos.to(pos.dtype), sin.to(pos.dtype) + return compute_vision_yarn_freqs( + pos=pos, + dim=dim, + theta=theta, + linear_scale=scale, + ntk_scale=scale, + current_sigma=current_sigma, + dype_config=dype_config, + ori_max_pe_len=ori_max_pe_len, + ) def compute_ntk_freqs( diff --git a/invokeai/backend/flux/dype/embed.py b/invokeai/backend/flux/dype/embed.py index ace6a56ab0f..8a062d460a6 100644 --- a/invokeai/backend/flux/dype/embed.py +++ b/invokeai/backend/flux/dype/embed.py @@ -6,6 +6,11 @@ from invokeai.backend.flux.dype.base import DyPEConfig from invokeai.backend.flux.dype.rope import rope_dype +# FLUX uses 8x8 patch compression with 2x2 packing +# base_resolution / 8 / 2 = base_patch_grid +FLUX_PATCH_SIZE = 8 +FLUX_PACKING_FACTOR = 2 + class DyPEEmbedND(nn.Module): """N-dimensional position embedding with DyPE support. @@ -17,6 +22,7 @@ class DyPEEmbedND(nn.Module): - Maintains step state (current_sigma, target dimensions) - Uses rope_dype() instead of rope() for frequency computation - Applies timestep-dependent scaling for better high-resolution generation + - Calculates base_patch_grid for proper scale computation """ def __init__( @@ -40,6 +46,14 @@ def __init__( self.axes_dim = axes_dim self.dype_config = dype_config + # Calculate base patch grid from base resolution + # FLUX: 1024 / 8 / 2 = 64 patches per side + self.base_patch_grid = dype_config.base_resolution // FLUX_PATCH_SIZE // FLUX_PACKING_FACTOR + + # Original max position embedding length = base patch grid + # This is the number of spatial positions per side at base resolution + self.ori_max_pe_len = self.base_patch_grid + # Step state - updated before each denoising step self._current_sigma: float = 1.0 self._target_height: int = 1024 @@ -83,6 +97,8 @@ def forward(self, ids: Tensor) -> Tensor: target_height=self._target_height, target_width=self._target_width, dype_config=self.dype_config, + axis_index=i, + ori_max_pe_len=self.ori_max_pe_len, ) embeddings.append(axis_emb) diff --git a/invokeai/backend/flux/dype/rope.py b/invokeai/backend/flux/dype/rope.py index f6a1594f6be..e876b82f1f1 100644 --- a/invokeai/backend/flux/dype/rope.py +++ b/invokeai/backend/flux/dype/rope.py @@ -5,8 +5,10 @@ from torch import Tensor from invokeai.backend.flux.dype.base import ( + FLUX_BASE_PE_LEN, DyPEConfig, compute_ntk_freqs, + compute_timestep_mscale, compute_vision_yarn_freqs, compute_yarn_freqs, ) @@ -20,12 +22,17 @@ def rope_dype( target_height: int, target_width: int, dype_config: DyPEConfig, + axis_index: int = 0, + ori_max_pe_len: int = FLUX_BASE_PE_LEN, ) -> Tensor: """Compute RoPE with Dynamic Position Extrapolation. This is the core DyPE function that replaces the standard rope() function. It applies resolution-aware and timestep-aware scaling to position embeddings. + DyPE scaling is only applied to spatial axes (axis_index > 0). Axis 0 + (time/channel) always uses plain RoPE to avoid distorting temporal attention. + Args: pos: Position indices tensor dim: Embedding dimension per axis @@ -34,34 +41,56 @@ def rope_dype( target_height: Target image height in pixels target_width: Target image width in pixels dype_config: DyPE configuration + axis_index: Which axis this is (0=time/channel, 1=height, 2=width). + Axis 0 always uses plain RoPE without DyPE scaling. + ori_max_pe_len: Original maximum position embedding length for YaRN correction Returns: Rotary position embedding tensor with shape suitable for FLUX attention """ assert dim % 2 == 0 + # Axis 0 (time/channel) never gets DyPE scaling - only spatial axes do + if axis_index == 0: + return _rope_base(pos, dim, theta) + # Calculate scaling factors base_res = dype_config.base_resolution scale_h = target_height / base_res scale_w = target_width / base_res scale = max(scale_h, scale_w) - # If no scaling needed and DyPE disabled, use base method + # If no scaling needed or DyPE disabled, use base method if not dype_config.enable_dype or scale <= 1.0: return _rope_base(pos, dim, theta) + # Compute per-axis linear_scale and global ntk_scale + # linear_scale: the scale for THIS specific axis (height or width) + # ntk_scale: the global scale = max(scale_h, scale_w) + if axis_index == 1: + linear_scale = scale_h + elif axis_index == 2: + linear_scale = scale_w + else: + linear_scale = scale + ntk_scale = scale + # Select method and compute frequencies method = dype_config.method if method == "vision_yarn": + # Compute timestep-dependent mscale + mscale = compute_timestep_mscale(ntk_scale, current_sigma, dype_config) cos, sin = compute_vision_yarn_freqs( pos=pos, dim=dim, theta=theta, - scale_h=scale_h, - scale_w=scale_w, + linear_scale=linear_scale, + ntk_scale=ntk_scale, current_sigma=current_sigma, dype_config=dype_config, + ori_max_pe_len=ori_max_pe_len, + mscale_override=mscale, ) elif method == "yarn": cos, sin = compute_yarn_freqs( @@ -71,6 +100,7 @@ def rope_dype( scale=scale, current_sigma=current_sigma, dype_config=dype_config, + ori_max_pe_len=ori_max_pe_len, ) elif method == "ntk": cos, sin = compute_ntk_freqs( diff --git a/tests/backend/flux/dype/test_dype.py b/tests/backend/flux/dype/test_dype.py index cc0b99011cd..0965bb6a35b 100644 --- a/tests/backend/flux/dype/test_dype.py +++ b/tests/backend/flux/dype/test_dype.py @@ -1,14 +1,22 @@ """Tests for DyPE (Dynamic Position Extrapolation) module.""" +import math + import torch from invokeai.backend.flux.dype.base import ( + FLUX_BASE_PE_LEN, + YARN_BETA_0, + YARN_BETA_1, DyPEConfig, + compute_dype_k_t, compute_ntk_freqs, compute_vision_yarn_freqs, compute_yarn_freqs, + find_correction_factor, + find_correction_range, get_mscale, - get_timestep_mscale, + linear_ramp_mask, ) from invokeai.backend.flux.dype.embed import DyPEEmbedND from invokeai.backend.flux.dype.presets import ( @@ -68,36 +76,153 @@ def test_get_mscale_with_scaling(self): assert mscale_2x > 1.0 assert mscale_4x > mscale_2x - def test_get_timestep_mscale_no_scaling(self): - """When scale <= 1.0, timestep_mscale should be 1.0.""" - result = get_timestep_mscale( - scale=1.0, - current_sigma=0.5, + def test_get_mscale_formula(self): + """Test mscale uses the correct YaRN formula: 1 + 0.1 * log(s) / sqrt(s).""" + scale = 4.0 + expected = 1.0 + 0.1 * math.log(scale) / math.sqrt(scale) + actual = get_mscale(scale) + assert abs(actual - expected) < 1e-10 + + +class TestDyPEKt: + """Tests for DyPE k_t calculation.""" + + def test_compute_dype_k_t_formula(self): + """Test k_t uses correct formula: scale * (timestep ^ exponent).""" + dype_scale = 2.0 + dype_exponent = 2.0 + current_sigma = 0.5 # normalized timestep + dype_start_sigma = 1.0 + + k_t = compute_dype_k_t( + current_sigma=current_sigma, + dype_scale=dype_scale, + dype_exponent=dype_exponent, + dype_start_sigma=dype_start_sigma, + ) + + # Expected: 2.0 * (0.5 ^ 2.0) = 2.0 * 0.25 = 0.5 + expected = dype_scale * (current_sigma**dype_exponent) + assert abs(k_t - expected) < 1e-10 + + def test_compute_dype_k_t_at_start(self): + """At timestep=1.0 (start), k_t should equal dype_scale.""" + k_t = compute_dype_k_t( + current_sigma=1.0, dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, ) - assert result == 1.0 + assert abs(k_t - 2.0) < 1e-10 - def test_get_timestep_mscale_high_sigma(self): - """Early steps (high sigma) should have stronger scaling.""" - early_mscale = get_timestep_mscale( - scale=2.0, - current_sigma=1.0, # Early step + def test_compute_dype_k_t_at_end(self): + """At timestep=0.0 (end), k_t should be 0.""" + k_t = compute_dype_k_t( + current_sigma=0.0, dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, ) - late_mscale = get_timestep_mscale( - scale=2.0, - current_sigma=0.1, # Late step + assert k_t == 0.0 + + def test_compute_dype_k_t_decreases_over_time(self): + """k_t should decrease as sigma decreases (denoising progresses).""" + k_t_early = compute_dype_k_t( + current_sigma=1.0, + dype_scale=2.0, + dype_exponent=2.0, + dype_start_sigma=1.0, + ) + k_t_mid = compute_dype_k_t( + current_sigma=0.5, + dype_scale=2.0, + dype_exponent=2.0, + dype_start_sigma=1.0, + ) + k_t_late = compute_dype_k_t( + current_sigma=0.1, dype_scale=2.0, dype_exponent=2.0, dype_start_sigma=1.0, ) - # Early steps should have larger mscale than late steps - assert early_mscale >= late_mscale + assert k_t_early > k_t_mid > k_t_late + + +class TestYaRNHelpers: + """Tests for YaRN helper functions.""" + + def test_find_correction_factor(self): + """Test correction factor calculation.""" + # Basic sanity check - should return a reasonable dimension index + factor = find_correction_factor( + num_rotations=1.25, + dim=32, + base=10000, + max_position_embeddings=256, + ) + assert factor >= 0 + assert factor <= 32 + + def test_find_correction_range(self): + """Test correction range returns valid bounds.""" + low, high = find_correction_range( + low_ratio=YARN_BETA_0, + high_ratio=YARN_BETA_1, + dim=28, # half of 56 (FLUX spatial dim) + base=10000, + ori_max_pe_len=FLUX_BASE_PE_LEN, + ) + + # Low should be <= high + assert low <= high + # Both should be in valid range + assert low >= 0 + assert high <= 27 # dim - 1 + + def test_linear_ramp_mask_shape(self): + """Test linear ramp mask has correct shape.""" + mask = linear_ramp_mask( + min_val=5.0, + max_val=15.0, + dim=28, + device=torch.device("cpu"), + dtype=torch.float32, + ) + assert mask.shape == (28,) + + def test_linear_ramp_mask_values(self): + """Test linear ramp mask has correct values.""" + mask = linear_ramp_mask( + min_val=5.0, + max_val=15.0, + dim=20, + device=torch.device("cpu"), + dtype=torch.float32, + ) + + # Values before min should be 0 + assert mask[0].item() == 0.0 + assert mask[4].item() == 0.0 + + # Values after max should be 1 + assert mask[15].item() == 1.0 + assert mask[19].item() == 1.0 + + # Values in between should be in (0, 1) + assert 0.0 < mask[10].item() < 1.0 + + def test_linear_ramp_mask_degenerate(self): + """Test linear ramp mask handles degenerate case (min >= max).""" + mask = linear_ramp_mask( + min_val=10.0, + max_val=5.0, # max < min + dim=20, + device=torch.device("cpu"), + dtype=torch.float32, + ) + # Should return step function at min_val + assert mask.shape == (20,) class TestRopeDype: @@ -131,7 +256,7 @@ def test_rope_dype_no_scaling(self): config = DyPEConfig(base_resolution=1024) - # No scaling needed + # No scaling needed (spatial axis) result_no_scale = rope_dype( pos=pos, dim=dim, @@ -140,9 +265,10 @@ def test_rope_dype_no_scaling(self): target_height=1024, target_width=1024, dype_config=config, + axis_index=1, # spatial axis ) - # With scaling + # With scaling (spatial axis) result_with_scale = rope_dype( pos=pos, dim=dim, @@ -151,11 +277,44 @@ def test_rope_dype_no_scaling(self): target_height=2048, target_width=2048, dype_config=config, + axis_index=1, # spatial axis ) # Results should be different when scaling is applied assert not torch.allclose(result_no_scale, result_with_scale) + def test_rope_dype_axis0_always_base(self): + """Axis 0 (time/channel) should always return base RoPE regardless of scaling.""" + pos = torch.arange(16).unsqueeze(0).float() + dim = 16 + theta = 10000 + + config = DyPEConfig(base_resolution=1024) + + result_no_scale = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=0.5, + target_height=1024, + target_width=1024, + dype_config=config, + axis_index=0, + ) + result_with_scale = rope_dype( + pos=pos, + dim=dim, + theta=theta, + current_sigma=0.5, + target_height=4096, + target_width=4096, + dype_config=config, + axis_index=0, + ) + + # Axis 0 should be identical regardless of target resolution + assert torch.allclose(result_no_scale, result_with_scale) + class TestDyPEEmbedND: """Tests for DyPEEmbedND module.""" @@ -364,8 +523,8 @@ def test_compute_vision_yarn_freqs_shape(self): pos=pos, dim=32, theta=10000, - scale_h=2.0, - scale_w=2.0, + linear_scale=2.0, + ntk_scale=2.0, current_sigma=0.5, dype_config=config, ) @@ -404,3 +563,107 @@ def test_compute_ntk_freqs_shape(self): assert cos.shape == sin.shape assert cos.shape[0] == 1 + + +class TestThreeBandBlending: + """Tests for 3-band frequency blending in YaRN implementation.""" + + def test_different_timesteps_produce_different_freqs(self): + """Different timesteps should produce different frequency outputs.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos_early, sin_early = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + linear_scale=2.0, + ntk_scale=2.0, + current_sigma=1.0, # Early step + dype_config=config, + ) + + cos_late, sin_late = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + linear_scale=2.0, + ntk_scale=2.0, + current_sigma=0.1, # Late step + dype_config=config, + ) + + # Outputs should be different due to DyPE timestep modulation + assert not torch.allclose(cos_early, cos_late) + assert not torch.allclose(sin_early, sin_late) + + def test_no_scaling_returns_base_freqs(self): + """When scale <= 1.0, should return base frequencies without mscale.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos, sin = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + linear_scale=1.0, + ntk_scale=1.0, + current_sigma=0.5, + dype_config=config, + ) + + # Verify shape is correct + assert cos.shape[0] == 1 + assert cos.shape[1] == 16 + + def test_yarn_freqs_matches_vision_yarn_for_uniform_scale(self): + """yarn and vision_yarn should produce same results for uniform scale.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos_vision, sin_vision = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + linear_scale=2.0, + ntk_scale=2.0, + current_sigma=0.5, + dype_config=config, + ) + + cos_yarn, sin_yarn = compute_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + scale=2.0, + current_sigma=0.5, + dype_config=config, + ) + + # Should be very close (same algorithm with same scale) + assert torch.allclose(cos_vision, cos_yarn, atol=1e-6) + assert torch.allclose(sin_vision, sin_yarn, atol=1e-6) + + def test_mscale_applied_to_output(self): + """Verify mscale is applied to cos/sin outputs.""" + pos = torch.arange(16).unsqueeze(0).float() + config = DyPEConfig() + + cos, sin = compute_vision_yarn_freqs( + pos=pos, + dim=32, + theta=10000, + linear_scale=4.0, + ntk_scale=4.0, + current_sigma=0.5, + dype_config=config, + ) + + # With mscale applied, max values can exceed 1.0 + # (pure cos/sin are in [-1, 1], but mscale > 1 stretches them) + ntk_scale = 4.0 ** (32 / (32 - 2)) + mscale = get_mscale(ntk_scale) + assert mscale > 1.0 + + # The actual values depend on the blending, but shape should be correct + assert cos.shape == (1, 16, 16) # (batch, seq_len, dim/2)