-
Notifications
You must be signed in to change notification settings - Fork 9
Created regression_system and updated check_loss #71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
161dc8f
c557943
a666126
b4266a7
7480548
6c1f488
2a55661
be779be
cc2d146
422f4c4
62913ae
88f51df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -449,3 +449,71 @@ def __init__( | |
| } | ||
|
|
||
| super().__init__(model, loss, optimizer, metrics=metrics) | ||
|
|
||
|
|
||
| class RegressionSystem(GenericPredictionSystem): | ||
| """Regression task class.""" | ||
| def __init__( | ||
| self, | ||
| model: Union[torch.nn.Module, Mapping], | ||
| loss: Union[torch.nn.modules.loss._Loss, str], | ||
| optimizer: Union[torch.nn.Module, Mapping] = None, | ||
| lr: float = 1e-2, | ||
| weight_decay: float = 0, | ||
| metrics: Optional[dict[str, Callable]] = None, | ||
| task: str = 'regression_multilabel', | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Argument not used |
||
| loss_kwargs: Optional[dict] = {}, | ||
| hparams_preprocess: bool = True, | ||
| checkpoint_path: Optional[str] = None | ||
| ): | ||
| """Class constructor. | ||
| Parameters | ||
| ---------- | ||
| model : dict | ||
| model to use | ||
| lr : float | ||
| learning rate | ||
| weight_decay : float | ||
| weight decay | ||
| momentum : float | ||
| value of momentum | ||
| nesterov : bool | ||
| if True, uses Nesterov's momentum | ||
| metrics : dict | ||
| dictionnary containing the metrics to compute. | ||
| Keys must match metrics' names and have a subkey with each | ||
| metric's functional methods as value. This subkey is either | ||
| created from the `malpolon.models.utils.FMETRICS_CALLABLES` | ||
| constant or supplied, by the user directly. | ||
| task : str, optional | ||
| Machine learning task (used to format labels accordingly), | ||
| by default 'classification_multiclass'. The value determines | ||
| the loss to be selected. if 'multilabel' or 'binary' is | ||
| in the task, the BCEWithLogitsLoss is selected, otherwise | ||
| the CrossEntropyLoss is used. | ||
| hparams_preprocess : bool, optional | ||
| if True performs preprocessing operations on the hyperparameters, | ||
| by default True | ||
| """ | ||
| if hparams_preprocess: | ||
| task = task.split('regression_')[1] | ||
| metrics = check_metric(metrics) | ||
|
|
||
| self.lr = lr | ||
| self.weight_decay = weight_decay | ||
|
|
||
| self.checkpoint_path = checkpoint_path | ||
| model = check_model(model) | ||
|
|
||
| if optimizer is None: | ||
| print(f'[INFO] No optimizer provided: using AdamW with lr={lr}, weight_decay={weight_decay}') | ||
| optimizer = torch.optim.AdamW( | ||
| model.parameters(), | ||
| lr=self.lr, | ||
| weight_decay=self.weight_decay | ||
| ) | ||
|
|
||
|
|
||
| loss = check_loss(loss)(**loss_kwargs) | ||
|
|
||
| super().__init__(model, loss, optimizer, metrics=metrics) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -104,7 +104,7 @@ def check_metric(metrics: OmegaConf) -> OmegaConf: | |
| return metrics | ||
|
|
||
|
|
||
| def check_loss(loss: nn.modules.loss._Loss) -> nn.modules.loss._Loss: | ||
| def check_loss(loss: Union[nn.modules.loss._Loss, str]) -> nn.modules.loss._Loss: | ||
| """Ensure input loss is a pytorch loss. | ||
|
|
||
| Args: | ||
|
|
@@ -118,7 +118,9 @@ def check_loss(loss: nn.modules.loss._Loss) -> nn.modules.loss._Loss: | |
| """ | ||
| if isinstance(loss, nn.modules.loss._Loss): # pylint: disable=protected-access # noqa | ||
| return loss | ||
| raise ValueError(f"Loss must be of type nn.modules.loss. " | ||
| elif isinstance(loss, str): | ||
| return eval(loss) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are aiming at removing all calls to |
||
| raise ValueError(f"Loss must be of type nn.modules.loss or callable string" | ||
| f"Loss given type {type(loss)} instead") | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some arguments (optimizer kwargs) are not used because incompatible with the proposed default optimizer. I suggest replacing them with the default optimizer's; or simply removing them.
Please update the docstring accordingly too