-
Notifications
You must be signed in to change notification settings - Fork 97
0.3 kan #764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+1,463
−7
Merged
0.3 kan #764
Changes from 6 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
8307d12
KAN implementation (#611)
ajacoby9 dc31498
KAN with non-vectorized spline
ndem0 a548161
Fix minor problem, black formatter
ndem0 c1d7e26
fix output dimension for vectorized spline
ndem0 e2ec4d0
fix index mismatch and remove unused function
GiovanniCanali 88e4bbd
Fix vectorized splines and implement working KAN
GiovanniCanali d9b59ff
minor fix to output dimension in vector splines
GiovanniCanali ad8a27f
add tests
GiovanniCanali d4dfb65
add rst files
GiovanniCanali 5bd5902
add docstrings
GiovanniCanali 6eb49fb
implement derivatives for vector splines
GiovanniCanali 6caa873
fix minor shape bug
GiovanniCanali File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,146 @@ | ||
| """Module for the Kolmogorov-Arnold Network block.""" | ||
|
|
||
| import torch | ||
| from pina._src.model.vectorized_spline import VectorizedSpline | ||
| from pina._src.core.utils import check_consistency, check_positive_integer | ||
|
|
||
|
|
||
| class KANBlock(torch.nn.Module): | ||
| """ | ||
| TODO: docstring. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| input_dimensions, | ||
| output_dimensions, | ||
| spline_order=3, | ||
| n_knots=10, | ||
| grid_range=[0, 1], | ||
| base_function=torch.nn.SiLU, | ||
| use_base_linear=True, | ||
| use_bias=True, | ||
| init_scale_spline=1e-2, | ||
| init_scale_base=1.0, | ||
| ): | ||
| """ | ||
| Initialization of the :class:`KANBlock` class. | ||
|
|
||
| :param int input_dimensions: The number of input features. | ||
| :param int output_dimensions: The number of output features. | ||
| :param int spline_order: The order of each spline basis function. | ||
| Default is 3 (cubic splines). | ||
| :param int n_knots: The number of knots for each spline basis function. | ||
| Default is 10. | ||
| :param grid_range: The range for the spline knots. It must be either a | ||
| list or a tuple of the form [min, max]. Default is [0, 1]. | ||
| :type grid_range: list | tuple. | ||
| :param torch.nn.Module base_function: The base activation function to be | ||
| applied to the input before the linear transformation. Default is | ||
| :class:`torch.nn.SiLU`. | ||
| :param bool use_base_linear: Whether to include a linear transformation | ||
| of the base function output. Default is True. | ||
| :param bool use_bias: Whether to include a bias term in the output. | ||
| Default is True. | ||
| :param init_scale_spline: The scale for initializing each spline | ||
| control points. Default is 1e-2. | ||
| :type init_scale_spline: float | int. | ||
| :param init_scale_base: The scale for initializing the base linear | ||
| weights. Default is 1.0. | ||
| :type init_scale_base: float | int. | ||
| :raises ValueError: If ``grid_range`` is not of length 2. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| # Check consistency | ||
| check_consistency(base_function, torch.nn.Module, subclass=True) | ||
| check_positive_integer(input_dimensions, strict=True) | ||
| check_positive_integer(output_dimensions, strict=True) | ||
| check_positive_integer(spline_order, strict=True) | ||
| check_positive_integer(n_knots, strict=True) | ||
| check_consistency(use_base_linear, bool) | ||
| check_consistency(use_bias, bool) | ||
| check_consistency(init_scale_spline, (int, float)) | ||
| check_consistency(init_scale_base, (int, float)) | ||
| check_consistency(grid_range, (int, float)) | ||
|
|
||
| # Raise error if grid_range is not valid | ||
| if len(grid_range) != 2: | ||
| raise ValueError("Grid must be a list or tuple with two elements.") | ||
|
|
||
| # Knots for the spline basis functions | ||
| initial_knots = torch.ones(spline_order) * grid_range[0] | ||
| final_knots = torch.ones(spline_order) * grid_range[1] | ||
|
|
||
| # Number of internal knots | ||
| n_internal = max(0, n_knots - 2 * spline_order) | ||
|
|
||
| # Internal knots are uniformly spaced in the grid range | ||
| internal_knots = torch.linspace( | ||
| grid_range[0], grid_range[1], n_internal + 2 | ||
| )[1:-1] | ||
|
|
||
| # Define the knots | ||
| knots = torch.cat((initial_knots, internal_knots, final_knots)) | ||
| knots = knots.unsqueeze(0).repeat(input_dimensions, 1) | ||
|
|
||
| # Define the control points for the spline basis functions | ||
| control_points = ( | ||
| torch.randn( | ||
| input_dimensions, | ||
| output_dimensions, | ||
| knots.shape[-1] - spline_order, | ||
| ) | ||
| * init_scale_spline | ||
| ) | ||
|
|
||
| # Define the vectorized spline module | ||
| self.spline = VectorizedSpline( | ||
| order=spline_order, knots=knots, control_points=control_points | ||
| ) | ||
|
|
||
| # Initialize the base function | ||
| self.base_function = base_function() | ||
|
|
||
| # Initialize the base linear weights if needed | ||
| if use_base_linear: | ||
| self.base_weight = torch.nn.Parameter( | ||
| torch.randn(output_dimensions, input_dimensions) | ||
| * (init_scale_base / (input_dimensions**0.5)) | ||
| ) | ||
| else: | ||
| self.register_parameter("base_weight", None) | ||
|
|
||
| # Initialize the bias term if needed | ||
| if use_bias: | ||
| self.bias = torch.nn.Parameter(torch.zeros(output_dimensions)) | ||
| else: | ||
| self.register_parameter("bias", None) | ||
|
|
||
| def forward(self, x): | ||
| """ | ||
| Forward pass of the :class:`KANBlock`. It transforms the input using a | ||
| vectorized spline basis and optionally adds a linear transformation of a | ||
| base activation function. | ||
|
|
||
| The input is expected to have shape (batch_size, input_dimensions) and | ||
| the output will have shape (batch_size, output_dimensions). | ||
|
|
||
| :param torch.Tensor x: The input tensor for the model. | ||
| :return: The output tensor of the model. | ||
| :rtype: torch.Tensor | ||
| """ | ||
| y = self.spline(x) | ||
|
|
||
| if self.base_weight is not None: | ||
| base_x = self.base_function(x) | ||
| base_out = torch.einsum("bi,oi->bio", base_x, self.base_weight) | ||
| y = y + base_out | ||
|
|
||
| # aggregate contributions from all input dimensions | ||
| y = y.sum(dim=1) | ||
|
|
||
| if self.bias is not None: | ||
| y = y + self.bias | ||
|
|
||
| return y |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| import torch | ||
| from pina._src.model.block.kan_block import KANBlock | ||
| from pina._src.core.utils import check_consistency | ||
|
|
||
|
|
||
| class KolmogorovArnoldNetwork(torch.nn.Module): | ||
| """ | ||
| TODO: add docstring. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| layers, | ||
| spline_order=3, | ||
| n_knots=10, | ||
| grid_range=[-1, 1], | ||
| base_function=torch.nn.SiLU, | ||
| use_base_linear=True, | ||
| use_bias=True, | ||
| init_scale_spline=1e-2, | ||
| init_scale_base=1.0, | ||
| ): | ||
| """ | ||
| Initialization of the :class:`KolmogorovArnoldNetwork` class. | ||
|
|
||
| :param layers: A list of integers specifying the sizes of each layer, | ||
| including input and output dimensions. | ||
| :type layers: list | tuple. | ||
| :param int spline_order: The order of each spline basis function. | ||
| Default is 3 (cubic splines). | ||
| :param int n_knots: The number of knots for each spline basis function. | ||
| Default is 3. | ||
| :param grid_range: The range for the spline knots. It must be either a | ||
| list or a tuple of the form [min, max]. Default is [0, 1]. | ||
| :type grid_range: list | tuple. | ||
| :param torch.nn.Module base_function: The base activation function to be | ||
| applied to the input before the linear transformation. Default is | ||
| :class:`torch.nn.SiLU`. | ||
| :param bool use_base_linear: Whether to include a linear transformation | ||
| of the base function output. Default is True. | ||
| :param bool use_bias: Whether to include a bias term in the output. | ||
| Default is True. | ||
| :param init_scale_spline: The scale for initializing each spline | ||
| control points. Default is 1e-2. | ||
| :type init_scale_spline: float | int. | ||
| :param init_scale_base: The scale for initializing the base linear | ||
| weights. Default is 1.0. | ||
| :type init_scale_base: float | int. | ||
| :raises ValueError: If ``grid_range`` is not of length 2. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| # Check consistency -- all other checks are performed in KANBlock | ||
| check_consistency(layers, int) | ||
| if len(layers) < 2: | ||
| raise ValueError( | ||
| "`Provide at least two elements for layers (input and output)." | ||
| ) | ||
|
|
||
| # Initialize KAN blocks | ||
| self.kan_layers = torch.nn.ModuleList( | ||
| [ | ||
| KANBlock( | ||
| input_dimensions=layers[i], | ||
| output_dimensions=layers[i + 1], | ||
| spline_order=spline_order, | ||
| n_knots=n_knots, | ||
| grid_range=grid_range, | ||
| base_function=base_function, | ||
| use_base_linear=use_base_linear, | ||
| use_bias=use_bias, | ||
| init_scale_spline=init_scale_spline, | ||
| init_scale_base=init_scale_base, | ||
| ) | ||
| for i in range(len(layers) - 1) | ||
| ] | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| """ | ||
| TODO: add docstring. | ||
| """ | ||
| for layer in self.kan_layers: | ||
| x = layer(x) | ||
|
|
||
| return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.