Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions malpolon/models/standard_prediction_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,71 @@ def __init__(
}

super().__init__(model, loss, optimizer, metrics=metrics)


class RegressionSystem(GenericPredictionSystem):
"""Regression task class."""
def __init__(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some arguments (optimizer kwargs) are not used because incompatible with the proposed default optimizer. I suggest replacing them with the default optimizer's; or simply removing them.

Please update the docstring accordingly too

self,
model: Union[torch.nn.Module, Mapping],
loss: Union[torch.nn.modules.loss._Loss, str],
optimizer: Union[torch.nn.Module, Mapping] = None,
lr: float = 1e-2,
weight_decay: float = 0,
metrics: Optional[dict[str, Callable]] = None,
task: str = 'regression_multilabel',
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Argument not used

loss_kwargs: Optional[dict] = {},
hparams_preprocess: bool = True,
checkpoint_path: Optional[str] = None
):
"""Class constructor.
Parameters
----------
model : dict
model to use
lr : float
learning rate
weight_decay : float
weight decay
momentum : float
value of momentum
nesterov : bool
if True, uses Nesterov's momentum
metrics : dict
dictionnary containing the metrics to compute.
Keys must match metrics' names and have a subkey with each
metric's functional methods as value. This subkey is either
created from the `malpolon.models.utils.FMETRICS_CALLABLES`
constant or supplied, by the user directly.
task : str, optional
Machine learning task (used to format labels accordingly),
by default 'classification_multiclass'. The value determines
the loss to be selected. if 'multilabel' or 'binary' is
in the task, the BCEWithLogitsLoss is selected, otherwise
the CrossEntropyLoss is used.
hparams_preprocess : bool, optional
if True performs preprocessing operations on the hyperparameters,
by default True
"""
if hparams_preprocess:
task = task.split('regression_')[1]
metrics = check_metric(metrics)

self.lr = lr
self.weight_decay = weight_decay

self.checkpoint_path = checkpoint_path
model = check_model(model)

if optimizer is None:
print(f'[INFO] No optimizer provided: using AdamW with lr={lr}, weight_decay={weight_decay}')
optimizer = torch.optim.AdamW(
model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay
)


loss = check_loss(loss)(**loss_kwargs)

super().__init__(model, loss, optimizer, metrics=metrics)
6 changes: 4 additions & 2 deletions malpolon/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def check_metric(metrics: OmegaConf) -> OmegaConf:
return metrics


def check_loss(loss: nn.modules.loss._Loss) -> nn.modules.loss._Loss:
def check_loss(loss: Union[nn.modules.loss._Loss, str]) -> nn.modules.loss._Loss:
"""Ensure input loss is a pytorch loss.

Args:
Expand All @@ -118,7 +118,9 @@ def check_loss(loss: nn.modules.loss._Loss) -> nn.modules.loss._Loss:
"""
if isinstance(loss, nn.modules.loss._Loss): # pylint: disable=protected-access # noqa
return loss
raise ValueError(f"Loss must be of type nn.modules.loss. "
elif isinstance(loss, str):
return eval(loss)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are aiming at removing all calls to eval() for security purposes. Please take a look at functions check_optimizer(), check_scheduler() or check_metrics(), on our new branch "no-more-evals" to get an understanding at what a similar check_loss() function should look like

raise ValueError(f"Loss must be of type nn.modules.loss or callable string"
f"Loss given type {type(loss)} instead")


Expand Down