-
Notifications
You must be signed in to change notification settings - Fork 97
new solvers structure #777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ndem0
wants to merge
5
commits into
0.3
Choose a base branch
from
0.3-solver
base: 0.3
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| """Module for the EquationConditionBase class.""" | ||
|
|
||
| from pina._src.condition.condition_base import ConditionBase | ||
|
|
||
|
|
||
| class EquationConditionBase(ConditionBase): | ||
| """ | ||
| Base class for conditions that involve an equation. | ||
|
|
||
| This class provides the :meth:`evaluate` method, which computes the | ||
| non-aggregated residual of the equation given the input samples and a | ||
| solver. It is intended to be subclassed by conditions that define an | ||
| ``equation`` attribute, such as | ||
| :class:`~pina.condition.DomainEquationCondition` and | ||
| :class:`~pina.condition.InputEquationCondition`. | ||
| """ | ||
|
|
||
| def evaluate(self, batch, solver, loss): | ||
| """ | ||
| Evaluate the equation residual on the given batch using the solver. | ||
|
|
||
| This method computes the non-aggregated, element-wise residual of the | ||
| equation. It performs a forward pass of the solver's model on the | ||
| input samples and then evaluates the equation residual. The returned | ||
| tensor is **not** reduced (i.e., no mean, sum, etc.), preserving the | ||
| per-sample residual values. | ||
|
|
||
| :param batch: The batch containing the ``input`` entry. | ||
| :type batch: dict | _DataManager | ||
| :param solver: The solver containing the model and any additional | ||
| parameters (e.g., unknown parameters for inverse problems). | ||
| :type solver: ~pina.solver.solver.SolverInterface | ||
| :param loss: The non-aggregating loss function to apply to the | ||
| computed residual against zero. | ||
| :type loss: torch.nn.Module | ||
| :return: The non-aggregated loss tensor. | ||
| :rtype: ~pina.label_tensor.LabelTensor | ||
|
|
||
| :Example: | ||
|
|
||
| >>> residuals = condition.evaluate( | ||
| ... {"input": input_samples}, solver, loss | ||
| ... ) | ||
| >>> # residuals is a non-reduced tensor of shape (n_samples, ...) | ||
| """ | ||
| samples = batch["input"] | ||
| residual = self.equation.residual( | ||
| samples, solver.forward(samples), solver._params | ||
| ) | ||
| return residual**2 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,135 @@ | ||
| """Module for the SingleModelSimpleSolver.""" | ||
|
|
||
| import torch | ||
| from torch.nn.modules.loss import _Loss | ||
|
|
||
| from pina._src.condition.domain_equation_condition import ( | ||
| DomainEquationCondition, | ||
| ) | ||
| from pina._src.condition.input_equation_condition import ( | ||
| InputEquationCondition, | ||
| ) | ||
| from pina._src.condition.input_target_condition import InputTargetCondition | ||
| from pina._src.core.utils import check_consistency | ||
| from pina._src.loss.loss_interface import LossInterface | ||
| from pina._src.solver.solver import SingleSolverInterface | ||
|
|
||
|
|
||
| class SingleModelSimpleSolver(SingleSolverInterface): | ||
| """ | ||
| Minimal single-model solver with explicit residual evaluation, reduction, | ||
| and loss aggregation across conditions. | ||
|
|
||
| The solver orchestrates a uniform workflow for all conditions in the batch: | ||
|
|
||
| 1. evaluate the condition and obtain a non-aggregated loss tensor; | ||
| 2. apply a reduction to obtain a scalar loss for that condition; | ||
| 4. return the per-condition losses, which are aggregated by the inherited | ||
| solver machinery through the configured weighting. | ||
| """ | ||
|
|
||
| accepted_conditions_types = ( | ||
| InputTargetCondition, | ||
| InputEquationCondition, | ||
| DomainEquationCondition, | ||
| ) | ||
|
|
||
| def __init__( | ||
| self, | ||
| problem, | ||
| model, | ||
| optimizer=None, | ||
| scheduler=None, | ||
| weighting=None, | ||
| loss=None, | ||
| use_lt=True, | ||
| ): | ||
| """ | ||
| Initialize the single-model simple solver. | ||
|
|
||
| :param AbstractProblem problem: The problem to be solved. | ||
| :param torch.nn.Module model: The neural network model to be used. | ||
| :param Optimizer optimizer: The optimizer to be used. | ||
| :param Scheduler scheduler: Learning rate scheduler. | ||
| :param WeightingInterface weighting: The weighting schema to be used. | ||
| :param torch.nn.Module loss: The element-wise loss module whose | ||
| reduction strategy is reused by the solver. If ``None``, | ||
| :class:`torch.nn.MSELoss` is used. | ||
| :param bool use_lt: If ``True``, the solver uses LabelTensors as input. | ||
| """ | ||
| if loss is None: | ||
| loss = torch.nn.MSELoss() | ||
|
|
||
| check_consistency(loss, (LossInterface, _Loss), subclass=False) | ||
|
|
||
| super().__init__( | ||
| model=model, | ||
| problem=problem, | ||
| optimizer=optimizer, | ||
| scheduler=scheduler, | ||
| weighting=weighting, | ||
| use_lt=use_lt, | ||
| ) | ||
|
|
||
| self._loss_fn = loss | ||
| self._reduction = getattr(loss, "reduction", "mean") | ||
|
|
||
| if hasattr(self._loss_fn, "reduction"): | ||
| self._loss_fn.reduction = "none" | ||
|
|
||
| def optimization_cycle(self, batch): | ||
| """ | ||
| Compute one reduced loss per condition in the batch. | ||
|
|
||
| :param list[tuple[str, dict]] batch: A batch of data. Each element is a | ||
| tuple containing a condition name and a dictionary of points. | ||
| :return: The reduced losses for all conditions. | ||
| :rtype: dict[str, torch.Tensor] | ||
| """ | ||
| condition_losses = {} | ||
|
|
||
| for condition_name, data in batch: | ||
| condition = self.problem.conditions[condition_name] | ||
| condition_data = dict(data) | ||
|
|
||
| if hasattr(condition_data.get("input"), "requires_grad_"): | ||
| condition_data["input"] = condition_data[ | ||
| "input" | ||
| ].requires_grad_() | ||
|
|
||
| condition_loss_tensor = condition.evaluate( | ||
| condition_data, self, self._loss_fn | ||
| ) | ||
| condition_losses[condition_name] = self._apply_reduction( | ||
| condition_loss_tensor | ||
| ) | ||
|
|
||
| return condition_losses | ||
|
|
||
| def _apply_reduction(self, value): | ||
| """ | ||
| Apply the configured reduction to a non-aggregated condition tensor. | ||
|
|
||
| :param value: The non-aggregated tensor returned by a condition. | ||
| :type value: torch.Tensor | ||
| :return: The reduced scalar tensor. | ||
| :rtype: torch.Tensor | ||
| :raises ValueError: If the reduction is not supported. | ||
| """ | ||
| if self._reduction == "none": | ||
| return value | ||
| if self._reduction == "mean": | ||
| return value.mean() | ||
| if self._reduction == "sum": | ||
| return value.sum() | ||
| raise ValueError(f"Unsupported reduction '{self._reduction}'.") | ||
|
|
||
| @property | ||
| def loss(self): | ||
| """ | ||
| The underlying element-wise loss module. | ||
|
|
||
| :return: The stored loss module. | ||
| :rtype: torch.nn.Module | ||
| """ | ||
| return self._loss_fn | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add weighting