Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions pina/_src/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Module for the DomainEquationCondition class."""

from pina._src.condition.condition_base import ConditionBase
from pina._src.condition.equation_condition_base import (
EquationConditionBase,
)
from pina._src.domain.domain_interface import DomainInterface
from pina._src.equation.equation_interface import EquationInterface


class DomainEquationCondition(ConditionBase):
class DomainEquationCondition(EquationConditionBase):
"""
The class :class:`DomainEquationCondition` defines a condition based on a
``domain`` and an ``equation``. This condition is typically used in
Expand Down Expand Up @@ -92,4 +94,29 @@ def store_data(self, **kwargs):
:rtype: dict
"""
setattr(self, "domain", kwargs.get("domain"))
setattr(self, "equation", kwargs.get("equation"))
setattr(self, "_equation", kwargs.get("equation"))

@property
def equation(self):
"""
Return the equation associated with this condition.

:return: Equation associated with this condition.
:rtype: EquationInterface
"""
return self._equation

@equation.setter
def equation(self, value):
"""
Set the equation associated with this condition.

:param EquationInterface value: The equation to associate
with this condition.
"""
if not isinstance(value, EquationInterface):
raise TypeError(
"The equation must be an instance of "
"EquationInterface."
)
self._equation = value
50 changes: 50 additions & 0 deletions pina/_src/condition/equation_condition_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Module for the EquationConditionBase class."""

from pina._src.condition.condition_base import ConditionBase


class EquationConditionBase(ConditionBase):
"""
Base class for conditions that involve an equation.

This class provides the :meth:`evaluate` method, which computes the
non-aggregated residual of the equation given the input samples and a
solver. It is intended to be subclassed by conditions that define an
``equation`` attribute, such as
:class:`~pina.condition.DomainEquationCondition` and
:class:`~pina.condition.InputEquationCondition`.
"""

def evaluate(self, batch, solver, loss):
"""
Evaluate the equation residual on the given batch using the solver.

This method computes the non-aggregated, element-wise residual of the
equation. It performs a forward pass of the solver's model on the
input samples and then evaluates the equation residual. The returned
tensor is **not** reduced (i.e., no mean, sum, etc.), preserving the
per-sample residual values.

:param batch: The batch containing the ``input`` entry.
:type batch: dict | _DataManager
:param solver: The solver containing the model and any additional
parameters (e.g., unknown parameters for inverse problems).
:type solver: ~pina.solver.solver.SolverInterface
:param loss: The non-aggregating loss function to apply to the
computed residual against zero.
:type loss: torch.nn.Module
:return: The non-aggregated loss tensor.
:rtype: ~pina.label_tensor.LabelTensor

:Example:

>>> residuals = condition.evaluate(
... {"input": input_samples}, solver, loss
... )
>>> # residuals is a non-reduced tensor of shape (n_samples, ...)
"""
samples = batch["input"]
residual = self.equation.residual(
samples, solver.forward(samples), solver._params
)
return residual**2
6 changes: 4 additions & 2 deletions pina/_src/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Module for the InputEquationCondition class and its subclasses."""

from pina._src.condition.condition_base import ConditionBase
from pina._src.condition.equation_condition_base import (
EquationConditionBase,
)
from pina._src.core.label_tensor import LabelTensor
from pina._src.core.graph import Graph
from pina._src.equation.equation_interface import EquationInterface
from pina._src.condition.data_manager import _DataManager


class InputEquationCondition(ConditionBase):
class InputEquationCondition(EquationConditionBase):
"""
The class :class:`InputEquationCondition` defines a condition based on
``input`` data and an ``equation``. This condition is typically used in
Expand Down
18 changes: 18 additions & 0 deletions pina/_src/condition/input_target_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,21 @@ def target(self):
list[Data] | tuple[Graph] | tuple[Data]
"""
return self.data.target

def evaluate(self, batch, solver, loss):
"""
Evaluate the supervised condition on the given batch using the solver.

This method computes the element-wise loss associated with the
condition using the input and target stored in the provided batch.

