This is an extension for Ostris - AI Toolkit to allow training with the loss described in SOAR: Self-Correction for Optimal Alignment and Refinement in Diffusion Models.
To use it, do a git clone in the ai-toolkit/extensions directory:
cd ai-toolkit/extensions
git clone https://github.com/ostris/aitk_soar.gitThe you can adjust your advanced training config to use the SOARTrainer class:
---
job: "extension"
config:
name: "project_name"
process:
- type: "SOARTrainer" # <- replace diffusion_trainer with SOARTrainer
soar: # <- add this section with the desired hyperparameters for the SOAR loss>
num_steps: 30
cfg_scale: 4.5
correction_weight: 1
num_aux_points: 4Corresponds to K in the paper (rollout step count). SOAR constructs an off-trajectory state by taking a single Euler ODE step of size 1/K from the current noise level toward the clean endpoint. With K=30, each rollout step moves σ down by ≈ 0.033, keeping the one-step deviation small and bounded (Appendix A.3.2).
- Larger K (e.g. 30, 50): small step, deviation stays close to the ideal trajectory,
z_0remains a safe correction anchor. This is the paper's default and matches a typical 30-step inference schedule. - Smaller K (e.g. 4 or 1): larger single-step jump — trains the model specifically for few-step inference, but the paper's bounded-deviation argument weakens, and diversity/stability can suffer. In the limit
K=1,next_sigmaalways clamps to 0, meaning every rollout is a full single-step jump to the clean endpoint (adjacent to consistency-style training).
Pick K to match your intended inference step budget.
The classifier-free guidance scale w_cfg used when computing the rollout velocity (eq 25). The off-trajectory state is constructed using v_cfg = v_uncond + w_cfg · (v_cond − v_uncond) so the rollout matches where a CFG-guided inference pass would actually drift. This is applied under no_grad — CFG only shapes the off-trajectory state, not the loss gradient.
4.5(paper default): all experiments in the paper fix this value. Not ablated. Matches the CFG scale used for evaluation, so correction training is on-policy for CFG=4.5 inference.1.0: disables CFG. The rollout uses only the conditional velocity. This skips the extra unconditional forward pass (cheaper, less memory) and trains correction for no-CFG inference. Reasonable if you plan to infer without CFG.- Higher (e.g. 7.5): matches higher-CFG inference but produces more aggressive rollout drift, which the correction loss must absorb.
Rule of thumb: set this to whatever CFG you plan to use at inference.
Corresponds to λ in the paper (eq 15 / eq 43). Weight of the correction loss relative to the on-trajectory SFT base loss in the combined objective:
L_SOAR = (L_base + λ · L_corr_sum) / (1 + λ · N)
1.0(paper default): base and correction carry equal per-sample weight.0.0: disables SOAR entirely, collapsing to plain SFT. Useful for A/B comparison.> 1: prioritize correction over the on-trajectory base loss. Can accelerate exposure-bias reduction but risks drifting from the SFT data distribution if pushed too high.< 1: softens the correction signal, keeping SFT as the dominant objective.
The paper does not sweep λ; 1.0 is the only value used.
Corresponds to N in the paper (auxiliary points per rollout). For each training sample, after constructing the one-step off-trajectory state, SOAR samples N auxiliary noise levels in [next_sigma, 1], re-noises the off-trajectory state toward the same endpoint noise at each level, and computes a correction loss at every one (eq 29 / Algorithm 1 lines 10–15). More points give denser per-timestep supervision across the upstream trajectory.
4: value used in this extension as a reasonable default. Balances supervision density against cost (each aux point is an extra forward + backward pass through the model).- Higher (e.g. 8): denser trajectory coverage, proportionally more compute and memory per step.
- Lower (1–2): cheaper per step, useful when memory is tight on large models with gradient checkpointing.
The paper explicitly defers the N ablation to an extended version — no empirical guidance on the optimal value. Algorithm 1 treats N as a free hyperparameter.