Skip to content

ostris/aitk_soar

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 

Repository files navigation

AI Toolkit - SOAR Extension

This is an extension for Ostris - AI Toolkit to allow training with the loss described in SOAR: Self-Correction for Optimal Alignment and Refinement in Diffusion Models.

To use it, do a git clone in the ai-toolkit/extensions directory:

cd ai-toolkit/extensions
git clone https://github.com/ostris/aitk_soar.git

The you can adjust your advanced training config to use the SOARTrainer class:

---
job: "extension"
config:
  name: "project_name"
  process:
    - type: "SOARTrainer" # <- replace diffusion_trainer with SOARTrainer
      soar:  # <- add this section with the desired hyperparameters for the SOAR loss>
        num_steps: 30
        cfg_scale: 4.5
        correction_weight: 1
        num_aux_points: 4

SOAR config items

num_steps (default: 30)

Corresponds to K in the paper (rollout step count). SOAR constructs an off-trajectory state by taking a single Euler ODE step of size 1/K from the current noise level toward the clean endpoint. With K=30, each rollout step moves σ down by ≈ 0.033, keeping the one-step deviation small and bounded (Appendix A.3.2).

  • Larger K (e.g. 30, 50): small step, deviation stays close to the ideal trajectory, z_0 remains a safe correction anchor. This is the paper's default and matches a typical 30-step inference schedule.
  • Smaller K (e.g. 4 or 1): larger single-step jump — trains the model specifically for few-step inference, but the paper's bounded-deviation argument weakens, and diversity/stability can suffer. In the limit K=1, next_sigma always clamps to 0, meaning every rollout is a full single-step jump to the clean endpoint (adjacent to consistency-style training).

Pick K to match your intended inference step budget.

cfg_scale (default: 4.5)

The classifier-free guidance scale w_cfg used when computing the rollout velocity (eq 25). The off-trajectory state is constructed using v_cfg = v_uncond + w_cfg · (v_cond − v_uncond) so the rollout matches where a CFG-guided inference pass would actually drift. This is applied under no_grad — CFG only shapes the off-trajectory state, not the loss gradient.

  • 4.5 (paper default): all experiments in the paper fix this value. Not ablated. Matches the CFG scale used for evaluation, so correction training is on-policy for CFG=4.5 inference.
  • 1.0: disables CFG. The rollout uses only the conditional velocity. This skips the extra unconditional forward pass (cheaper, less memory) and trains correction for no-CFG inference. Reasonable if you plan to infer without CFG.
  • Higher (e.g. 7.5): matches higher-CFG inference but produces more aggressive rollout drift, which the correction loss must absorb.

Rule of thumb: set this to whatever CFG you plan to use at inference.

correction_weight (default: 1.0)

Corresponds to λ in the paper (eq 15 / eq 43). Weight of the correction loss relative to the on-trajectory SFT base loss in the combined objective:

L_SOAR = (L_base + λ · L_corr_sum) / (1 + λ · N)
  • 1.0 (paper default): base and correction carry equal per-sample weight.
  • 0.0: disables SOAR entirely, collapsing to plain SFT. Useful for A/B comparison.
  • > 1: prioritize correction over the on-trajectory base loss. Can accelerate exposure-bias reduction but risks drifting from the SFT data distribution if pushed too high.
  • < 1: softens the correction signal, keeping SFT as the dominant objective.

The paper does not sweep λ; 1.0 is the only value used.

num_aux_points (default: 4)

Corresponds to N in the paper (auxiliary points per rollout). For each training sample, after constructing the one-step off-trajectory state, SOAR samples N auxiliary noise levels in [next_sigma, 1], re-noises the off-trajectory state toward the same endpoint noise at each level, and computes a correction loss at every one (eq 29 / Algorithm 1 lines 10–15). More points give denser per-timestep supervision across the upstream trajectory.

  • 4: value used in this extension as a reasonable default. Balances supervision density against cost (each aux point is an extra forward + backward pass through the model).
  • Higher (e.g. 8): denser trajectory coverage, proportionally more compute and memory per step.
  • Lower (1–2): cheaper per step, useful when memory is tight on large models with gradient checkpointing.

The paper explicitly defers the N ablation to an extended version — no empirical guidance on the optimal value. Algorithm 1 treats N as a free hyperparameter.

About

AI Toolkit Extension to train with a SOAR objective

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages