diff --git a/README.md b/README.md index 1ff714e..8c361bb 100644 --- a/README.md +++ b/README.md @@ -115,18 +115,20 @@ examples requires installing `matplotlib` (`pip install matplotlib`). ## Quickstart (C) ```c +#include /* INFINITY */ #include "aa.h" -AaWork *a = aa_init(n, /* dim */ - 10, /* mem */ - 10, /* min_len */ - 1, /* type1 */ - 1e-8, /* regularization */ - 1.0, /* relaxation */ - 2.0, /* safeguard_factor */ - 1e10, /* max_weight_norm */ - 5, /* ir_max_steps */ - 0); /* verbosity */ +AaWork *a = aa_init(n, /* dim */ + 10, /* mem */ + 10, /* min_len */ + 1, /* type1 */ + 1e-8, /* regularization */ + 1.0, /* relaxation */ + 2.0, /* safeguard_factor */ + INFINITY, /* max_weight_norm */ + INFINITY, /* trust_factor */ + 5, /* ir_max_steps */ + 0); /* verbosity */ for (int i = 0; i < N; i++) { if (i > 0) aa_apply(x, x_prev, a); @@ -149,10 +151,11 @@ See [`tests/c/gd.c`](tests/c/gd.c) for a complete runnable example | `mem` | Number of past iterates to look back | 5 – 20 | | `min_len` | Minimum buffered residual pairs before AA begins extrapolating. `min_len = mem` waits for the memory to fill (stable default); `min_len = 1` starts extrapolating immediately. Must be ≥ 1 when `mem > 0`; clamped down when it exceeds `min(mem, dim)`. | `mem` | | `type1` | Type-I if true, Type-II otherwise | see notes below | -| `regularization` | Tikhonov regularization on the AA least-squares system. `> 0`: scaled by `‖A‖_F·‖Y‖_F`. `< 0`: pinned absolute `-regularization` (no scaling). `= 0`: off. | Type-I: `1e-8`, Type-II: `1e-12` | +| `regularization` | Tikhonov regularization on the AA least-squares system. `> 0`: scaled by `‖S‖_F·‖Y‖_F` (same for Type-I and Type-II — historically Type-II used `‖Y‖²` which decays quadratically as `Y→0` and underflows on slow-contraction maps). `< 0`: pinned absolute `-regularization` (no scaling). `= 0`: off. | Type-I: `1e-8`, Type-II: `1e-12` | | `relaxation` | Mixing parameter in `[0, 2]`; `1.0` is vanilla AA | `1.0` | | `safeguard_factor` | Multiplier on the residual-growth ratio beyond which the AA step is rejected. Larger = more aggressive. | `2.0` | -| `max_weight_norm` | Upper bound on the norm of the AA combination weights; rejects numerically unstable steps | `1e6` – `1e10` | +| `max_weight_norm` | Hard cap: reject solves with `‖γ‖₂ ≥ max_weight_norm`. Pass `INFINITY` to disable. | `1e6` – `1e10` or `INFINITY` | +| `trust_factor` | Opt-in trust region + adaptive regularization (see "Trust-region mode" below). Pass `INFINITY` to disable; pass a finite positive value (typically `10`) to enable. | `INFINITY` (off), or `~10` for ADMM/DRS | | `ir_max_steps` | Cap on iterative-refinement passes for the weight solve. The loop stops early when refinement stalls, so this is an upper bound; raise for ill-conditioned problems, lower for tighter cost bounds. | `5` | | `verbosity` | `0` silent, higher values print progress and diagnostics | `0` | @@ -161,6 +164,42 @@ problems but can be sensitive; Type-II is more robust. If one fails, try the other. Both tolerate nonsmooth `F` thanks to the safeguard, though convergence guarantees in that regime are stronger for Type-I (see the paper). +### Trust-region mode (opt-in) + +Setting `trust_factor` to a finite positive value (typically `10`) turns on +two coupled mechanisms that target a failure mode seen on slow-contraction +maps (ADMM / DRS / proximal splitting): AA's least-squares system underflows +near convergence, the weight vector `γ` blows up, and the safeguard's +monotone test (`‖g_new‖ ≤ ‖g_old‖`) is too weak to catch the resulting +"creep" — each step marginally reduces the residual without approaching the +fixed point. + +1. **Trust region.** Each AA solve rejects the step if `‖D γ‖₂ > trust_factor · ‖g‖₂`, + where `D` is the matrix of past `Δf` columns. This bounds the iterate + displacement relative to the current residual and catches the + "γ-passes-the-weight-cap-but-`Dγ`-is-huge" failure mode directly. + +2. **Adaptive regularization.** The Tikhonov term `r` is replaced by a + self-tuning value that starts large (so initial `γ ≈ 0`, i.e. AA ≈ `F`), + shrinks by `0.9×` on each safeguard accept (let AA do more), and grows + by `10×` on each rejection (back off toward `F`). The two mechanisms + feed each other: a trust trip bumps `r`, the next solve produces a + smaller `γ`, the trust check usually passes. + +`trust_factor = INFINITY` (the default) disables both mechanisms and the +library uses the standard `ε · ‖S‖_F · ‖Y‖_F` regularization. The two paths +are independent — turning trust-region mode on only matters for the kind of +problem above. + +When to set `trust_factor`: + +| Situation | Recommendation | +|-------------------------------------------------------------------|----------------------------| +| Gradient descent, prox iterations on well-scaled problems | leave `INFINITY` (default) | +| Operator-splitting solvers (ADMM, PDHG, Douglas-Rachford) | try `trust_factor = 10` | +| Anything where AA produces `‖γ‖₂` in the hundreds but `‖g‖` doesn't keep dropping | try `trust_factor = 10` | +| Newton-like ill-conditioned problems where `γ` is legitimately huge | leave `INFINITY` (default) | + ## Python API ```python @@ -168,12 +207,13 @@ aa.AndersonAccelerator( dim, mem, *, - min_len=None, # defaults to min(mem, dim) + min_len=None, # defaults to min(mem, dim) type1=False, regularization=1e-12, relaxation=1.0, safeguard_factor=1.0, - max_weight_norm=1e6, + max_weight_norm=math.inf, + trust_factor=math.inf, # see "Trust-region mode" above; ~10 for ADMM/DRS ir_max_steps=5, verbosity=0, ) @@ -187,7 +227,7 @@ C-contiguous, writeable `float64` numpy arrays of length `dim`. | `apply(f, x)` | Call once per iteration (skip the first). `f` holds the most recent map output `F(x)`. Overwrites `f` in place with the AA-extrapolated point. | | `safeguard(f_new, x_new)` | Call after running your map on the AA extrapolate. If AA did not make progress, reverts both arrays to the last-known-good state. Returns `0` on accept, `-1` on reject. | | `reset()` | Clears AA state (equivalent to re-initializing) without reallocating. Lifetime `stats` counters are NOT cleared. | -| `stats` | Read-only property returning a dict of lifetime counters: `iter`, `n_accept`, `n_reject_lapack`, `n_reject_rank0`, `n_reject_nonfinite`, `n_reject_weight_cap`, `n_safeguard_reject`, `last_rank`, `last_aa_norm` (NaN until the first solve), `last_regularization`. Useful for diagnosing when AA isn't helping — rising `n_reject_weight_cap` or `n_reject_nonfinite` points at `max_weight_norm` / `regularization` tuning; rising `n_safeguard_reject` points at `safeguard_factor` / `mem`; `n_reject_rank0` is normal near convergence (memory is numerically zero). | +| `stats` | Read-only property returning a dict of lifetime counters: `iter`, `n_accept`, `n_reject_lapack`, `n_reject_rank0`, `n_reject_nonfinite`, `n_reject_weight_cap`, `n_safeguard_reject`, `last_rank`, `last_aa_norm` (NaN until the first solve), `last_regularization`. Useful for diagnosing when AA isn't helping — high `n_safeguard_reject` with low `last_regularization` and high `last_aa_norm` on a slow-contraction problem suggests trying `trust_factor = 10`; rising `n_reject_weight_cap` or `n_reject_nonfinite` points at `max_weight_norm` / `regularization` tuning; rising `n_safeguard_reject` points at `safeguard_factor` / `mem`; `n_reject_rank0` is normal near convergence (memory is numerically zero). | ## C API @@ -198,7 +238,8 @@ Python API exactly: AaWork *aa_init(aa_int dim, aa_int mem, aa_int min_len, aa_int type1, aa_float regularization, aa_float relaxation, aa_float safeguard_factor, aa_float max_weight_norm, - aa_int ir_max_steps, aa_int verbosity); + aa_float trust_factor, aa_int ir_max_steps, + aa_int verbosity); aa_float aa_apply(aa_float *f, const aa_float *x, AaWork *a); aa_int aa_safeguard(aa_float *f_new, aa_float *x_new, AaWork *a); diff --git a/include/aa.h b/include/aa.h index 149a613..9854ed6 100644 --- a/include/aa.h +++ b/include/aa.h @@ -49,7 +49,21 @@ typedef struct ACCEL_WORK AaWork; * @param relaxation finite float \in [0,2], mixing parameter (1.0 is vanilla) * @param safeguard_factor finite nonnegative factor that controls safeguarding checks * larger is more aggressive but less stable - * @param max_weight_norm finite positive float, maximum norm of AA weights + * @param max_weight_norm positive float, maximum L2 norm of AA weights γ; + * the solve is rejected when ||γ||_2 >= max_weight_norm. + * Pass INFINITY to disable this hard cap. + * @param trust_factor positive float, opt-in trust region + adaptive + * regularization. When finite, the solve rejects + * steps with ||D γ||_2 > trust_factor * ||g||_2 and + * AA's regularization adapts via accept/reject + * feedback (start large so AA ≈ DRS; shrink by 0.9× + * on each safeguard accept; grow by 10× on any + * rejection). Useful for slow-contraction maps + * (ADMM/DRS) where the LS basis can produce large + * but unproductive γ that the safeguard alone + * fails to catch. Pass INFINITY to disable — the + * original ε·||S||·||Y|| regularization path is + * then used unchanged. * @param ir_max_steps max iterative refinement passes on the γ solve. * 0 disables IR. Each step is O(mem²) and loops * until the correction stops contracting, so on @@ -65,7 +79,8 @@ typedef struct ACCEL_WORK AaWork; AaWork *aa_init(aa_int dim, aa_int mem, aa_int min_len, aa_int type1, aa_float regularization, aa_float relaxation, aa_float safeguard_factor, aa_float max_weight_norm, - aa_int ir_max_steps, aa_int verbosity); + aa_float trust_factor, aa_int ir_max_steps, + aa_int verbosity); /** * Apply Anderson Acceleration. The usage pattern should be as follows: diff --git a/python/aapy.pyx b/python/aapy.pyx index 3f3ccf6..877414b 100644 --- a/python/aapy.pyx +++ b/python/aapy.pyx @@ -19,7 +19,7 @@ cdef extern from "../include/aa.h": int last_rank double last_aa_norm double last_regularization - AaWork *aa_init(int, int, int, int, double, double, double, double, int, int) + AaWork *aa_init(int, int, int, int, double, double, double, double, double, int, int) double aa_apply(double*, const double*, AaWork*) int aa_safeguard(double*, double*, AaWork*) void aa_reset(AaWork*) @@ -32,7 +32,8 @@ cdef class AndersonAccelerator(object): def __cinit__(self, dim, mem, *, min_len=None, type1=False, regularization=1e-12, relaxation=1.0, safeguard_factor=1.0, - max_weight_norm=1e6, ir_max_steps=5, verbosity=0): + max_weight_norm=math.inf, trust_factor=math.inf, + ir_max_steps=5, verbosity=0): if dim <= 0: raise ValueError("dim must be positive") if mem < 0: @@ -47,8 +48,11 @@ cdef class AndersonAccelerator(object): raise ValueError("relaxation must be finite and in [0, 2]") if not math.isfinite(safeguard_factor) or safeguard_factor < 0: raise ValueError("safeguard_factor must be finite and non-negative") - if not math.isfinite(max_weight_norm) or max_weight_norm <= 0: - raise ValueError("max_weight_norm must be finite and positive") + # max_weight_norm and trust_factor: math.inf disables the cap. + if math.isnan(max_weight_norm) or max_weight_norm <= 0: + raise ValueError("max_weight_norm must be positive (inf for no cap)") + if math.isnan(trust_factor) or trust_factor <= 0: + raise ValueError("trust_factor must be positive (inf for off)") if ir_max_steps < 0: raise ValueError("ir_max_steps must be non-negative") # min_len: minimum # of residual pairs before AA starts extrapolating. @@ -61,7 +65,7 @@ cdef class AndersonAccelerator(object): raise ValueError("min_len must be >= 1 when mem > 0") self._wrk = aa_init(dim, mem, min_len, type1, regularization, relaxation, safeguard_factor, max_weight_norm, - ir_max_steps, verbosity) + trust_factor, ir_max_steps, verbosity) if self._wrk is NULL: raise MemoryError("aa_init failed") self._dim = dim diff --git a/src/aa.c b/src/aa.c index 6fe9c62..053a259 100644 --- a/src/aa.c +++ b/src/aa.c @@ -132,6 +132,27 @@ struct ACCEL_WORK { aa_float safeguard_factor; /* safeguard tolerance factor */ aa_float max_weight_norm; /* maximum norm of AA weights */ + /* Opt-in "trust region + adaptive regularization" mode. When trust_factor + * is a positive finite value, two coupled mechanisms turn on: + * 1. Trust region: each solve rejects the step if ||D γ||_2 > trust_factor + * * ||g||_2. This catches the failure mode where γ passes the LS solve + * and weight cap but the resulting iterate displacement is far larger + * than the current residual — common on slow-contraction maps where the + * LS basis points away from the descent direction. + * 2. Adaptive r: every safeguard or trust-region rejection grows + * r_adaptive by 10×; every accept shrinks it by 0.9×. r_adaptive + * *replaces* the ε·||S||·||Y|| baseline in compute_regularization. + * Starting r_adaptive = 1.0 makes initial γ ≈ 0 (AA ≈ DRS) and the + * shrink/grow feedback walks r to the right scale for the problem. + * This is what keeps LASER-style problems converging even when the + * trust region trips occasionally: each trip damps subsequent γ. + * + * trust_factor = INFINITY (the default for callers passing INFINITY) + * disables both mechanisms and the original ε·||S||·||Y|| regularization + * path is used unchanged. */ + aa_float trust_factor; + aa_float r_adaptive; + aa_float *x; /* x input to map*/ aa_float *f; /* f(x) output of map */ aa_float *g; /* x - f(x) */ @@ -230,17 +251,48 @@ static aa_float frob_from_col_norms(const aa_float *nrm_col, aa_int mem) { * fresh nrm2 over dim·mem entries. */ static aa_float compute_regularization(AaWork *a) { TIME_TIC - aa_float nrm_y = frob_from_col_norms(a->nrm_y_col, a->mem); - aa_float nrm_a = a->type1 ? frob_from_col_norms(a->nrm_s_col, a->mem) : nrm_y; - aa_float r = a->regularization * nrm_a * nrm_y; - if (a->verbosity > 2) { - printf("iter: %i, ||A||_F %.2e, ||Y||_F %.2e, r: %.2e\n", - (int)a->iter, nrm_a, nrm_y, r); + aa_float r; + if (isfinite(a->trust_factor)) { + /* Trust-region mode: r is the adaptive value walked by the + * accept/reject feedback in aa_safeguard and solve(). The column-norm + * baseline is bypassed entirely — it underflows on slow-contraction + * maps, which is the whole reason we're here. */ + r = a->r_adaptive; + if (a->verbosity > 2) { + printf("iter: %i, r_adaptive %.2e (trust mode)\n", (int)a->iter, r); + } + } else { + /* Default mode: ε · ||S||_F · ||Y||_F. Symmetric across types — the + * type-II side was ||Y||² before PR #53; ||S||_F · ||Y||_F keeps the + * LS conditioned without changing well-behaved cases where + * ||S|| ≈ ||Y||. */ + aa_float nrm_y = frob_from_col_norms(a->nrm_y_col, a->mem); + aa_float nrm_s = frob_from_col_norms(a->nrm_s_col, a->mem); + r = a->regularization * nrm_s * nrm_y; + if (a->verbosity > 2) { + printf("iter: %i, ||S||_F %.2e, ||Y||_F %.2e, r: %.2e\n", + (int)a->iter, nrm_s, nrm_y, r); + } } TIME_TOC return r; } +/* Helper: bump r_adaptive when AA produced a bad step. Called from every + * rejection path (in-solve and safeguard). */ +static void trust_grow(AaWork *a) { + if (!isfinite(a->trust_factor)) return; + a->r_adaptive *= 10.0; + if (a->r_adaptive > 1e30) a->r_adaptive = 1e30; +} + +/* Helper: shrink r_adaptive when AA's step survived the safeguard. */ +static void trust_shrink(AaWork *a) { + if (!isfinite(a->trust_factor)) return; + a->r_adaptive *= 0.9; + if (a->r_adaptive < 1e-12) a->r_adaptive = 1e-12; +} + /* Build [M; √r I_len] column-major into `dst` with fixed leading dim * (dim + mem). When len < mem we zero-pad the unused trailing rows so * the QR factorization still operates on a well-defined (dim+mem) x len @@ -586,6 +638,7 @@ static aa_float solve(aa_float *f, AaWork *a, aa_int len) { } else { a->n_reject_weight_cap++; } + trust_grow(a); /* in-solve rejection bumps r toward damped-AA */ a->success = 0; aa_reset(a); TIME_TOC @@ -593,6 +646,28 @@ static aa_float solve(aa_float *f, AaWork *a, aa_int len) { return (aa_norm < 0) ? aa_norm : -aa_norm; } + /* Trust region: only in trust mode. Compute ||D γ|| and reject if it + * exceeds trust_factor · ||g||. c_aug is dim+mem sized; γ has already + * been extracted into `gamma` (aliases a->work), and c_aug is rebuilt + * before being read again on the next iter — safe to scratch here. */ + if (isfinite(a->trust_factor)) { + aa_float zerof = 0.0; + aa_float d_gamma_norm; + BLAS(gemv) + ("NoTrans", &bdim, &blen, &onef, a->D, &bdim, gamma, &one, &zerof, + a->c_aug, &one); + d_gamma_norm = BLAS(nrm2)(&bdim, a->c_aug, &one); + if (isfinite(d_gamma_norm) && + d_gamma_norm > a->trust_factor * a->norm_g) { + a->n_reject_weight_cap++; + trust_grow(a); + a->success = 0; + aa_reset(a); + TIME_TOC + return -aa_norm; + } + } + /* f -= D γ */ BLAS(gemv) ("NoTrans", &bdim, &blen, &neg_onef, a->D, &bdim, gamma, &one, &onef, f, @@ -613,13 +688,16 @@ static aa_float solve(aa_float *f, AaWork *a, aa_int len) { AaWork *aa_init(aa_int dim, aa_int mem, aa_int min_len, aa_int type1, aa_float regularization, aa_float relaxation, aa_float safeguard_factor, aa_float max_weight_norm, - aa_int ir_max_steps, aa_int verbosity) { + aa_float trust_factor, aa_int ir_max_steps, + aa_int verbosity) { TIME_TIC AaWork *a; aa_int mem_clamped = MIN(mem, dim); /* `regularization` is accepted with either sign: positive = scaled by * ||A||_F ||Y||_F; negative = pinned absolute |regularization|; zero = off. * Only NaN / non-finite values are rejected (via the !isfinite check). + * `max_weight_norm` and `trust_factor` accept INFINITY as the "no cap" + * sentinel — NaN and non-positive are rejected. * min_len < 1 is rejected when mem > 0; min_len > mem_clamped is * silently clamped down — same treatment the `mem` argument already * gets against `dim`, so callers can pass `min_len = mem` without @@ -629,7 +707,8 @@ AaWork *aa_init(aa_int dim, aa_int mem, aa_int min_len, aa_int type1, !isfinite(regularization) || !isfinite(relaxation) || relaxation < 0 || relaxation > 2 || !isfinite(safeguard_factor) || safeguard_factor < 0 || - !isfinite(max_weight_norm) || max_weight_norm <= 0 || + isnan(max_weight_norm) || max_weight_norm <= 0 || + isnan(trust_factor) || trust_factor <= 0 || ir_max_steps < 0 || (mem_clamped > 0 && min_len < 1)) { printf("Invalid AA parameters.\n"); @@ -653,6 +732,12 @@ AaWork *aa_init(aa_int dim, aa_int mem, aa_int min_len, aa_int type1, a->relaxation = relaxation; a->safeguard_factor = safeguard_factor; a->max_weight_norm = max_weight_norm; + a->trust_factor = trust_factor; + /* Adaptive r initial value: start at 1.0 so γ ≈ 0 on the first solves + * (AA ≈ pure f-iteration). The shrink/grow feedback walks r toward + * whatever scale this particular problem needs. Only meaningful when + * trust_factor is finite. */ + a->r_adaptive = 1.0; a->ir_max_steps = ir_max_steps; a->success = 0; a->verbosity = verbosity; @@ -850,10 +935,12 @@ aa_int aa_safeguard(aa_float *f_new, aa_float *x_new, AaWork *a) { (int)a->iter, norm_diff, a->norm_g); } a->n_safeguard_reject++; + trust_grow(a); /* bad step → damp AA next time */ aa_reset(a); TIME_TOC return -1; } + trust_shrink(a); /* successful step → let AA do more next time */ TIME_TOC return 0; } diff --git a/tests/c/bench.c b/tests/c/bench.c index 8026c4c..c8f2cf1 100644 --- a/tests/c/bench.c +++ b/tests/c/bench.c @@ -135,7 +135,8 @@ static void run_one(bench_cfg cfg) { AaWork *a = aa_init(n, cfg.mem, /*min_len=*/cfg.mem, cfg.type1, cfg.regularization, cfg.relaxation, /*safeguard_factor=*/2.0, - /*max_weight_norm=*/1e10, /*ir_max_steps=*/5, + /*max_weight_norm=*/INFINITY, + /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); if (!a) { printf("%-20s | aa_init failed\n", cfg.label); @@ -214,7 +215,8 @@ int main(void) { aa_float *x = (aa_float *)malloc(sizeof(aa_float) * wd); aa_float *xprev = (aa_float *)malloc(sizeof(aa_float) * wd); for (aa_int i = 0; i < wd; i++) { x[i] = 1.0; xprev[i] = 0.0; } - AaWork *a = aa_init(wd, wm, /*min_len=*/wm, 0, 1e-12, 1.0, 2.0, 1e10, 5, 0); + AaWork *a = aa_init(wd, wm, /*min_len=*/wm, 0, 1e-12, 1.0, 2.0, INFINITY, + /*trust_factor=*/INFINITY, 5, 0); for (aa_int i = 0; i < wi; i++) { if (i > 0) aa_apply(x, xprev, a); memcpy(xprev, x, sizeof(aa_float) * wd); diff --git a/tests/c/gd.c b/tests/c/gd.c index 0648a04..01ec640 100644 --- a/tests/c/gd.c +++ b/tests/c/gd.c @@ -1,6 +1,7 @@ /* Gradient descent (GD) on convex quadratic */ #include "aa.h" #include "aa_blas.h" +#include #include #include #include @@ -144,7 +145,8 @@ int main(int argc, char **argv) { AaWork *a = aa_init(n, memory, /*min_len=*/memory, type1, regularization, relaxation, safeguard_tolerance, - max_aa_norm, /*ir_max_steps=*/5, verbosity); + max_aa_norm, /*trust_factor=*/INFINITY, + /*ir_max_steps=*/5, verbosity); for (i = 0; i < iters; i++) { if (i > 0) { _tic(&aa_timer); diff --git a/tests/c/run_tests.c b/tests/c/run_tests.c index db379cd..f3f634c 100644 --- a/tests/c/run_tests.c +++ b/tests/c/run_tests.c @@ -137,7 +137,7 @@ static const char *gd(aa_int type1, aa_float relaxation) { AaWork *a = aa_init(n, memory, /*min_len=*/memory, type1, regularization, relaxation, safeguard_tolerance, - max_aa_norm, /*ir_max_steps=*/5, verbosity); + max_aa_norm, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, verbosity); for (i = 0; i < iters; i++) { if (i > 0) { _tic(&aa_timer); @@ -197,7 +197,7 @@ static aa_float diag_gd(const aa_float *Qdiag, aa_int n, aa_float step, x[i] = rand_float(); } AaWork *a = aa_init(n, mem, /*min_len=*/mem, type1, /*reg=*/1e-10, - relaxation, /*safeguard=*/2.0, /*max_w=*/1e10, + relaxation, /*safeguard=*/2.0, /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); for (aa_int i = 0; i < iters; i++) { if (i > 0) { @@ -222,7 +222,7 @@ static const char *test_mem_zero_is_noop(void) { /* min_len is ignored when mem=0 — pass the sentinel 0 to make that * explicit (a nonzero value would also be accepted since the range * check is gated on mem>0). */ - AaWork *a = aa_init(10, 0, /*min_len=*/0, 1, 1e-8, 1.0, 2.0, 1e10, 5, 0); + AaWork *a = aa_init(10, 0, /*min_len=*/0, 1, 1e-8, 1.0, 2.0, 1e10, INFINITY, 5, 0); mu_assert("aa_init(mem=0) returned NULL", a != NULL); aa_float x[10], xprev[10]; @@ -247,14 +247,14 @@ static const char *test_mem_zero_is_noop(void) { } static const char *test_dim_zero_rejected(void) { - AaWork *a = aa_init(0, 1, /*min_len=*/1, 1, 1e-8, 1.0, 2.0, 1e10, 5, 0); + AaWork *a = aa_init(0, 1, /*min_len=*/1, 1, 1e-8, 1.0, 2.0, 1e10, INFINITY, 5, 0); mu_assert("aa_init(dim=0) should return NULL", a == NULL); return 0; } /* Negative ir_max_steps is invalid and must be rejected at init. */ static const char *test_ir_max_steps_negative_rejected(void) { - AaWork *a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, 1e10, -1, 0); + AaWork *a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, 1e10, INFINITY, -1, 0); mu_assert("aa_init(ir_max_steps=-1) should return NULL", a == NULL); return 0; } @@ -264,20 +264,36 @@ static const char *test_ir_max_steps_negative_rejected(void) { static const char *test_nonfinite_scalar_options_rejected(void) { AaWork *a; - a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, NAN, 2.0, 1e10, 5, 0); + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, NAN, 2.0, 1e10, INFINITY, 5, 0); mu_assert("aa_init(relaxation=NaN) should return NULL", a == NULL); - a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, INFINITY, 2.0, 1e10, 5, 0); + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, INFINITY, 2.0, 1e10, INFINITY, 5, 0); mu_assert("aa_init(relaxation=Inf) should return NULL", a == NULL); - a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, NAN, 1e10, 5, 0); + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, NAN, 1e10, INFINITY, 5, 0); mu_assert("aa_init(safeguard_factor=NaN) should return NULL", a == NULL); - a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, INFINITY, 1e10, 5, 0); + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, INFINITY, 1e10, INFINITY, 5, 0); mu_assert("aa_init(safeguard_factor=Inf) should return NULL", a == NULL); - a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, NAN, 5, 0); + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, NAN, INFINITY, 5, 0); mu_assert("aa_init(max_weight_norm=NaN) should return NULL", a == NULL); - a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, INFINITY, 5, 0); - mu_assert("aa_init(max_weight_norm=Inf) should return NULL", a == NULL); + /* INFINITY is the "no cap" sentinel for max_weight_norm — must accept. */ + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, INFINITY, INFINITY, 5, 0); + mu_assert("aa_init(max_weight_norm=Inf) should accept", a != NULL); + aa_finish(a); + + /* trust_factor rejects NaN and non-positive; INFINITY = no cap (default). */ + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, 1e10, NAN, 5, 0); + mu_assert("aa_init(trust_factor=NaN) should return NULL", a == NULL); + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, 1e10, 0.0, 5, 0); + mu_assert("aa_init(trust_factor=0) should return NULL", a == NULL); + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, 1e10, -1.0, 5, 0); + mu_assert("aa_init(trust_factor<0) should return NULL", a == NULL); + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, 1e10, INFINITY, 5, 0); + mu_assert("aa_init(trust_factor=Inf) should accept", a != NULL); + aa_finish(a); + a = aa_init(4, 2, /*min_len=*/2, 1, 1e-8, 1.0, 2.0, 1e10, 10.0, 5, 0); + mu_assert("aa_init(trust_factor=10) should accept", a != NULL); + aa_finish(a); return 0; } @@ -297,7 +313,7 @@ static const char *test_ir_max_steps_zero_still_solves(void) { for (aa_int i = 0; i < n; i++) x[i] = rand_float(); AaWork *a = aa_init(n, /*mem=*/5, /*min_len=*/5, /*type1=*/1, /*reg=*/1e-10, /*relax=*/1.0, /*safeguard=*/2.0, - /*max_w=*/1e10, /*ir_max_steps=*/0, /*verbosity=*/0); + /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/0, /*verbosity=*/0); mu_assert("aa_init(ir_max_steps=0) must accept", a != NULL); for (aa_int i = 0; i < 500; i++) { if (i > 0) aa_apply(x, xprev, a); @@ -337,7 +353,7 @@ static const char *test_ir_max_steps_no_regression_on_ill_conditioned(void) { for (aa_int i = 0; i < n; i++) x[i] = rand_float(); AaWork *a = aa_init(n, /*mem=*/10, /*min_len=*/10, /*type1=*/0, /*reg=*/1e-10, /*relax=*/1.0, /*safeguard=*/2.0, - /*max_w=*/1e10, caps[k], /*verbosity=*/0); + /*max_w=*/1e10, /*trust_factor=*/INFINITY, caps[k], /*verbosity=*/0); for (aa_int i = 0; i < 2000; i++) { if (i > 0) aa_apply(x, xprev, a); memcpy(xprev, x, n * sizeof(aa_float)); @@ -408,7 +424,7 @@ static const char *test_reset_matches_fresh_init(void) { aa_float step = 1.0; aa_float x0[5] = {1, 1, 1, 1, 1}; - AaWork *a = aa_init(n, 3, /*min_len=*/3, 1, 1e-10, 1.0, 2.0, 1e10, 5, 0); + AaWork *a = aa_init(n, 3, /*min_len=*/3, 1, 1e-10, 1.0, 2.0, 1e10, INFINITY, 5, 0); /* First run: 20 iters from x0. */ aa_float x[5], xprev[5]; @@ -443,7 +459,7 @@ static const char *test_reset_matches_fresh_init(void) { /* reset must clear any "last AA step succeeded" state so a subsequent * safeguard call cannot roll inputs back to pre-reset iterates. */ static const char *test_reset_clears_stale_safeguard_state(void) { - AaWork *a = aa_init(2, 2, /*min_len=*/2, 1, 1e-8, 1.0, 1.0, 1e10, 5, 0); + AaWork *a = aa_init(2, 2, /*min_len=*/2, 1, 1e-8, 1.0, 1.0, 1e10, INFINITY, 5, 0); aa_float x[2] = {1.0, 1.0}; aa_float f[2] = {0.5, 0.5}; @@ -572,7 +588,7 @@ static const char *test_zero_reg_near_singular_y(void) { for (aa_int i = 0; i < n; i++) x[i] = rand_float(); AaWork *a = aa_init(n, /*mem=*/10, /*min_len=*/10, /*type1=*/0, /*reg=*/0.0, /*relax=*/1.0, /*safeguard=*/2.0, - /*max_w=*/1e10, /*ir_max_steps=*/5, /*verbosity=*/0); + /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); for (aa_int i = 0; i < 2000; i++) { if (i > 0) aa_apply(x, xprev, a); memcpy(xprev, x, n * sizeof(aa_float)); @@ -648,7 +664,7 @@ static aa_float diag_gd_with_reg(const aa_float *Qdiag, aa_int n, aa_float step, srand(seed); for (aa_int i = 0; i < n; i++) x_out[i] = rand_float(); AaWork *a = aa_init(n, mem, /*min_len=*/mem, type1, reg, /*relax=*/1.0, - /*safeguard=*/2.0, /*max_w=*/1e10, + /*safeguard=*/2.0, /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); for (aa_int i = 0; i < iters; i++) { if (i > 0) aa_apply(x_out, xprev, a); @@ -731,7 +747,7 @@ static const char *test_min_len_one_accelerates_early(void) { for (aa_int i = 0; i < n; i++) x[i] = rand_float(); AaWork *a = aa_init(n, /*mem=*/10, /*min_len=*/1, /*type1=*/0, /*reg=*/1e-10, /*relax=*/1.0, /*safeguard=*/2.0, - /*max_w=*/1e10, /*ir_max_steps=*/5, /*verbosity=*/0); + /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); mu_assert("aa_init(min_len=1) returned NULL", a != NULL); aa_int applies = 0; for (aa_int i = 0; i < 100; i++) { @@ -772,7 +788,7 @@ static const char *test_min_len_half_mem(void) { for (aa_int i = 0; i < 10; i++) x[i] = rand_float(); AaWork *a = aa_init(10, /*mem=*/8, /*min_len=*/4, /*type1=*/1, /*reg=*/1e-8, /*relax=*/1.0, /*safeguard=*/2.0, - /*max_w=*/1e10, /*ir_max_steps=*/5, /*verbosity=*/0); + /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); mu_assert("aa_init(min_len=4) returned NULL", a != NULL); for (aa_int i = 0; i < 300; i++) { if (i > 0) aa_apply(x, xprev, a); @@ -805,7 +821,7 @@ static const char *test_min_len_equal_mem_matches_default(void) { /* min_len=0 with mem>0 must be rejected. */ static const char *test_min_len_zero_rejected(void) { AaWork *a = aa_init(5, /*mem=*/3, /*min_len=*/0, 1, 1e-8, 1.0, - 2.0, 1e10, 5, 0); + 2.0, 1e10, INFINITY, 5, 0); mu_assert("aa_init(min_len=0, mem=3) should return NULL", a == NULL); return 0; } @@ -813,7 +829,7 @@ static const char *test_min_len_zero_rejected(void) { /* min_len is ignored when mem=0 — any value (incl 0 or nonsense) accepted. */ static const char *test_min_len_ignored_when_mem_zero(void) { AaWork *a = aa_init(5, /*mem=*/0, /*min_len=*/999, 1, 1e-8, 1.0, - 2.0, 1e10, 5, 0); + 2.0, 1e10, INFINITY, 5, 0); mu_assert("aa_init(mem=0) should accept any min_len", a != NULL); aa_finish(a); return 0; @@ -832,7 +848,7 @@ static const char *test_min_len_exceeding_mem_is_clamped(void) { for (aa_int i = 0; i < 10; i++) x[i] = rand_float(); AaWork *a = aa_init(10, /*mem=*/5, /*min_len=*/50, /*type1=*/0, /*reg=*/1e-10, /*relax=*/1.0, /*safeguard=*/2.0, - /*max_w=*/1e10, /*ir_max_steps=*/5, /*verbosity=*/0); + /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); mu_assert("aa_init should clamp min_len > mem silently", a != NULL); for (aa_int i = 0; i < 300; i++) { if (i > 0) aa_apply(x, xprev, a); @@ -867,7 +883,7 @@ static const char *test_min_len_survives_safeguard_churn(void) { for (aa_int i = 0; i < n; i++) x[i] = rand_float(); AaWork *a = aa_init(n, /*mem=*/10, /*min_len=*/1, /*type1=*/0, /*reg=*/1e-10, /*relax=*/1.0, /*safeguard=*/1.0, - /*max_w=*/1e10, /*ir_max_steps=*/5, /*verbosity=*/0); + /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); mu_assert("aa_init(min_len=1) returned NULL", a != NULL); for (aa_int i = 0; i < 2000; i++) { if (i > 0) aa_apply(x, xprev, a); @@ -889,7 +905,7 @@ static const char *test_min_len_survives_safeguard_churn(void) { * unconditionally call aa_apply without branching. */ static const char *test_first_iter_is_noop_on_f(void) { const aa_int n = 4; - AaWork *a = aa_init(n, 3, /*min_len=*/3, 1, 1e-8, 1.0, 2.0, 1e10, 5, 0); + AaWork *a = aa_init(n, 3, /*min_len=*/3, 1, 1e-8, 1.0, 2.0, 1e10, INFINITY, 5, 0); aa_float x[4] = {0.1, 0.2, 0.3, 0.4}; aa_float xprev[4] = {0, 0, 0, 0}; aa_float snapshot[4]; @@ -925,7 +941,7 @@ static const char *test_stats_basic_counters(void) { for (aa_int i = 0; i < n; i++) x[i] = rand_float(); AaWork *a = aa_init(n, /*mem=*/5, /*min_len=*/5, /*type1=*/0, /*reg=*/1e-12, /*relax=*/1.0, /*safeguard=*/2.0, - /*max_w=*/1e10, /*ir_max_steps=*/5, /*verbosity=*/0); + /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); mu_assert("aa_init returned NULL", a != NULL); AaStats s = aa_get_stats(a); @@ -979,7 +995,7 @@ static const char *test_stats_survive_safeguard_reject(void) { * a safeguard rejection on the first post-solve iteration. */ AaWork *a = aa_init(n, /*mem=*/2, /*min_len=*/2, /*type1=*/0, /*reg=*/1e-12, /*relax=*/1.0, /*safeguard=*/0.01, - /*max_w=*/1e10, /*ir_max_steps=*/5, /*verbosity=*/0); + /*max_w=*/1e10, /*trust_factor=*/INFINITY, /*ir_max_steps=*/5, /*verbosity=*/0); mu_assert("aa_init returned NULL", a != NULL); /* Hand-driven sequence chosen to land in the safeguard reject branch. */ diff --git a/tests/python/test_aa.py b/tests/python/test_aa.py index 3272daf..bb48be3 100644 --- a/tests/python/test_aa.py +++ b/tests/python/test_aa.py @@ -99,9 +99,14 @@ def test_construct_dim_one(): (dict(dim=DIM, mem=MEM, safeguard_factor=float("nan")), "safeguard_factor must be finite"), (dict(dim=DIM, mem=MEM, safeguard_factor=float("inf")), "safeguard_factor must be finite"), (dict(dim=DIM, mem=MEM, safeguard_factor=-1.0), "safeguard_factor must be finite and non-negative"), - (dict(dim=DIM, mem=MEM, max_weight_norm=float("nan")), "max_weight_norm must be finite"), - (dict(dim=DIM, mem=MEM, max_weight_norm=float("inf")), "max_weight_norm must be finite"), - (dict(dim=DIM, mem=MEM, max_weight_norm=0.0), "max_weight_norm must be finite and positive"), + (dict(dim=DIM, mem=MEM, max_weight_norm=float("nan")), "max_weight_norm must be positive"), + # inf is now accepted as the "no cap" sentinel for max_weight_norm + (dict(dim=DIM, mem=MEM, max_weight_norm=0.0), "max_weight_norm must be positive"), + (dict(dim=DIM, mem=MEM, max_weight_norm=-1.0), "max_weight_norm must be positive"), + # trust_factor: same convention — inf accepted, NaN / non-positive rejected + (dict(dim=DIM, mem=MEM, trust_factor=float("nan")), "trust_factor must be positive"), + (dict(dim=DIM, mem=MEM, trust_factor=0.0), "trust_factor must be positive"), + (dict(dim=DIM, mem=MEM, trust_factor=-1.0), "trust_factor must be positive"), ], ) def test_construct_invalid_args_raise_value_error(kwargs, message):