diff --git a/docs/configuration.rst b/docs/configuration.rst index 441a10832..26b99b0b5 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -2270,6 +2270,21 @@ specific solver are defined in the relevant section below. Number of corrector steps for the predictor-corrector linear solver. 0 means a pure linear solve with no corrector steps. Must be a positive integer. +``atol`` (float | None [default = None]) + Absolute tolerance for fixed-point iterations in the predictor-corrector solver. + If specified, iterations can exit early when the normalized residual falls below this threshold. + +``rtol`` (float | None [default = None]) + Relative tolerance for fixed-point iterations in the predictor-corrector solver. + If specified, iterations can exit early when the normalized residual falls below this fraction of the initial residual. + +``use_backtracking`` (bool [default = True]) + Enables backtracking linesearch to improve stability. Can be used with any + solver. For the Newton-Raphson solver, this option is always enforced as True. + +``delta_reduction_factor`` (float [default = 0.5]) + Factor by which the step size is reduced during backtracking. + ``use_pereverzev`` (bool [default = False]) Use Pereverzev-Corrigan terms in the heat and particle flux when using the linear solver. Critical for stable calculation of stiff transport, at the cost diff --git a/docs/solver_details.rst b/docs/solver_details.rst index 2ed78b086..3072b7edb 100644 --- a/docs/solver_details.rst +++ b/docs/solver_details.rst @@ -234,6 +234,24 @@ these coefficients become known at every iteration step, describing a `linear` system of equations. :math:`\mathbf{x}_{t+\Delta t}^k` can then be solved using standard linear algebra methods implemented in JAX. +Optionally, the fixed-point iteration can be configured to terminate early +once a specified tolerance is achieved, rather than running for a fixed number +of iterations. This is controlled by user-configurable absolute and relative +tolerances on the residual norm, denoted by :math:`\varepsilon_{abs}` and +:math:`\varepsilon_{rel}` respectively. The solve iterates until the normalized +residual falls below the absolute tolerance +:math:`\| \mathbf{R} \|_{norm} < \varepsilon_{abs}` or becomes smaller than the +relative tolerance multiplied by the initial residual, i.e., +:math:`\| \mathbf{R} \|_{norm} < \varepsilon_{rel} \| \mathbf{R}_{0} \|_{norm}`. + +Additionally, a backtracking linesearch can be used to improve stability in +the solvers. When enabled in fixed-point iteration, if an iteration results +in an increase in the residual or an invalid state (e.g., NaN values), the +solver will backtrack along the update direction by reducing the step size. +For the Newton-Raphson solver, this backtracking linesearch is always required +and enforced to ensure robustness. + + To further enhance the stability of the linear solver, particularly in the presence of stiff transport coefficients (e.g., when using the QLKNN turbulent transport model, see :ref:`physics_models`), the |pereverzev-corrigan-method| diff --git a/torax/_src/solver/anderson.py b/torax/_src/solver/anderson.py new file mode 100644 index 000000000..3bc798d8a --- /dev/null +++ b/torax/_src/solver/anderson.py @@ -0,0 +1,215 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Anderson acceleration with safeguarding. + +References: + [1] D.G. Anderson, "Iterative procedures for nonlinear integral + equations," J. ACM, 12(4):547-560, 1965. + [2] H.F. Walker and P. Ni, "Anderson acceleration for fixed-point + iterations," SIAM J. Numer. Anal., 49(4):1715-1735, 2011. + [3] A. Toth and C.T. Kelley, "Convergence analysis for Anderson + acceleration," SIAM J. Numer. Anal., 53(2):805-819, 2015. + [4] J. Zhang, B. O'Donoghue, and S. Boyd, "Globally convergent type-I + Anderson acceleration for non-smooth fixed-point iterations," + SIAM J. Optim., 30(4):3170-3197, 2020. + [5] "Safeguarded Anderson acceleration for parametric nonexpansive + operators". +""" + +import dataclasses +from typing import Callable + +import jax +import jax.numpy as jnp + + +@dataclasses.dataclass(frozen=True) +class AndersonSettings: + """Configuration for Anderson acceleration. + + Attributes: + window_size: Number of retained iterates (m). Set to 0 to disable Anderson + acceleration. + safeguard_eta: Accept the AA candidate only if the residual norm is at most + eta times the current residual norm. eta=1 means "don't make it worse". + Values > 1 allow temporary increases. + regularization: Tikhonov regularization for the least-squares solve, scaled + by Gram matrix average diagonal. + beta: Relaxation parameter in (0, 1] + """ + + window_size: int = 5 + safeguard_eta: float = 1.0 + regularization: float = 1e-10 + beta: float = 1.0 + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class AndersonHistory: + """Circular buffer of past iterates and their fixed-point residuals.""" + + x_history: jnp.ndarray + f_history: jnp.ndarray + count: jnp.ndarray + + @classmethod + def create(cls, n: int, window_size: int, dtype): + """Creates an empty Anderson history buffer.""" + return cls( + x_history=jnp.zeros((window_size, n), dtype=dtype), + f_history=jnp.zeros((window_size, n), dtype=dtype), + count=jnp.array(0, dtype=jnp.int32), + ) + + def push(self, x_k: jnp.ndarray, f_k: jnp.ndarray, window_size: int): + """Pushes a new (x, f) pair into the circular buffer.""" + idx = self.count % window_size + return AndersonHistory( + x_history=self.x_history.at[idx].set(x_k), + f_history=self.f_history.at[idx].set(f_k), + count=self.count + 1, + ) + + def update( + self, + accepted: jnp.ndarray, + x: jnp.ndarray, + picard_step: jnp.ndarray, + settings: AndersonSettings, + ) -> 'AndersonHistory': + """Updates Anderson history based on acceptance. + + If accepted, pushes the current step to the history. If rejected, resets + the history and pushes the current step to the fresh history. + + Args: + accepted: True if the Anderson step was accepted. + x: Current iterate. + picard_step: Current fixed-point residual (or Picard step). + settings: Anderson configuration. + + Returns: + The updated history. + """ + base_history = jax.lax.cond( + accepted, + lambda _: self, + lambda _: AndersonHistory.create( + x.shape[0], settings.window_size, dtype=x.dtype + ), + operand=None, + ) + return base_history.push(x, picard_step, settings.window_size) + + def get_deltas( + self, x_k: jnp.ndarray, f_k: jnp.ndarray, window_size: int + ) -> tuple[jnp.ndarray, jnp.ndarray]: + """Computes Delta_F and Delta_X matrices from history.""" + + def _get_deltas(i): + hist_idx = (self.count - 1 - i) % window_size + df = f_k - self.f_history[hist_idx] + dx = x_k - self.x_history[hist_idx] + return df, dx + + indices = jnp.arange(window_size) + all_df, all_dx = jax.vmap(_get_deltas)(indices) + return all_df, all_dx + + +def _compute_candidate( + history: AndersonHistory, + x_k: jnp.ndarray, + f_k: jnp.ndarray, + settings: AndersonSettings, +) -> jnp.ndarray: + """Computes the Anderson acceleration candidate.""" + # Following Walker and Ni [2], section 1 and section 3. + m = settings.window_size + m_actual = jnp.minimum(history.count, m) + beta = settings.beta + + # Following (1.2): Compute residual differences and step differences. + # history class puts them in the correct temporal order. + all_df, all_dx = history.get_deltas(x_k, f_k, m) + # Mask unused columns. + indices = jnp.arange(m) + mask = (indices < m_actual).astype(f_k.dtype) + masked_df = all_df * mask[:, None] + + # Following Eq 3.1, finding the least squares solution of + # ||f_k - Delta_F @ gamma||^2 + rhs = masked_df @ f_k + lhs_raw = masked_df @ masked_df.T # (m, m) + # We want to regularize the least squares solution, as the problem can be + # ill-conditioned. Parametrize with the average of the trace. + trace = jnp.trace(lhs_raw) + regularizer = ( + settings.regularization * (trace / jnp.maximum(m_actual, 1)) + 1e-14 + ) + lhs = lhs_raw + regularizer * jnp.eye(m, dtype=f_k.dtype) + gamma = jnp.linalg.solve(lhs, rhs) + gamma = gamma * mask # Zero out unused coefficients. + + # From eq. 3.1, subtract a weighted sum of the iterates. + damped_picard_step = x_k + beta * f_k + correction = (all_dx + beta * all_df).T @ gamma + candidate = damped_picard_step - correction + return candidate + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class Result: + """Result of attempting a safeguarded Anderson acceleration step.""" + + candidate: jnp.ndarray + residual: jnp.ndarray + residual_norm: jnp.ndarray + accepted: jnp.ndarray + + +def try_step( + x: jnp.ndarray, + picard_step: jnp.ndarray, + residual_fn: Callable[[jnp.ndarray], jnp.ndarray], + current_residual_norm: jnp.ndarray, + current_history: AndersonHistory, + settings: AndersonSettings, +) -> Result: + """Attempts an Anderson acceleration step with safeguarding.""" + + m_actual = jnp.minimum(current_history.count, settings.window_size) + candidate = jax.lax.cond( + m_actual >= 1, + lambda _: _compute_candidate(current_history, x, picard_step, settings), + lambda _: x + settings.beta * picard_step, + operand=None, + ) + res = residual_fn(candidate) + res_norm = jnp.linalg.norm(res) + + # Safeguard ([5], Eq. 13): accept only if the residual norm does not + # increase by more than a factor eta. + safeguard_threshold = settings.safeguard_eta * current_residual_norm + accepted = res_norm <= safeguard_threshold + + return Result( + candidate=candidate, + residual=res, + residual_norm=res_norm, + accepted=accepted, + ) diff --git a/torax/_src/solver/jax_fixed_point.py b/torax/_src/solver/jax_fixed_point.py index ee171ec4d..80e1d3c09 100644 --- a/torax/_src/solver/jax_fixed_point.py +++ b/torax/_src/solver/jax_fixed_point.py @@ -16,8 +16,11 @@ from typing import Any, Callable, Literal, TypeAlias import jax +import jax.flatten_util import jax.numpy as jnp from torax._src import jax_utils +from torax._src.solver import anderson +from torax._src.solver import linesearch PyTree: TypeAlias = Any @@ -28,7 +31,13 @@ def fixed_point( args: tuple[PyTree, ...] = (), xtol: float | None = 1e-08, maxiter: int = 500, - method: Literal['del2', 'iteration'] = 'del2', + method: Literal['del2', 'iteration', 'anderson'] = 'del2', + atol: float | None = None, + rtol: float | None = None, + use_backtracking: bool = False, + delta_reduction_factor: float = 0.5, + max_backtrack_steps: int = 10, + anderson_settings: anderson.AndersonSettings | None = None, ) -> PyTree: """A JAX version of `scipy.optimize.fixed_point`. @@ -44,21 +53,44 @@ def fixed_point( tolerance is used and `maxiter` iterations will be performed. maxiter: The maximum number of iterations to perform. method: The method to use. 'del2' (the default) uses Steffensen’s Method - with Aitken’s Del^2 convergence acceleration, taken from Burden, Faires, - “Numerical Analysis”, 5th edition, pg. 80. 'iteration' just iterates the - function until the tolerance is reached. + with Aitken’s Del^2 convergence acceleration. 'iteration' just iterates + the function until the tolerance is reached. 'anderson' uses Anderson + acceleration with safeguarding. + atol: Absolute tolerance on the residual norm. + rtol: Relative tolerance on the residual norm. + use_backtracking: If true, use backtracking linesearch in 'iteration' method + or as fallback in 'anderson' method. + delta_reduction_factor: Factor by which step_size is reduced each step. + max_backtrack_steps: Maximum number of backtracking steps. + anderson_settings: Settings for Anderson acceleration. Only used if method + is 'anderson'. Returns: The fixed point `jax.Array`. """ - if method not in ['del2', 'iteration']: + if method not in ['del2', 'iteration', 'anderson']: raise ValueError(f'Invalid method: {method}') if maxiter <= 0: raise ValueError(f'Invalid maxiter: {maxiter} must be positive.') - def body(x): - x, count, _ = x + def residual_fn(x): + return jax.tree.map(lambda a, b: a - b, func(x, *args), x) + + def norm_fn(res): + return jnp.sqrt(sum(jnp.sum(leaf**2) for leaf in jax.tree.leaves(res))) + + def residual_norm(x): + return norm_fn(residual_fn(x)) + + if rtol is not None: + initial_residual_norm = residual_norm(x0) + else: + initial_residual_norm = jnp.array(0.0) + + def body(x_state): + x, count, _, history = x_state out1 = func(x, *args) + if method == 'del2': out2 = func(out1, *args) @@ -68,10 +100,117 @@ def _del2(p0, p1, p2): return jax.lax.select(d != 0, out3, p2) out = jax.tree.map(_del2, x, out1, out2) - else: - out = out1 - - if xtol: + new_history = history + elif method == 'iteration': + if use_backtracking: + direction = jax.tree.map(lambda a, b: a - b, out1, x) + + init_res = direction + init_norm = norm_fn(init_res) + + decrease = 1e-4 + current_norm_sq = init_norm**2 + + def accept_fn(step_size, trial_norm): + target = (1.0 - 2.0 * decrease * step_size) * current_norm_sq + return (trial_norm**2) <= target + + ls_state = linesearch.backtracking_linesearch( + residual_fn=residual_fn, + x_init=x, + direction=direction, + accept_fn=accept_fn, + norm_fn=norm_fn, + initial_residual=init_res, + initial_residual_norm=init_norm, + delta_reduction_factor=delta_reduction_factor, + max_steps=max_backtrack_steps, + ) + out = ls_state.x + else: + out = out1 + new_history = history + elif method == 'anderson': + # out1 is func(x, *args) = G(x) + # residual is G(x) - x + res = jax.tree.map(lambda a, b: a - b, out1, x) + + flat_x, unflatten = jax.flatten_util.ravel_pytree(x) + flat_res, _ = jax.flatten_util.ravel_pytree(res) + + current_norm = norm_fn(res) + + def residual_fn_flat(flat_candidate): + cand = unflatten(flat_candidate) + cand_out = func(cand, *args) + cand_res = jax.tree.map(lambda a, b: a - b, cand_out, cand) + flat_cand_res, _ = jax.flatten_util.ravel_pytree(cand_res) + return flat_cand_res + + aa_res = anderson.try_step( + x=flat_x, + picard_step=flat_res, + residual_fn=residual_fn_flat, + current_residual_norm=current_norm, + current_history=history, + settings=anderson_settings, + ) + + def accept_aa(): + return unflatten(aa_res.candidate), jnp.array(True) + + def reject_aa(): + if use_backtracking: + direction = res + init_res = direction + init_norm = current_norm + decrease = 1e-4 + current_norm_sq = init_norm**2 + + def ls_accept_fn(step_size, trial_norm): + target = (1.0 - 2.0 * decrease * step_size) * current_norm_sq + return (trial_norm**2) <= target + + ls_state = linesearch.backtracking_linesearch( + residual_fn=residual_fn, + x_init=x, + direction=direction, + accept_fn=ls_accept_fn, + norm_fn=norm_fn, + initial_residual=init_res, + initial_residual_norm=init_norm, + delta_reduction_factor=delta_reduction_factor, + max_steps=max_backtrack_steps, + ) + return ls_state.x, jnp.array(False) + else: + return out1, jnp.array(False) + + out, accepted = jax.lax.cond( + aa_res.accepted, + lambda _: accept_aa(), + lambda _: reject_aa(), + operand=None, + ) + + new_history = history.update( + accepted, + flat_x, + flat_res, + anderson_settings, + ) + + # Terminate based on residual norm. + if atol is not None or rtol is not None: + res_norm = residual_norm(out) + converged = jnp.array(False, dtype=jnp.bool_) + if atol is not None: + converged = converged | (res_norm <= atol) + if rtol is not None: + converged = converged | (res_norm <= rtol * initial_residual_norm) + stop = converged + # Terminate based on relative error. + elif xtol: def _relative_error(actual, expected): relative_error = (actual - expected) / expected @@ -83,17 +222,29 @@ def _relative_error(actual, expected): else: stop = jnp.array(False, dtype=jnp.bool_) count += 1 - return out, count, stop + return out, count, stop, new_history - def cond(x): - _, count, stop = x + def cond(x_state): + _, count, stop, _ = x_state return jnp.logical_not(stop) & (count < maxiter) count = jnp.array(0, dtype=jax_utils.get_int_dtype()) stop = jnp.array(False, dtype=jnp.bool_) - x_init = (x0, count, stop) - if xtol is None: + if method == 'anderson': + if anderson_settings is None: + anderson_settings = anderson.AndersonSettings() + flat_x0, _ = jax.flatten_util.ravel_pytree(x0) + n = flat_x0.shape[0] + history = anderson.AndersonHistory.create( + n, anderson_settings.window_size, dtype=flat_x0.dtype + ) + else: + history = anderson.AndersonHistory.create(1, 1, dtype=jnp.float32) # dummy + + x_init = (x0, count, stop, history) + + if xtol is None and atol is None and rtol is None: return jax.lax.fori_loop(0, maxiter, lambda i, val: body(val), x_init)[0] else: return jax.lax.while_loop(cond, body, x_init)[0] diff --git a/torax/_src/solver/jax_root_finding.py b/torax/_src/solver/jax_root_finding.py index 4c8fd995e..f8172a38c 100644 --- a/torax/_src/solver/jax_root_finding.py +++ b/torax/_src/solver/jax_root_finding.py @@ -13,6 +13,7 @@ # limitations under the License. """JAX root finding functions.""" + import dataclasses import functools from typing import Callable, Final @@ -21,6 +22,7 @@ import jax.numpy as jnp import numpy as np from torax._src import jax_utils +from torax._src.solver import linesearch # Delta is a vector. If no entry of delta is above this magnitude, we terminate # the delta loop. This is to avoid getting stuck in an infinite loop in edge @@ -123,9 +125,7 @@ def back(g, y): if use_jax_custom_root: if custom_jac is not None: - raise ValueError( - 'custom_jac is not compatible with use_jax_custom_root.' - ) + raise ValueError('custom_jac is not compatible with use_jax_custom_root.') x_out, metadata = jax.lax.custom_root( f=fun, initial_guess=x0, @@ -199,109 +199,44 @@ def _body( dtype = input_state['x'].dtype a_mat = jacobian_fun(input_state['x']) rhs = -input_state['residual'] - # delta = x_new - x_old - # tau = delta/delta0, where delta0 is the delta that sets the linearized - # residual to zero. tau < 1 when needed such that x_new meets - # conditions of reduced residual and valid state quantities. - # If tau < taumin while residual > tol, then the routine exits with an - # error flag, leading to either a warning or recalculation at lower dt - initial_delta_state = { - 'x': input_state['x'], - 'delta': jnp.linalg.solve(a_mat, rhs), - 'residual_old': input_state['residual'], - 'residual_new': input_state['residual'], - 'tau': jnp.array(1.0, dtype=dtype), - } - output_delta_state = _compute_output_delta_state( - initial_delta_state, residual_fun, delta_reduction_factor + + direction = jnp.linalg.solve(a_mat, rhs) + + def norm_fn(res): + return jnp.mean(jnp.abs(res)) + + init_norm = norm_fn(input_state['residual']) + + def accept_fn(step_size, trial_norm): + del step_size # Unused + return (trial_norm <= init_norm) & (~jnp.isnan(trial_norm)) + + ls_state = linesearch.backtracking_linesearch( + residual_fn=residual_fun, + x_init=input_state['x'], + direction=direction, + accept_fn=accept_fn, + norm_fn=norm_fn, + initial_residual=input_state['residual'], + initial_residual_norm=init_norm, + delta_reduction_factor=delta_reduction_factor, + max_steps=100, + min_step_norm=MIN_DELTA, ) output_state = { - 'x': input_state['x'] + output_delta_state['delta'], - 'residual': output_delta_state['residual_new'], + 'x': ls_state.x, + 'residual': ls_state.residual, 'iterations': jnp.array(input_state['iterations'][...], dtype=dtype) + 1, - 'last_tau': output_delta_state['tau'], + 'last_tau': ls_state.step_size, } + if log_iterations: jax.debug.print( 'Iteration: {iteration:d}. Residual: {residual:.16f}. tau = {tau:.6f}', iteration=output_state['iterations'].astype(jax_utils.get_int_dtype()), residual=_residual_scalar(output_state['residual']), - tau=output_delta_state['tau'], + tau=ls_state.step_size, ) return output_state - - -def _compute_output_delta_state( - initial_state: dict[str, jax.Array], - residual_fun: Callable[[jax.Array], jax.Array], - delta_reduction_factor: float, -): - """Updates output delta state.""" - delta_body_fun = functools.partial( - _delta_body, - delta_reduction_factor=delta_reduction_factor, - ) - delta_cond_fun = functools.partial( - _delta_cond, - residual_fun=residual_fun, - ) - output_delta_state = jax.lax.while_loop( - delta_cond_fun, delta_body_fun, initial_state - ) - - x_new = output_delta_state['x'] + output_delta_state['delta'] - residual_vec_x_new = residual_fun(x_new) - output_delta_state |= dict( - residual_new=residual_vec_x_new, - ) - return output_delta_state - - -def _delta_cond( - delta_state: dict[str, jax.Array], - residual_fun: Callable[[jax.Array], jax.Array], -) -> bool: - """Check if delta obtained from Newton step is valid. - - Args: - delta_state: see `delta_body`. - residual_fun: Residual function. - - Returns: - True if the new value of `x` causes any NaNs or has increased the residual - relative to the old value of `x`. - """ - x_old = delta_state['x'] - x_new = x_old + delta_state['delta'] - residual_vec_x_old = delta_state['residual_old'] - residual_scalar_x_old = _residual_scalar(residual_vec_x_old) - # Avoid sanity checking inside residual, since we directly - # afterwards check sanity on the output (NaN checking) - # TODO(b/312453092) consider instead sanity-checking x_new - with jax_utils.enable_errors(False): - residual_vec_x_new = residual_fun(x_new) - residual_scalar_x_new = _residual_scalar(residual_vec_x_new) - delta_state['residual_new'] = residual_vec_x_new - return jnp.bool_( - jnp.logical_and( - jnp.max(jnp.abs(delta_state['delta'])) > MIN_DELTA, - jnp.logical_or( - residual_scalar_x_old < residual_scalar_x_new, - jnp.isnan(residual_scalar_x_new), - ), - ), - ) - - -def _delta_body( - input_delta_state: dict[str, jax.Array], - delta_reduction_factor: float, -) -> dict[str, jax.Array]: - """Reduces step size for this Newton iteration.""" - return input_delta_state | dict( - delta=input_delta_state['delta'] * delta_reduction_factor, - tau=jnp.array(input_delta_state['tau'][...], dtype=jax_utils.get_dtype()) - * delta_reduction_factor, - ) diff --git a/torax/_src/solver/linesearch.py b/torax/_src/solver/linesearch.py new file mode 100644 index 000000000..82493053e --- /dev/null +++ b/torax/_src/solver/linesearch.py @@ -0,0 +1,145 @@ +# Copyright 2026 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Backtracking line search for use in solving functions.""" + +import dataclasses +from typing import Callable + +import jax +import jax.numpy as jnp +import jaxtyping as jt + + +@jax.tree_util.register_dataclass +@dataclasses.dataclass(frozen=True) +class LinesearchState: + """State and result of the backtracking line search. + + Attributes are the values at the accepted step size, or the last value + tried if the search failed. + + Attributes: + iteration: Current iteration of the linesearch. + step_size: Current step size. + next_step_size: Next step size to try. + x: Current location. + residual: Current residual. + residual_norm: Norm of current residual. + step_found: Whether a step has been found. + done: Whether the linesearch is done. + """ + + iteration: jnp.ndarray + step_size: jnp.ndarray + next_step_size: jnp.ndarray + x: jt.PyTree + residual: jt.PyTree + residual_norm: jnp.ndarray + step_found: jt.Bool[jax.Array, ""] + done: jt.Bool[jax.Array, ""] + + +def backtracking_linesearch( + residual_fn: Callable[[jt.PyTree], jt.PyTree], + x_init: jt.PyTree, + direction: jt.PyTree, + accept_fn: Callable[[jnp.ndarray, jnp.ndarray], jnp.ndarray], + norm_fn: Callable[[jt.PyTree], jnp.ndarray], + initial_residual: jt.PyTree, + initial_residual_norm: jnp.ndarray, + delta_reduction_factor: float, + max_steps: int, + min_step_norm: float = 0.0, +) -> LinesearchState: + """Performs backtracking line search. + + A backtracking linesearch seeks a value for step_size such that + x_trial = x_init + step_size * direction + meets the condition specified by accept_fn. It performs the search by starting + step_size at 1.0, and decreasing step_size until either accept_fn is true or + the maximum number of iterations is reached. + + Args: + residual_fn: Accepts the location x, and returns the residual R(x). + x_init: Starting location. + direction: Search direction, a PyTree with the same shape as x. + accept_fn: Accepts (step_size, trial_residual_norm) and returns True if the + trial point is acceptable, and false otherwise. + norm_fn: Function compute the norm of the residual. + initial_residual: Residual vector at input x_init. + initial_residual_norm: Norm of initial_residual. + delta_reduction_factor: Factor by which step_size is reduced each step. + max_steps: Maximum number of backtracking steps. + min_step_norm: Minimum value of max(abs(step_size * direction)) allowed. + + Returns: + LinesearchState with the accepted (or last tried) trial point. + """ + + init_step_size = 1.0 + init_state = LinesearchState( + iteration=jnp.array(0, dtype=jnp.int32), + step_size=jnp.array( + init_step_size, + dtype=x_init.dtype if hasattr(x_init, "dtype") else jnp.float32, + ), + next_step_size=jnp.array( + init_step_size, + dtype=x_init.dtype if hasattr(x_init, "dtype") else jnp.float32, + ), + x=x_init, + residual=initial_residual, + residual_norm=initial_residual_norm, + step_found=jnp.array(False), + done=jnp.array(False), + ) + + def cond_fun(state: LinesearchState) -> jt.Bool[jax.Array, ""]: + return jnp.logical_not(state.done) + + def body_fun(state: LinesearchState) -> LinesearchState: + new_iter = state.iteration + 1 + step_size = state.next_step_size + + new_x = jax.tree.map(lambda a, b: a + step_size * b, x_init, direction) + new_res = residual_fn(new_x) + new_norm = norm_fn(new_res) + + new_step_found = accept_fn(step_size, new_norm) + is_max_iter = new_iter >= max_steps + + # Check if step is too small. + max_abs_dir = jnp.max( + jnp.array( + [jnp.max(jnp.abs(leaf)) for leaf in jax.tree.leaves(direction)] + ) + ) + step_too_small = (step_size * max_abs_dir) <= min_step_norm + + new_done = new_step_found | is_max_iter | step_too_small + next_step_size = step_size * delta_reduction_factor + + return LinesearchState( + iteration=new_iter, + step_size=step_size, + next_step_size=next_step_size, + x=new_x, + residual=new_res, + residual_norm=new_norm, + step_found=new_step_found, + done=new_done, + ) + + return jax.lax.while_loop(cond_fun, body_fun, init_state) diff --git a/torax/_src/solver/predictor_corrector_method.py b/torax/_src/solver/predictor_corrector_method.py index 2d1ab0728..0946cc614 100644 --- a/torax/_src/solver/predictor_corrector_method.py +++ b/torax/_src/solver/predictor_corrector_method.py @@ -115,6 +115,10 @@ def loop_body(x_new_guess): maxiter=solver_params.n_corrector_steps + 1, xtol=None, method='iteration', + atol=solver_params.atol, + rtol=solver_params.rtol, + use_backtracking=solver_params.use_backtracking, + delta_reduction_factor=solver_params.delta_reduction_factor, ) else: x_new = loop_body(x_new_guess) diff --git a/torax/_src/solver/pydantic_model.py b/torax/_src/solver/pydantic_model.py index 6daffd026..857d8837e 100644 --- a/torax/_src/solver/pydantic_model.py +++ b/torax/_src/solver/pydantic_model.py @@ -13,9 +13,11 @@ # limitations under the License. """Pydantic config for Solver.""" + import abc import functools from typing import Annotated, Any, Literal +import warnings import pydantic from torax._src import models as models_lib @@ -50,6 +52,8 @@ class BaseSolver(torax_pydantic.BaseModelFrozen, abc.ABC): implicit linear system solve. chi_pereverzev: (deliberately) large heat conductivity for Pereverzev rule. D_pereverzev: (deliberately) large particle diffusion for Pereverzev rule. + atol: Absolute tolerance on the residual norm for the solver. + rtol: Relative tolerance on the residual norm for the solver. """ theta_implicit: Annotated[ @@ -71,6 +75,10 @@ class BaseSolver(torax_pydantic.BaseModelFrozen, abc.ABC): ] = tridiagonal.SolverType.THOMAS chi_pereverzev: pydantic.PositiveFloat = 30.0 D_pereverzev: pydantic.NonNegativeFloat = 15.0 + atol: float | None = None + rtol: float | None = None + use_backtracking: Annotated[bool, torax_pydantic.JAX_STATIC] = True + delta_reduction_factor: float = 0.5 @property @abc.abstractmethod @@ -117,6 +125,10 @@ def build_runtime_params(self) -> runtime_params.RuntimeParams: chi_pereverzev=self.chi_pereverzev, D_pereverzev=self.D_pereverzev, n_corrector_steps=self.n_corrector_steps, + atol=self.atol, + rtol=self.rtol, + use_backtracking=self.use_backtracking, + delta_reduction_factor=self.delta_reduction_factor, ) def build_solver( @@ -157,6 +169,18 @@ class NewtonRaphsonThetaMethod(BaseSolver): delta_reduction_factor: float = 0.5 tau_min: float = 0.01 + @pydantic.model_validator(mode='before') + @classmethod + def enforce_backtracking(cls, data: Any) -> Any: + if isinstance(data, dict) and 'use_backtracking' in data: + if not data['use_backtracking']: + warnings.warn( + 'use_backtracking is always True for Newton-Raphson solver. ' + 'Ignoring user setting of False.' + ) + data['use_backtracking'] = True + return data + @functools.cached_property def build_runtime_params( self, @@ -178,6 +202,9 @@ def build_runtime_params( tau_min=self.tau_min, initial_guess_mode=self.initial_guess_mode.value, log_iterations=self.log_iterations, + atol=self.atol, + rtol=self.rtol, + use_backtracking=self.use_backtracking, ) def build_solver( @@ -225,6 +252,10 @@ def build_runtime_params( loss_tol=self.loss_tol, n_corrector_steps=self.n_corrector_steps, initial_guess_mode=self.initial_guess_mode.value, + atol=self.atol, + rtol=self.rtol, + use_backtracking=self.use_backtracking, + delta_reduction_factor=self.delta_reduction_factor, ) def build_solver( diff --git a/torax/_src/solver/runtime_params.py b/torax/_src/solver/runtime_params.py index 37a65879a..d0fd4b010 100644 --- a/torax/_src/solver/runtime_params.py +++ b/torax/_src/solver/runtime_params.py @@ -33,3 +33,7 @@ class RuntimeParams: ) chi_pereverzev: float D_pereverzev: float # pylint: disable=invalid-name + atol: float | None = dataclasses.field(metadata={'static': True}) + rtol: float | None = dataclasses.field(metadata={'static': True}) + use_backtracking: bool = dataclasses.field(metadata={'static': True}) + delta_reduction_factor: float = dataclasses.field(metadata={'static': True}) diff --git a/torax/_src/solver/tests/anderson_test.py b/torax/_src/solver/tests/anderson_test.py new file mode 100644 index 000000000..bc4d02ef9 --- /dev/null +++ b/torax/_src/solver/tests/anderson_test.py @@ -0,0 +1,222 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Anderson acceleration.""" + +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax +import jax.numpy as jnp +from torax._src.solver import anderson + + +def _default_settings(window_size=3, beta=1.0): + return anderson.AndersonSettings( + window_size=window_size, + safeguard_eta=1.0, + regularization=1e-10, + beta=beta, + ) + + +class AndersonTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + jax.config.update("jax_enable_x64", True) + + +class HistoryTest(AndersonTest): + + def test_init_history(self): + n = 5 + h = anderson.AndersonHistory.create(n, window_size=3, dtype=jnp.float64) + self.assertEqual(h.x_history.shape, (3, n)) + self.assertEqual(h.f_history.shape, (3, n)) + self.assertEqual(h.count, 0) + + def test_push_history(self): + n = 3 + window_size = 2 + h = anderson.AndersonHistory.create( + n, window_size=window_size, dtype=jnp.float64 + ) + x = jnp.array([1.0, 2.0, 3.0]) + f = jnp.array([0.1, 0.2, 0.3]) + + h = h.push(x, f, window_size=window_size) + self.assertEqual(h.count, 1) + chex.assert_trees_all_close(h.x_history[0], x) + chex.assert_trees_all_close(h.f_history[0], f) + + x2 = jnp.array([4.0, 5.0, 6.0]) + f2 = jnp.array([0.4, 0.5, 0.6]) + h = h.push(x2, f2, window_size=window_size) + self.assertEqual(h.count, 2) + chex.assert_trees_all_close(h.x_history[1], x2) + + # Third push should wrap around to index 0. + x3 = jnp.array([7.0, 8.0, 9.0]) + f3 = jnp.array([0.7, 0.8, 0.9]) + h = h.push(x3, f3, window_size=window_size) + self.assertEqual(h.count, 3) + chex.assert_trees_all_close(h.x_history[0], x3) + chex.assert_trees_all_close(h.f_history[0], f3) + + +class ComputeCandidateTest(AndersonTest): + + def test_no_history_returns_picard(self): + """With empty history, returns the damped Picard step.""" + n = 2 + beta = 0.5 + settings = _default_settings(beta=beta) + h = anderson.AndersonHistory.create( + n, settings.window_size, dtype=jnp.float64 + ) + + x = jnp.array([1.0, 2.0]) + f = jnp.array([0.1, 0.2]) + + candidate = anderson._compute_candidate(h, x, f, settings) + expected = x + beta * f + chex.assert_trees_all_close(candidate, expected) + + def test_with_history_differs_from_picard(self): + """With history, the Anderson candidate differs from Picard.""" + n = 3 + settings = _default_settings() + history = anderson.AndersonHistory.create( + n, settings.window_size, dtype=jnp.float64 + ) + history = history.push( + jnp.array([1.0, 2.0, 0.0]), + jnp.array([0.5, -0.3, 0.1]), + settings.window_size, + ) + history = history.push( + jnp.array([1.5, 1.7, 0.1]), + jnp.array([0.2, -0.1, -0.2]), + settings.window_size, + ) + + x = jnp.array([1.7, 1.6, 0.2]) + f = jnp.array([0.1, -0.05, 0.3]) + + candidate = anderson._compute_candidate(history, x, f, settings) + picard = x + f + self.assertFalse(jnp.allclose(candidate, picard, atol=1e-6)) + + # Test with damping as well. + settings_damped = _default_settings(beta=0.5) + candidate_damped = anderson._compute_candidate( + history, x, f, settings_damped + ) + picard_damped = x + 0.5 * f + self.assertFalse(jnp.allclose(candidate_damped, picard_damped, atol=1e-6)) + + # Damped candidate should differ from undamped candidate. + self.assertFalse(jnp.allclose(candidate, candidate_damped, atol=1e-6)) + + def test_exact_on_affine(self): + """Anderson should exactly solve an affine fixed-point in one step.""" + + def g(x): + # Fixed point of function is x = [0.5, 0.5]. + return 0.5 * x + jnp.array([0.25, 0.25]) + + n = 2 + settings = _default_settings(window_size=3) + h = anderson.AndersonHistory.create( + n, settings.window_size, dtype=jnp.float64 + ) + + # Run multiple iterations of Picard to build history. + x0 = jnp.array([0.0, 0.0]) + f0 = g(x0) - x0 + h = h.push(x0, f0, settings.window_size) + + x1 = g(x0) + f1 = g(x1) - x1 + h = h.push(x1, f1, settings.window_size) + + x2 = g(x1) + f2 = g(x2) - x2 + + # Constructed candidate should be close to the exact answer. + candidate = anderson._compute_candidate(h, x2, f2, settings) + chex.assert_trees_all_close(candidate, jnp.array([0.5, 0.5]), atol=1e-10) + + def test_jit_compatible(self): + """try_step works under jit.""" + n = 2 + settings = _default_settings() + + def residual_fn(x): + return x - jnp.array([0.5, 0.5]) + + @jax.jit + def step(x, picard_step, current_residual_norm, current_history): + return anderson.try_step( + x, + picard_step, + residual_fn, + current_residual_norm, + current_history, + settings, + ) + + h = anderson.AndersonHistory.create( + n, settings.window_size, dtype=jnp.float64 + ) + x = jnp.array([1.0, 2.0]) + picard_step = jnp.array([0.1, 0.2]) + current_residual_norm = jnp.linalg.norm(residual_fn(x)) + + result = step(x, picard_step, current_residual_norm, h) + self.assertEqual(result.candidate.shape, (n,)) + self.assertIsInstance(result, anderson.Result) + + +class TryStepTest(AndersonTest): + + @parameterized.named_parameters( + ("improving", lambda c: jnp.array([0.01, 0.02]), True), + ("worse", lambda c: jnp.array([10.0, 20.0]), False), + ) + def test_try_step_safeguard(self, residual_fn, expected_accepted): + """Tests that try_step accepts improving steps and rejects worse steps.""" + n = 2 + settings = _default_settings() + h = anderson.AndersonHistory.create( + n, settings.window_size, dtype=jnp.float64 + ) + x = jnp.array([1.0, 2.0]) + picard_step = jnp.array([0.1, 0.2]) + current_residual_norm = jnp.array(1.0) + + result = anderson.try_step( + x=x, + picard_step=picard_step, + residual_fn=residual_fn, + current_residual_norm=current_residual_norm, + current_history=h, + settings=settings, + ) + self.assertEqual(result.accepted, expected_accepted) + + +if __name__ == "__main__": + absltest.main() diff --git a/torax/_src/solver/tests/jax_fixed_point_test.py b/torax/_src/solver/tests/jax_fixed_point_test.py index 583dfec88..bd6b8e498 100644 --- a/torax/_src/solver/tests/jax_fixed_point_test.py +++ b/torax/_src/solver/tests/jax_fixed_point_test.py @@ -88,6 +88,115 @@ def test_fixed_point_none(self): ) chex.assert_trees_all_close(out_expected, out_jnp, atol=1e-8) + def test_fixed_point_residual_norm(self): + c1 = np.array([10, 12.0]) + c2 = np.array([3, 5.0]) + x = np.array([1.2, 1.3]) + + # Test with atol + out_jnp_atol = jax_fixed_point.fixed_point( + _func_jnp, + x, + args=(c1, c2), + method='iteration', + maxiter=500, + atol=1e-5, + xtol=None, + ) + + # Verify it gives close result to standard fixed point + out_expected = jax_fixed_point.fixed_point( + _func_jnp, + x, + args=(c1, c2), + method='iteration', + maxiter=500, + xtol=1e-5, + ) + chex.assert_trees_all_close(out_expected, out_jnp_atol, atol=1e-5) + + # Test with rtol + out_jnp_rtol = jax_fixed_point.fixed_point( + _func_jnp, + x, + args=(c1, c2), + method='iteration', + maxiter=500, + rtol=1e-5, + xtol=None, + ) + chex.assert_trees_all_close(out_expected, out_jnp_rtol, atol=1e-5) + + def test_fixed_point_backtracking(self): + c1 = np.array([10, 12.0]) + c2 = np.array([3, 5.0]) + x = np.array([1.2, 1.3]) + + out_jnp = jax_fixed_point.fixed_point( + _func_jnp, + x, + args=(c1, c2), + method='iteration', + maxiter=500, + use_backtracking=True, + delta_reduction_factor=0.5, + max_backtrack_steps=5, + atol=1e-5, + xtol=None, + ) + + out_expected = jax_fixed_point.fixed_point( + _func_jnp, + x, + args=(c1, c2), + method='iteration', + maxiter=500, + atol=1e-5, + xtol=None, + ) + chex.assert_trees_all_close(out_expected, out_jnp, atol=1e-5) + + def test_fixed_point_anderson(self): + c1 = np.array([10, 12.0]) + c2 = np.array([3, 5.0]) + x = np.array([1.2, 1.3]) + + # Get reference from scipy + out_expected = optimize.fixed_point( + _func_np, x, args=(c1, c2), method='del2', maxiter=500, xtol=1e-8 + ) + + # Test JAX Anderson + out_jnp = jax_fixed_point.fixed_point( + _func_jnp, + x, + args=(c1, c2), + method='anderson', + maxiter=500, + xtol=1e-8, + ) + chex.assert_trees_all_close(out_expected, out_jnp, atol=1e-5) + + def test_fixed_point_anderson_backtracking(self): + c1 = np.array([10, 12.0]) + c2 = np.array([3, 5.0]) + x = np.array([1.2, 1.3]) + + out_expected = optimize.fixed_point( + _func_np, x, args=(c1, c2), method='del2', maxiter=500, xtol=1e-8 + ) + + out_jnp = jax_fixed_point.fixed_point( + _func_jnp, + x, + args=(c1, c2), + method='anderson', + maxiter=500, + use_backtracking=True, + xtol=1e-8, + ) + chex.assert_trees_all_close(out_expected, out_jnp, atol=1e-5) + if __name__ == '__main__': absltest.main() diff --git a/torax/_src/solver/tests/linesearch_test.py b/torax/_src/solver/tests/linesearch_test.py new file mode 100644 index 000000000..7f4d3b66a --- /dev/null +++ b/torax/_src/solver/tests/linesearch_test.py @@ -0,0 +1,124 @@ +# Copyright 2026 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for linesearch module.""" + +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax.numpy as jnp +from torax._src.solver import linesearch + + +class BacktrackingLinesearchTest(parameterized.TestCase): + + def test_linesearch_success(self): + def residual_fn(x): + return x - 2.0 + + x_init = jnp.array(0.0) + direction = jnp.array(2.0) # Newton step to root + + def accept_fn(step_size, trial_norm): + del step_size + return trial_norm <= 1.0 # Initial is 2.0, so decreasing is good + + def norm_fn(res): + return jnp.abs(res) + + final = linesearch.backtracking_linesearch( + residual_fn=residual_fn, + x_init=x_init, + direction=direction, + accept_fn=accept_fn, + norm_fn=norm_fn, + initial_residual=x_init - 2.0, + initial_residual_norm=jnp.array(2.0), + delta_reduction_factor=0.5, + max_steps=10, + ) + + self.assertTrue(bool(final.step_found)) + self.assertLessEqual(int(final.iteration), 10) + chex.assert_trees_all_close(final.x, jnp.array(2.0)) + + def test_linesearch_backtracking(self): + # A function that increases residual if step is too large + def residual_fn(x): + # If x > 1.0, return large residual, else return x - 2.0 + return jnp.where(x > 1.0, 10.0, x - 2.0) + + x_init = jnp.array(0.0) + direction = jnp.array(2.0) + + def accept_fn(step_size, trial_norm): + del step_size + return trial_norm <= 1.5 # Initial norm is 2.0. 1.0 is good. + + def norm_fn(res): + return jnp.abs(res) + + final = linesearch.backtracking_linesearch( + residual_fn=residual_fn, + x_init=x_init, + direction=direction, + accept_fn=accept_fn, + norm_fn=norm_fn, + initial_residual=x_init - 2.0, + initial_residual_norm=jnp.array(2.0), + delta_reduction_factor=0.5, + max_steps=10, + ) + + self.assertTrue(bool(final.step_found)) + self.assertGreater(int(final.iteration), 1) # Must have backtracked + chex.assert_trees_all_close(final.x, jnp.array(1.0)) + + def test_linesearch_pytree(self): + def residual_fn(x): + return {'a': x['a'] - 2.0, 'b': x['b'] - 3.0} + + x_init = {'a': jnp.array(0.0), 'b': jnp.array(0.0)} + direction = {'a': jnp.array(2.0), 'b': jnp.array(3.0)} + + def accept_fn(step_size, trial_norm): + del step_size + return trial_norm <= 1.0 + + def norm_fn(res): + return jnp.sqrt(res['a'] ** 2 + res['b'] ** 2) + + init_res = residual_fn(x_init) + init_norm = norm_fn(init_res) + + final = linesearch.backtracking_linesearch( + residual_fn=residual_fn, + x_init=x_init, + direction=direction, + accept_fn=accept_fn, + norm_fn=norm_fn, + initial_residual=init_res, + initial_residual_norm=init_norm, + delta_reduction_factor=0.5, + max_steps=10, + ) + + self.assertTrue(bool(final.step_found)) + chex.assert_trees_all_close( + final.x, {'a': jnp.array(2.0), 'b': jnp.array(3.0)} + ) + + +if __name__ == '__main__': + absltest.main()