Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 21 additions & 28 deletions pyadjoint/optimization/tao_solver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from numbers import Complex
from functools import cached_property

import numpy as np

Expand All @@ -9,15 +10,6 @@
from .optimization_solver import OptimizationSolver


try:
import petsc4py.PETSc as PETSc
except ModuleNotFoundError:
PETSc = None
try:
import petsctools
except ModuleNotFoundError:
petsctools = None

__all__ = [
"TAOConvergenceError",
"TAOSolver"
Expand All @@ -37,10 +29,7 @@ class PETScVecInterface:
"""

def __init__(self, x, *, comm=None):
if PETSc is None:
raise RuntimeError("PETSc not available")
if petsctools is None:
raise RuntimeError("petsctools not available")
from petsc4py import PETSc

x = Enlist(x)
comm = valid_comm(comm)
Expand Down Expand Up @@ -80,6 +69,7 @@ def new_petsc(self):
Returns:
petsc4py.PETSc.Vec: The new :class:`petsc4py.PETSc.Vec`.
"""
from petsc4py import PETSc

vec = PETSc.Vec().create(comm=self.comm)
vec.setSizes((self.n, self.N))
Expand Down Expand Up @@ -152,6 +142,7 @@ def valid_comm(comm):
petsc4py.PETSc.COMM_WORLD if `comm is None`, otherwise `comm.tompi4py()`.
"""
if comm is None:
from petsc4py import PETSc
comm = PETSc.COMM_WORLD
if hasattr(comm, "tompi4py"):
comm = comm.tompi4py()
Expand Down Expand Up @@ -435,6 +426,7 @@ def ReducedFunctionalMat(rf, action=RFOperation.HESSIAN, *, apply_riesz=False, a
be reevaluated at every call to `mult`.
comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the rf is defined over.
"""
from petsc4py import PETSc
if action == RFOperation.HESSIAN:
ctx = ReducedFunctionalHessianMat(
rf, appctx=appctx, apply_riesz=apply_riesz,
Expand Down Expand Up @@ -511,6 +503,7 @@ def RieszMapMat(controls, symmetric=True, comm=None):
symmetric (bool): Whether the Riesz map attached to the Control is symmetric.
comm (Optional[petsc4py.PETSc.Comm,mpi4py.MPI.Comm]): Communicator that the controls are defined over.
"""
from petsc4py import PETSc
ctx = RieszMapMatCtx(controls, comm=comm)

n = ctx.vec_interface.n
Expand Down Expand Up @@ -623,16 +616,6 @@ class TAOConvergenceError(Exception):
"""


if PETSc is None:
_tao_reasons = {}
else:
# Same approach as in _make_reasons in firedrake/solving_utils.py,
# Firedrake master branch 57e21cc8ebdb044c1d8423b48f3dbf70975d5548
_tao_reasons = {getattr(PETSc.TAO.Reason, key): key
for key in dir(PETSc.TAO.Reason)
if not key.startswith("_")}


class TAOSolver(OptimizationSolver):
"""Use TAO to solve an optimization problem.

Expand All @@ -648,10 +631,8 @@ class TAOSolver(OptimizationSolver):
def __init__(self, problem, parameters, *,
options_prefix=None, appctx=None,
Pmat=None, comm=None):
if PETSc is None:
raise RuntimeError("PETSc not available")
if petsctools is None:
raise RuntimeError("petsctools not available")
from petsc4py import PETSc
import petsctools

if not isinstance(problem, MinimizationProblem):
raise TypeError("MinimizationProblem required")
Expand Down Expand Up @@ -795,12 +776,24 @@ def x(self):

return self._x

@cached_property
def _tao_reasons(self):
"""Dictionary of TAO convergence reason int codes -> python objects
"""
from petsc4py import PETSc
# Same approach as in _make_reasons in firedrake/solving_utils.py,
# Firedrake master branch 57e21cc8ebdb044c1d8423b48f3dbf70975d5548
return {getattr(PETSc.TAO.Reason, key): key
for key in dir(PETSc.TAO.Reason)
if not key.startswith("_")}

def solve(self):
"""Solve the optimization problem.

Returns:
OverloadedType or Sequence[OverloadedType]: The solution.
"""
import petsctools

controls = self.tao_objective.reduced_functional.controls
m = tuple(control.tape_value()._ad_copy() for control in controls)
Expand All @@ -814,7 +807,7 @@ def solve(self):
# Using the same format as Firedrake linear solver errors
raise TAOConvergenceError(
f"TAOSolver failed to converge after {self.tao.getIterationNumber()} iterations "
f"with reason: {_tao_reasons[self.tao.getConvergedReason()]}")
f"with reason: {self._tao_reasons[self.tao.getConvergedReason()]}")
if isinstance(controls, Enlist):
return controls.delist(m)
else:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ visualisation = [
]
tao = [
"petsc4py",
"petsctools>2025.0"
"petsctools>2025.3"
]


Expand Down