diff --git a/malpolon/models/__init__.py b/malpolon/models/__init__.py index 49e322ff..62da981c 100644 --- a/malpolon/models/__init__.py +++ b/malpolon/models/__init__.py @@ -1,7 +1,9 @@ from .standard_prediction_systems import (ClassificationSystem, + RegressionSystem, GenericPredictionSystem) __all__ = [ # noqa: F405 "GenericPredictionSystem", "ClassificationSystem", + "RegressionSystem", ] diff --git a/malpolon/models/standard_prediction_systems.py b/malpolon/models/standard_prediction_systems.py index 0a0f0d24..9e40c56f 100644 --- a/malpolon/models/standard_prediction_systems.py +++ b/malpolon/models/standard_prediction_systems.py @@ -449,3 +449,62 @@ def __init__( } super().__init__(model, loss, optimizer, metrics=metrics) + + +class RegressionSystem(GenericPredictionSystem): + """Regression task class.""" + def __init__( + self, + model: Union[torch.nn.Module, Mapping], + loss: Union[torch.nn.modules.loss._Loss, str], + optimizer: Union[torch.nn.Module, Mapping] = None, + lr: float = 1e-2, + weight_decay: float = 0, + metrics: Optional[dict[str, Callable]] = None, + loss_kwargs: Optional[dict] = {}, + ): + """Class constructor. + Parameters + ---------- + model : dict + model to use + loss : Union[torch.nn.modules.loss._Loss, str] + loss or string from the predifined LOSS_CALLABLES. + optimizer : Union[torch.nn.Module, Mapping] + optional custom optimizer to use for training + lr : float + learning rate + weight_decay : float + weight decay + metrics : dict + dictionnary containing the metrics to compute. + Keys must match metrics' names and have a subkey with each + metric's functional methods as value. This subkey is either + created from the `malpolon.models.utils.FMETRICS_CALLABLES` + constant or supplied, by the user directly. + loss_kwargs: Optional[dict] = {} + Arguments to be passed to loss constructor. + """ + + metrics = check_metric(metrics) + + self.lr = lr + self.weight_decay = weight_decay + + model = check_model(model) + + if optimizer is None: + print(f'[INFO] No optimizer provided: using AdamW with lr={lr}, weight_decay={weight_decay}') + optimizer = torch.optim.AdamW( + model.parameters(), + lr=self.lr, + weight_decay=self.weight_decay + ) + + if isinstance(loss, torch.nn.modules.loss._Loss): + # If loss is already instantiated, no need to provide kwargs + loss = check_loss(loss) + else: + loss = check_loss(loss)(**loss_kwargs) + + super().__init__(model, loss, optimizer, metrics=metrics) \ No newline at end of file diff --git a/malpolon/models/utils.py b/malpolon/models/utils.py index 12f28439..ecda1bfc 100644 --- a/malpolon/models/utils.py +++ b/malpolon/models/utils.py @@ -35,6 +35,11 @@ 'reduce_lr_on_plateau': lr_scheduler.ReduceLROnPlateau, 'cosine_annealing_lr': lr_scheduler.CosineAnnealingLR, } +LOSS_CALLABLES = {'huber_loss': nn.HuberLoss, + 'mse_loss': nn.MSELoss, + 'cross_entropy_loss': nn.CrossEntropyLoss, + 'bce_loss': nn.BCELoss, } + class CrashHandler(): """Saves the model in case of unexpected crash or user interruption.""" @@ -104,11 +109,11 @@ def check_metric(metrics: OmegaConf) -> OmegaConf: return metrics -def check_loss(loss: nn.modules.loss._Loss) -> nn.modules.loss._Loss: +def check_loss(loss: Union[nn.modules.loss._Loss, str]) -> nn.modules.loss._Loss: """Ensure input loss is a pytorch loss. Args: - loss (nn.modules.loss._Loss): input loss. + loss (Union[nn.modules.loss._Loss, str]): input loss. Raises: ValueError: if input loss isn't a pytorch loss object. @@ -118,7 +123,10 @@ def check_loss(loss: nn.modules.loss._Loss) -> nn.modules.loss._Loss: """ if isinstance(loss, nn.modules.loss._Loss): # pylint: disable=protected-access # noqa return loss - raise ValueError(f"Loss must be of type nn.modules.loss. " + elif isinstance(loss, str): + if loss in LOSS_CALLABLES: + return(LOSS_CALLABLES[loss]) + raise ValueError(f"Loss must be of type nn.modules.loss or string from LOSS_CALLABLES" f"Loss given type {type(loss)} instead")