Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions malpolon/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .standard_prediction_systems import (ClassificationSystem,
RegressionSystem,
GenericPredictionSystem)

__all__ = [ # noqa: F405
"GenericPredictionSystem",
"ClassificationSystem",
"RegressionSystem",
]
59 changes: 59 additions & 0 deletions malpolon/models/standard_prediction_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,62 @@ def __init__(
}

super().__init__(model, loss, optimizer, metrics=metrics)


class RegressionSystem(GenericPredictionSystem):
"""Regression task class."""
def __init__(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some arguments (optimizer kwargs) are not used because incompatible with the proposed default optimizer. I suggest replacing them with the default optimizer's; or simply removing them.

Please update the docstring accordingly too

self,
model: Union[torch.nn.Module, Mapping],
loss: Union[torch.nn.modules.loss._Loss, str],
optimizer: Union[torch.nn.Module, Mapping] = None,
lr: float = 1e-2,
weight_decay: float = 0,
metrics: Optional[dict[str, Callable]] = None,
loss_kwargs: Optional[dict] = {},
):
"""Class constructor.
Parameters
----------
model : dict
model to use
loss : Union[torch.nn.modules.loss._Loss, str]
loss or string from the predifined LOSS_CALLABLES.
optimizer : Union[torch.nn.Module, Mapping]
optional custom optimizer to use for training
lr : float
learning rate
weight_decay : float
weight decay
metrics : dict
dictionnary containing the metrics to compute.
Keys must match metrics' names and have a subkey with each
metric's functional methods as value. This subkey is either
created from the `malpolon.models.utils.FMETRICS_CALLABLES`
constant or supplied, by the user directly.
loss_kwargs: Optional[dict] = {}
Arguments to be passed to loss constructor.
"""

metrics = check_metric(metrics)

self.lr = lr
self.weight_decay = weight_decay

model = check_model(model)

if optimizer is None:
print(f'[INFO] No optimizer provided: using AdamW with lr={lr}, weight_decay={weight_decay}')
optimizer = torch.optim.AdamW(
model.parameters(),
lr=self.lr,
weight_decay=self.weight_decay
)

if isinstance(loss, torch.nn.modules.loss._Loss):
# If loss is already instantiated, no need to provide kwargs
loss = check_loss(loss)
else:
loss = check_loss(loss)(**loss_kwargs)

super().__init__(model, loss, optimizer, metrics=metrics)
14 changes: 11 additions & 3 deletions malpolon/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@
'reduce_lr_on_plateau': lr_scheduler.ReduceLROnPlateau,
'cosine_annealing_lr': lr_scheduler.CosineAnnealingLR, }

LOSS_CALLABLES = {'huber_loss': nn.HuberLoss,
'mse_loss': nn.MSELoss,
'cross_entropy_loss': nn.CrossEntropyLoss,
'bce_loss': nn.BCELoss, }


class CrashHandler():
"""Saves the model in case of unexpected crash or user interruption."""
Expand Down Expand Up @@ -104,11 +109,11 @@ def check_metric(metrics: OmegaConf) -> OmegaConf:
return metrics


def check_loss(loss: nn.modules.loss._Loss) -> nn.modules.loss._Loss:
def check_loss(loss: Union[nn.modules.loss._Loss, str]) -> nn.modules.loss._Loss:
"""Ensure input loss is a pytorch loss.

Args:
loss (nn.modules.loss._Loss): input loss.
loss (Union[nn.modules.loss._Loss, str]): input loss.

Raises:
ValueError: if input loss isn't a pytorch loss object.
Expand All @@ -118,7 +123,10 @@ def check_loss(loss: nn.modules.loss._Loss) -> nn.modules.loss._Loss:
"""
if isinstance(loss, nn.modules.loss._Loss): # pylint: disable=protected-access # noqa
return loss
raise ValueError(f"Loss must be of type nn.modules.loss. "
elif isinstance(loss, str):
if loss in LOSS_CALLABLES:
return(LOSS_CALLABLES[loss])
raise ValueError(f"Loss must be of type nn.modules.loss or string from LOSS_CALLABLES"
f"Loss given type {type(loss)} instead")


Expand Down