:param batch: The batch containing ``input`` and ``target`` entries.
:type batch: dict | _DataManager
:param solver: The solver containing the model.
:type solver: ~pina.solver.solver.SolverInterface
:param loss: The non-aggregating loss function to apply.
:type loss: torch.nn.Module
:return: The non-aggregated loss tensor.
:rtype: LabelTensor | torch.Tensor | Graph | Data
"""
return loss(solver.forward(batch["input"]), batch["target"])
135 changes: 135 additions & 0 deletions pina/_src/solver/single_model_simple_solver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Module for the SingleModelSimpleSolver."""

import torch
from torch.nn.modules.loss import _Loss

from pina._src.condition.domain_equation_condition import (
DomainEquationCondition,
)
from pina._src.condition.input_equation_condition import (
InputEquationCondition,
)
from pina._src.condition.input_target_condition import InputTargetCondition
from pina._src.core.utils import check_consistency
from pina._src.loss.loss_interface import LossInterface
from pina._src.solver.solver import SingleSolverInterface


class SingleModelSimpleSolver(SingleSolverInterface):
"""
Minimal single-model solver with explicit residual evaluation, reduction,
and loss aggregation across conditions.

The solver orchestrates a uniform workflow for all conditions in the batch:

1. evaluate the condition and obtain a non-aggregated loss tensor;
2. apply a reduction to obtain a scalar loss for that condition;
4. return the per-condition losses, which are aggregated by the inherited
solver machinery through the configured weighting.
"""

accepted_conditions_types = (
InputTargetCondition,
InputEquationCondition,
DomainEquationCondition,
)

def __init__(
self,
problem,
model,
optimizer=None,
scheduler=None,
weighting=None,
loss=None,
use_lt=True,
):
"""
Initialize the single-model simple solver.

:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
:param Scheduler scheduler: Learning rate scheduler.
:param WeightingInterface weighting: The weighting schema to be used.
:param torch.nn.Module loss: The element-wise loss module whose
reduction strategy is reused by the solver. If ``None``,
:class:`torch.nn.MSELoss` is used.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
"""
if loss is None:
loss = torch.nn.MSELoss()

check_consistency(loss, (LossInterface, _Loss), subclass=False)

super().__init__(
model=model,
problem=problem,
optimizer=optimizer,
scheduler=scheduler,
weighting=weighting,
use_lt=use_lt,
)

self._loss_fn = loss
self._reduction = getattr(loss, "reduction", "mean")

if hasattr(self._loss_fn, "reduction"):
self._loss_fn.reduction = "none"

def optimization_cycle(self, batch):
"""
Compute one reduced loss per condition in the batch.

:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The reduced losses for all conditions.
:rtype: dict[str, torch.Tensor]
"""
condition_losses = {}

for condition_name, data in batch:
condition = self.problem.conditions[condition_name]
condition_data = dict(data)

if hasattr(condition_data.get("input"), "requires_grad_"):
condition_data["input"] = condition_data[
"input"
].requires_grad_()

condition_loss_tensor = condition.evaluate(
condition_data, self, self._loss_fn
)
condition_losses[condition_name] = self._apply_reduction(
condition_loss_tensor
)
Comment on lines +98 to +100
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add weighting


return condition_losses

def _apply_reduction(self, value):
"""
Apply the configured reduction to a non-aggregated condition tensor.

:param value: The non-aggregated tensor returned by a condition.
:type value: torch.Tensor
:return: The reduced scalar tensor.
:rtype: torch.Tensor
:raises ValueError: If the reduction is not supported.
"""
if self._reduction == "none":
return value
if self._reduction == "mean":
return value.mean()
if self._reduction == "sum":
return value.sum()
raise ValueError(f"Unsupported reduction '{self._reduction}'.")

@property
def loss(self):
"""
The underlying element-wise loss module.

:return: The stored loss module.
:rtype: torch.nn.Module
"""
return self._loss_fn
4 changes: 4 additions & 0 deletions pina/condition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"Condition",
"ConditionInterface",
"ConditionBase",
"EquationConditionBase",
"DomainEquationCondition",
"InputTargetCondition",
"InputEquationCondition",
Expand All @@ -18,6 +19,9 @@

