diff --git a/pyadjoint/optimization/tao_solver.py b/pyadjoint/optimization/tao_solver.py index 000d0138..a341dbea 100644 --- a/pyadjoint/optimization/tao_solver.py +++ b/pyadjoint/optimization/tao_solver.py @@ -1,5 +1,6 @@ from enum import Enum from numbers import Complex +from functools import cached_property import numpy as np @@ -9,15 +10,6 @@ from .optimization_solver import OptimizationSolver -try: - import petsc4py.PETSc as PETSc -except ModuleNotFoundError: - PETSc = None -try: - import petsctools -except ModuleNotFoundError: - petsctools = None - __all__ = [ "TAOConvergenceError", "TAOSolver" @@ -37,10 +29,7 @@ class PETScVecInterface: """ def __init__(self, x, *, comm=None): - if PETSc is None: - raise RuntimeError("PETSc not available") - if petsctools is None: - raise RuntimeError("petsctools not available") + from petsc4py import PETSc x = Enlist(x) comm = valid_comm(comm) @@ -80,6 +69,7 @@ def new_petsc(self): Returns: petsc4py.PETSc.Vec: The new :class:`petsc4py.PETSc.Vec`. """ + from petsc4py import PETSc vec = PETSc.Vec().create(comm=self.comm) vec.setSizes((self.n, self.N)) @@ -152,6 +142,7 @@ def valid_comm(comm): petsc4py.PETSc.COMM_WORLD if `comm is None`, otherwise `comm.tompi4py()`. """ if comm is None: + from petsc4py import PETSc comm = PETSc.COMM_WORLD if hasattr(comm, "tompi4py"): comm = comm.tompi4py() @@ -435,6 +426,7 @@ def ReducedFunctionalMat(rf, action=RFOperation.HESSIAN, *, apply_riesz=False, a be reevaluated at every call to `mult`. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over. """ + from petsc4py import PETSc if action == RFOperation.HESSIAN: ctx = ReducedFunctionalHessianMat( rf, appctx=appctx, apply_riesz=apply_riesz, @@ -511,6 +503,7 @@ def RieszMapMat(controls, symmetric=True, comm=None): symmetric (bool): Whether the Riesz map attached to the Control is symmetric. comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the controls are defined over. """ + from petsc4py import PETSc ctx = RieszMapMatCtx(controls, comm=comm) n = ctx.vec_interface.n @@ -623,16 +616,6 @@ class TAOConvergenceError(Exception): """ -if PETSc is None: - _tao_reasons = {} -else: - # Same approach as in _make_reasons in firedrake/solving_utils.py, - # Firedrake master branch 57e21cc8ebdb044c1d8423b48f3dbf70975d5548 - _tao_reasons = {getattr(PETSc.TAO.Reason, key): key - for key in dir(PETSc.TAO.Reason) - if not key.startswith("_")} - - class TAOSolver(OptimizationSolver): """Use TAO to solve an optimization problem. @@ -648,10 +631,8 @@ class TAOSolver(OptimizationSolver): def __init__(self, problem, parameters, *, options_prefix=None, appctx=None, Pmat=None, comm=None): - if PETSc is None: - raise RuntimeError("PETSc not available") - if petsctools is None: - raise RuntimeError("petsctools not available") + from petsc4py import PETSc + import petsctools if not isinstance(problem, MinimizationProblem): raise TypeError("MinimizationProblem required") @@ -795,12 +776,24 @@ def x(self): return self._x + @cached_property + def _tao_reasons(self): + """Dictionary of TAO convergence reason int codes -> python objects + """ + from petsc4py import PETSc + # Same approach as in _make_reasons in firedrake/solving_utils.py, + # Firedrake master branch 57e21cc8ebdb044c1d8423b48f3dbf70975d5548 + return {getattr(PETSc.TAO.Reason, key): key + for key in dir(PETSc.TAO.Reason) + if not key.startswith("_")} + def solve(self): """Solve the optimization problem. Returns: OverloadedType or Sequence[OverloadedType]: The solution. """ + import petsctools controls = self.tao_objective.reduced_functional.controls m = tuple(control.tape_value()._ad_copy() for control in controls) @@ -814,7 +807,7 @@ def solve(self): # Using the same format as Firedrake linear solver errors raise TAOConvergenceError( f"TAOSolver failed to converge after {self.tao.getIterationNumber()} iterations " - f"with reason: {_tao_reasons[self.tao.getConvergedReason()]}") + f"with reason: {self._tao_reasons[self.tao.getConvergedReason()]}") if isinstance(controls, Enlist): return controls.delist(m) else: diff --git a/pyproject.toml b/pyproject.toml index 9b4ac2ca..53e80f78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ visualisation = [ ] tao = [ "petsc4py", - "petsctools>2025.0" + "petsctools>2025.3" ]