from pina._src.condition.condition_interface import ConditionInterface
from pina._src.condition.condition_base import ConditionBase
from pina._src.condition.equation_condition_base import (
EquationConditionBase,
)
from pina._src.condition.condition import Condition
from pina._src.condition.domain_equation_condition import (
DomainEquationCondition,
Expand Down
4 changes: 4 additions & 0 deletions pina/solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"SolverInterface",
"SingleSolverInterface",
"MultiSolverInterface",
"SingleModelSimpleSolver",
"PINNInterface",
"PINN",
"GradientPINN",
Expand All @@ -36,6 +37,9 @@
SingleSolverInterface,
MultiSolverInterface,
)
from pina._src.solver.single_model_simple_solver import (
SingleModelSimpleSolver,
)
from pina._src.solver.physics_informed_solver.pinn import PINNInterface, PINN
from pina._src.solver.physics_informed_solver.gradient_pinn import GradientPINN
from pina._src.solver.physics_informed_solver.causal_pinn import CausalPINN
Expand Down
30 changes: 30 additions & 0 deletions tests/test_condition/test_domain_equation_condition.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import pytest
import torch
from pina import Condition
from pina import LabelTensor
from pina.domain import CartesianDomain
from pina._src.equation.equation_factory import FixedValue
from pina.equation import Equation
from pina.condition import DomainEquationCondition


class DummySolver:
def __init__(self):
self._params = {"shift": torch.tensor(0.25)}

def forward(self, samples):
return samples.extract(["x"]) - samples.extract(["y"])

example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
example_equation = FixedValue(0.0)

Expand All @@ -27,3 +38,22 @@ def test_getitem_not_implemented():
cond = Condition(domain=example_domain, equation=FixedValue(0.0))
with pytest.raises(NotImplementedError):
cond[0]


def test_evaluate_domain_equation_condition():
def equation_func(input_, output_, params_):
return output_ + input_.extract(["y"]) - params_["shift"]

samples = LabelTensor(torch.randn(12, 2), labels=["x", "y"])
cond = Condition(domain=example_domain, equation=Equation(equation_func))
solver = DummySolver()
batch = {"input": samples}
loss = torch.nn.MSELoss(reduction="none")

residual = cond.evaluate(batch, solver, loss)
expected = loss(
samples.extract(["x"]) - solver._params["shift"],
torch.zeros_like(samples.extract(["x"]) - solver._params["shift"]),
)

torch.testing.assert_close(residual, expected)
27 changes: 27 additions & 0 deletions tests/test_condition/test_input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@
from pina._src.condition.data_manager import _DataManager


class DummySolver:
def __init__(self):
self._params = {"shift": torch.tensor(1.5)}

def forward(self, samples):
return samples.extract(["x"]) + samples.extract(["y"])


def _create_pts_and_equation():
def dummy_equation(pts):
return pts["x"] ** 2 + pts["y"] ** 2 - 1
Expand Down Expand Up @@ -77,3 +85,22 @@ def test_getitems_tensor_equation_condition():
assert isinstance(item, _DataManager)
assert hasattr(item, "input")
assert item.input.shape == (3, 2)


def test_evaluate_tensor_equation_condition():
def equation_func(input_, output_, params_):
return output_ - input_.extract(["x"]) - params_["shift"]

pts = LabelTensor(torch.randn(10, 2), labels=["x", "y"])
condition = Condition(input=pts, equation=Equation(equation_func))
solver = DummySolver()
batch = {"input": pts}
loss = torch.nn.MSELoss(reduction="none")

residual = condition.evaluate(batch, solver, loss)
expected = loss(
pts.extract(["y"]) - solver._params["shift"],
torch.zeros_like(pts.extract(["y"]) - solver._params["shift"]),
)

torch.testing.assert_close(residual, expected)
Loading
Loading