diff --git a/pit/_compat.py b/pit/_compat.py new file mode 100644 index 0000000..1f1860f --- /dev/null +++ b/pit/_compat.py @@ -0,0 +1,48 @@ +"""Compatibility layer providing JAX-like APIs with NumPy fallbacks.""" + +from __future__ import annotations + +import numpy as _np + +try: # pragma: no cover - prefer JAX when available + import jax + import jax.numpy as jnp # type: ignore + import jax.nn as jnn # type: ignore + import jax.random as jr # type: ignore + Array = jax.Array + USING_JAX = True +except ImportError: # pragma: no cover - fallback path + jnp = _np # type: ignore + USING_JAX = False + + class _NN: + @staticmethod + def softplus(x): + return _np.log1p(_np.exp(-_np.abs(x))) + _np.maximum(x, 0.0) + + jnn = _NN() # type: ignore + + class _Random: + @staticmethod + def PRNGKey(seed: int): + return _np.random.default_rng(seed) + + @staticmethod + def split(key): + seed1 = int(key.integers(0, 2**32 - 1)) + seed2 = int(key.integers(0, 2**32 - 1)) + return _np.random.default_rng(seed1), _np.random.default_rng(seed2) + + @staticmethod + def multivariate_normal(key, mean, cov, shape): + size = int(shape[0]) if shape else None + return key.multivariate_normal(mean, cov, size=size) + + @staticmethod + def normal(key, shape): + return key.normal(size=shape) + + jr = _Random() # type: ignore + Array = _np.ndarray # type: ignore + +__all__ = ["jnp", "jnn", "jr", "Array", "USING_JAX"] diff --git a/pit/dynamics/__init__.py b/pit/dynamics/__init__.py index a7b9dea..35b87c3 100644 --- a/pit/dynamics/__init__.py +++ b/pit/dynamics/__init__.py @@ -1,18 +1,17 @@ -import torch -from torch import nn +"""Base definitions for dynamics models.""" + +from __future__ import annotations + from ..parameters.definitions import ParameterSample -class Dynamics(nn.Module): - """Base Class for dynamics""" - def __init__(self) -> None: - super().__init__() - def forward(self, states, inputs, params: ParameterSample): - """ - Dynamics evolutions +class Dynamics: + """Base class for dynamics models.""" - Args: - states: Dimension of (N, state_dims) - inputs: Dimension of (N, control_inputs) - """ + parameter_list: list[str] + + def forward(self, states, inputs, params: ParameterSample): # pragma: no cover - abstract raise NotImplementedError + + def __call__(self, states, inputs, params: ParameterSample): + return self.forward(states, inputs, params) diff --git a/pit/dynamics/_batching.py b/pit/dynamics/_batching.py index b5030eb..0f18928 100644 --- a/pit/dynamics/_batching.py +++ b/pit/dynamics/_batching.py @@ -1,34 +1,23 @@ from __future__ import annotations -from typing import Callable, Tuple, TypeVar +from typing import Callable, Tuple -import torch +from .._compat import Array, jnp -TensorLike = TypeVar("TensorLike", bound=torch.Tensor) - -def ensure_batch(tensor: TensorLike) -> Tuple[TensorLike, Callable[[TensorLike], TensorLike]]: +def ensure_batch(tensor: Array) -> Tuple[Array, Callable[[Array], Array]]: """Ensure a tensor has a batch dimension. - The helper promotes one-dimensional tensors to batched tensors by - unsqueezing a leading dimension. A callable is returned that can be - applied to tensors with the resulting shape to restore the original - dimensionality. - - Args: - tensor: A tensor with shape ``(dim,)`` or ``(batch, dim)``. - - Returns: - A tuple containing the (potentially) batched tensor and a callable to - restore tensors with matching shape back to their original - dimensionality. + Promotes one-dimensional arrays to batched arrays by adding a leading + dimension. A callable is returned that can be applied to arrays with the + resulting shape to restore the original dimensionality. """ if tensor.ndim == 1: - batched = tensor.unsqueeze(0) + batched = jnp.expand_dims(tensor, axis=0) - def restore(result: TensorLike) -> TensorLike: - return result.squeeze(0) + def restore(result: Array) -> Array: + return jnp.squeeze(result, axis=0) return batched, restore diff --git a/pit/dynamics/dynamic_bicycle.py b/pit/dynamics/dynamic_bicycle.py index 5320a66..12ccb7b 100644 --- a/pit/dynamics/dynamic_bicycle.py +++ b/pit/dynamics/dynamic_bicycle.py @@ -1,160 +1,120 @@ from __future__ import annotations +from .._compat import jnp + from . import Dynamics from ._batching import ensure_batch -from ..parameters import PointParameterGroup, CovariantNormalParameterGroup, NormalParameterGroup from ..parameters.definitions import ParameterSample -import torch -from torch import nn - X, Y, YAW, VX, VY, YAW_RATE, STEERING_ANGLE = 0, 1, 2, 3, 4, 5, 6 DRIVE_FORCE, STEER_SPEED = 0, 1 FRX, FFY, FRY = 0, 1, 2 -class DynamicBicycle(Dynamics, nn.Module): - """ - This is a dynamic bicycle model - From AMZ Driverless: The Full Autonomous Racing System - Model reference point: CoG - Longitudinal drive-train forces act on the center of gravity - State Variable [x, y, yaw, vx, vy, yaw rate, steering angle] - Control Inputs [drive force, steering speed] - """ +class DynamicBicycle(Dynamics): + """Dynamic bicycle model based on the AMZ Driverless formulation.""" def __init__(self, lf, lr, Iz, m, Df, Cf, Bf, Dr, Cr, Br, Cm, Cr0, Cr2, **kwargs) -> None: - super().__init__() - self.parameter_list = ['lf', 'lr', 'Iz', 'm', 'Df', 'Cf', 'Bf', 'Dr', 'Cr', 'Br', 'Cm', 'Cr0', 'Cr2'] + del kwargs + self.parameter_list = [ + "lf", + "lr", + "Iz", + "m", + "Df", + "Cf", + "Bf", + "Dr", + "Cr", + "Br", + "Cm", + "Cr0", + "Cr2", + ] self.initial_values = { - 'lf': lf, - 'lr': lr, - 'Iz': Iz, - 'm': m, - 'Df': Df, - 'Cf': Cf, - 'Bf': Bf, - 'Dr': Dr, - 'Cr': Cr, - 'Br': Br, - 'Cm': Cm, - 'Cr0': Cr0, - 'Cr2': Cr2, + "lf": lf, + "lr": lr, + "Iz": Iz, + "m": m, + "Df": Df, + "Cf": Cf, + "Bf": Bf, + "Dr": Dr, + "Cr": Cr, + "Br": Br, + "Cm": Cm, + "Cr0": Cr0, + "Cr2": Cr2, } - # if param_type == 'point': - # self.params = PointParameterGroup(self.param_names, self.initial_values) - # elif param_type == 'normal': - # self.params = NormalParameterGroup(self.param_names, self.initial_values) - # elif param_type == 'covariant': - # # raise FutureWarning("CovariantNormalParameterGroup is not implemented yet") - # self.params = CovariantNormalParameterGroup(self.param_names, self.initial_values) - - # self.lf = torch.nn.Parameter(torch.tensor(lf, dtype=torch.float32)) - # self.lr = torch.nn.Parameter(torch.tensor(lr, dtype=torch.float32)) - # self.Iz = torch.nn.Parameter(torch.tensor(Iz, dtype=torch.float32)) - # self.mass = torch.nn.Parameter(torch.tensor(mass, dtype=torch.float32)) - # self.Df = torch.nn.Parameter(torch.tensor(Df, dtype=torch.float32)) - # self.Cf = torch.nn.Parameter(torch.tensor(Cf, dtype=torch.float32)) - # self.Bf = torch.nn.Parameter(torch.tensor(Bf, dtype=torch.float32)) - # self.Dr = torch.nn.Parameter(torch.tensor(Dr, dtype=torch.float32)) - # self.Cr = torch.nn.Parameter(torch.tensor(Cr, dtype=torch.float32)) - # self.Br = torch.nn.Parameter(torch.tensor(Br, dtype=torch.float32)) - # self.Cm = torch.nn.Parameter(torch.tensor(Cm, dtype=torch.float32)) - # self.Cr0 = torch.nn.Parameter(torch.tensor(Cr0, dtype=torch.float32)) - # self.Cr2 = torch.nn.Parameter(torch.tensor(Cr2, dtype=torch.float32)) - - def to(self, *args, **kwargs): - super().to(*args, **kwargs) - # self.params.to(*args, **kwargs) def calculate_tire_forces(self, states, control_inputs, params: ParameterSample): - """Get the tire forces at this point. - - Args: - states: Shape of ``(B, 7)`` or ``(7,)``. - control_inputs: Shape of ``(B, 2)`` or ``(2,)``. - - Returns: - Tire forces with shape ``(B, 3)`` or ``(3,)`` [Frx, Ffy, Fry]. - """ + states = jnp.asarray(states) + control_inputs = jnp.asarray(control_inputs) states, unbatch_states = ensure_batch(states) control_inputs, _ = ensure_batch(control_inputs) - device = params['lf'].device - tire_forces = torch.zeros((*states.shape[:-1], 3), device=device, dtype=states.dtype) - - alpha_f = states[..., STEERING_ANGLE] - torch.arctan( - (states[..., YAW_RATE] * params['lf'] + states[..., VY]) / states[..., VX] + alpha_f = states[..., STEERING_ANGLE] - jnp.arctan( + (states[..., YAW_RATE] * params["lf"] + states[..., VY]) / states[..., VX] ) - alpha_r = torch.arctan( - (states[..., YAW_RATE] * params['lr'] - states[..., VY]) / states[..., VX] + alpha_r = jnp.arctan( + (states[..., YAW_RATE] * params["lr"] - states[..., VY]) / states[..., VX] ) - tire_forces[..., FRX] = ( - params['Cm'] * control_inputs[..., DRIVE_FORCE] - - params['Cr0'] - - params['Cr2'] * states[..., VX] ** 2.0 - ) - tire_forces[..., FFY] = params['Df'] * torch.sin( - params['Cf'] * torch.arctan(params['Bf'] * alpha_f) - ) - tire_forces[..., FRY] = params['Dr'] * torch.sin( - params['Cr'] * torch.arctan(params['Br'] * alpha_r) + frx = ( + params["Cm"] * control_inputs[..., DRIVE_FORCE] + - params["Cr0"] + - params["Cr2"] * states[..., VX] ** 2.0 ) + ffy = params["Df"] * jnp.sin(params["Cf"] * jnp.arctan(params["Bf"] * alpha_f)) + fry = params["Dr"] * jnp.sin(params["Cr"] * jnp.arctan(params["Br"] * alpha_r)) + tire_forces = jnp.stack([frx, ffy, fry], axis=-1) return unbatch_states(tire_forces) def forward(self, states, control_inputs, params: ParameterSample): - """Get the evaluated ODEs of the state at this point. - - Args: - states: Shape of ``(B, 7)`` or ``(7,)``. - control_inputs: Shape of ``(B, 2)`` or ``(2,)``. - params: Parameter sample containing the vehicle parameters. - """ + states = jnp.asarray(states) + control_inputs = jnp.asarray(control_inputs) states, unbatch_states = ensure_batch(states) control_inputs, _ = ensure_batch(control_inputs) - diff = torch.zeros_like(states) tire_forces = self.calculate_tire_forces(states, control_inputs, params) - diff[..., X] = ( - states[..., VX] * torch.cos(states[..., YAW]) - - states[..., VY] * torch.sin(states[..., YAW]) - ) - diff[..., Y] = ( - states[..., VX] * torch.sin(states[..., YAW]) - - states[..., VY] * torch.cos(states[..., YAW]) - ) - diff[..., YAW] = states[..., YAW_RATE] - diff[..., VX] = ( - 1.0 - / params['m'] - * ( - tire_forces[..., FRX] - - tire_forces[..., FFY] * torch.sin(states[..., STEERING_ANGLE]) - + states[..., VY] * states[..., YAW_RATE] * params['m'] - ) + diff_x = ( + states[..., VX] * jnp.cos(states[..., YAW]) + - states[..., VY] * jnp.sin(states[..., YAW]) ) - diff[..., VY] = ( - 1.0 - / params['m'] - * ( - tire_forces[..., FRY] - + tire_forces[..., FFY] * torch.cos(states[..., STEERING_ANGLE]) - - states[..., VX] * states[..., YAW_RATE] * params['m'] - ) + diff_y = ( + states[..., VX] * jnp.sin(states[..., YAW]) + - states[..., VY] * jnp.cos(states[..., YAW]) ) - diff[..., YAW_RATE] = ( - 1.0 - / params['Iz'] - * ( - tire_forces[..., FFY] - * params['lf'] - * torch.cos(states[..., STEERING_ANGLE]) - - tire_forces[..., FRY] * params['lr'] - ) + diff_yaw = states[..., YAW_RATE] + diff_vx = ( + tire_forces[..., FRX] + - tire_forces[..., FFY] * jnp.sin(states[..., STEERING_ANGLE]) + + states[..., VY] * states[..., YAW_RATE] * params["m"] + ) / params["m"] + diff_vy = ( + tire_forces[..., FRY] + + tire_forces[..., FFY] * jnp.cos(states[..., STEERING_ANGLE]) + - states[..., VX] * states[..., YAW_RATE] * params["m"] + ) / params["m"] + diff_yaw_rate = ( + tire_forces[..., FFY] * params["lf"] * jnp.cos(states[..., STEERING_ANGLE]) + - tire_forces[..., FRY] * params["lr"] + ) / params["Iz"] + diff_steer = control_inputs[..., STEER_SPEED] + + diff = jnp.stack( + [ + diff_x, + diff_y, + diff_yaw, + diff_vx, + diff_vy, + diff_yaw_rate, + diff_steer, + ], + axis=-1, ) - diff[..., STEERING_ANGLE] = control_inputs[..., STEER_SPEED] return unbatch_states(diff) diff --git a/pit/dynamics/kinematic_bicycle.py b/pit/dynamics/kinematic_bicycle.py index 6fcbf04..f106453 100644 --- a/pit/dynamics/kinematic_bicycle.py +++ b/pit/dynamics/kinematic_bicycle.py @@ -1,58 +1,49 @@ +from __future__ import annotations + +from .._compat import jnp + from . import Dynamics from ._batching import ensure_batch from ..parameters.definitions import ParameterSample -import torch -from torch import nn -class Bicycle(Dynamics, nn.Module): - """ - This is a kinematic bicycle model, with the center of the vehicle as the control point. - Based on - https://thomasfermi.github.io/Algorithms-for-Automated-Driving/Control/BicycleModel.html - """ - def __init__(self, wheelbase) -> None: - super().__init__() - self.wb = torch.nn.Parameter(torch.tensor(wheelbase, dtype=torch.float32)) +class Bicycle(Dynamics): + """Kinematic bicycle model with the vehicle centre as the reference point.""" - def forward(self, states, control_inputs): - """ Get the evaluated ODEs of the state at this point + def __init__(self, wheelbase: float) -> None: + self.wb = jnp.array(wheelbase, dtype=jnp.float32) - Args: - states (): Shape of (B, 4) or (4) - control_inputs (): Shape of (B, 2) or (2) - """ + def forward(self, states, control_inputs): X, Y, THETA, V = 0, 1, 2, 3 STEER, ACCEL = 0, 1 + + states = jnp.asarray(states) + control_inputs = jnp.asarray(control_inputs) states, unbatch_states = ensure_batch(states) control_inputs, _ = ensure_batch(control_inputs) - diff = torch.zeros_like(states) - diff[..., X] = states[..., V] * torch.cos(states[..., THETA]) - diff[..., Y] = states[..., V] * torch.sin(states[..., THETA]) - diff[..., THETA] = ( - states[..., V] * torch.tan(control_inputs[..., STEER]) - ) / self.wb - diff[..., V] = control_inputs[..., ACCEL] + diff_x = states[..., V] * jnp.cos(states[..., THETA]) + diff_y = states[..., V] * jnp.sin(states[..., THETA]) + diff_theta = states[..., V] * jnp.tan(control_inputs[..., STEER]) / self.wb + diff_v = control_inputs[..., ACCEL] + diff = jnp.stack([diff_x, diff_y, diff_theta, diff_v], axis=-1) return unbatch_states(diff) def kinematic_bicycle(states, control_inputs, params: ParameterSample): - """Get the evaluated ODEs of the state at this point - - Args: - states (): Shape of (B, 5) or (5) - control_inputs (): Shape of (B, 2) or (2) - """ X, Y, THETA, V, YAW = 0, 1, 2, 3, 4 STEER, ACCEL = 0, 1 - beta = torch.atan(torch.tan(states[..., THETA]) * (params['lr']/(params['lf']+params['lr']))) + states = jnp.asarray(states) + control_inputs = jnp.asarray(control_inputs) + + beta = jnp.arctan( + jnp.tan(states[..., THETA]) * (params["lr"] / (params["lf"] + params["lr"])) + ) - diff = torch.zeros_like(states) - diff[..., X] = states[..., V] * torch.cos(states[..., THETA] + beta) - diff[..., Y] = states[..., V] * torch.sin(states[..., THETA] + beta) - diff[..., THETA] = control_inputs[..., STEER] - diff[..., V] = control_inputs[..., ACCEL] - diff[..., YAW] = torch.cos(beta) * torch.tan(states[..., THETA]) / (params['lf']+params['lr']) - return diff \ No newline at end of file + diff_x = states[..., V] * jnp.cos(states[..., THETA] + beta) + diff_y = states[..., V] * jnp.sin(states[..., THETA] + beta) + diff_theta = control_inputs[..., STEER] + diff_v = control_inputs[..., ACCEL] + diff_yaw = jnp.cos(beta) * jnp.tan(states[..., THETA]) / (params["lf"] + params["lr"]) + return jnp.stack([diff_x, diff_y, diff_theta, diff_v, diff_yaw], axis=-1) diff --git a/pit/dynamics/unicycle.py b/pit/dynamics/unicycle.py index 1196810..c2e1000 100644 --- a/pit/dynamics/unicycle.py +++ b/pit/dynamics/unicycle.py @@ -1,33 +1,30 @@ +from __future__ import annotations + +from .._compat import jnp + from . import Dynamics from ._batching import ensure_batch -import torch -from torch import nn -class Unicycle(Dynamics, nn.Module): - """ - This is a kinematic Unicycle model. - """ +class Unicycle(Dynamics): + """Kinematic unicycle model.""" + def __init__(self) -> None: - super().__init__() - self.parameter_list = ['null'] + self.parameter_list = ["null"] def forward(self, states, control_inputs, params): - """ Get the evaluated ODEs of the state at this point - - Args: - states (): Shape of (B, 4) or (4) - control_inputs (): Shape of (B, 2) or (2) - """ + del params # Unused for the unicycle model. X, Y, THETA, V = 0, 1, 2, 3 STEER, ACCEL = 0, 1 + + states = jnp.asarray(states) + control_inputs = jnp.asarray(control_inputs) states, unbatch_states = ensure_batch(states) control_inputs, _ = ensure_batch(control_inputs) - diff = torch.zeros_like(states) - diff[..., X] = states[..., V] * torch.cos(states[..., THETA]) - diff[..., Y] = states[..., V] * torch.sin(states[..., THETA]) - diff[..., THETA] = control_inputs[..., STEER] - diff[..., V] = control_inputs[..., ACCEL] + diff_x = states[..., V] * jnp.cos(states[..., THETA]) + diff_y = states[..., V] * jnp.sin(states[..., THETA]) + diff_theta = control_inputs[..., STEER] + diff_v = control_inputs[..., ACCEL] + diff = jnp.stack([diff_x, diff_y, diff_theta, diff_v], axis=-1) return unbatch_states(diff) - diff --git a/pit/integration/_utils.py b/pit/integration/_utils.py index 2aca7d5..e01dd2a 100644 --- a/pit/integration/_utils.py +++ b/pit/integration/_utils.py @@ -4,34 +4,21 @@ from typing import Any, Tuple -import torch +from .._compat import jnp def _normalize_batch_inputs( - initial_state: torch.Tensor, - control_inputs: torch.Tensor, - time_deltas: torch.Tensor | None, + initial_state, + control_inputs, + time_deltas, default_dt: float, parameter_group: Any | None = None, params_override: Any | None = None, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any, bool]: - """Normalize integration inputs to batched tensors. +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, Any, bool]: + """Normalize integration inputs to batched arrays.""" - Args: - initial_state: Tensor with shape ``(B, state_dims)`` or ``(state_dims,)``. - control_inputs: Tensor with shape ``(B, steps, input_dims)`` or - ``(steps, input_dims)``. - time_deltas: Optional tensor with shape ``(B, steps)`` or ``(steps,)``. - default_dt: Default time step to use when ``time_deltas`` is ``None``. - parameter_group: Parameter group used to sample parameters, when - available. - params_override: Optional parameters provided by the caller. - - Returns: - Tuple containing normalized tensors for ``initial_state``, - ``control_inputs``, ``time_deltas``, the parameters used for dynamics, - and a boolean flag indicating whether the original inputs were batched. - """ + initial_state = jnp.asarray(initial_state) + control_inputs = jnp.asarray(control_inputs) if control_inputs.ndim < 2: raise ValueError("Control inputs are not in the correct shape") @@ -39,31 +26,26 @@ def _normalize_batch_inputs( was_batched = initial_state.ndim == 2 if not was_batched: - initial_state = initial_state.unsqueeze(0) - control_inputs = control_inputs.unsqueeze(0) + initial_state = jnp.expand_dims(initial_state, axis=0) + control_inputs = jnp.expand_dims(control_inputs, axis=0) batch_size = initial_state.shape[0] steps = control_inputs.shape[1] if time_deltas is None: - time_deltas = torch.full( - (batch_size, steps), - fill_value=default_dt, - device=initial_state.device, - dtype=initial_state.dtype, - ) + time_deltas = jnp.full((batch_size, steps), default_dt, dtype=initial_state.dtype) else: + time_deltas = jnp.asarray(time_deltas, dtype=initial_state.dtype) if time_deltas.ndim == 1: - time_deltas = time_deltas.unsqueeze(0) + time_deltas = jnp.expand_dims(time_deltas, axis=0) if time_deltas.ndim != 2: raise ValueError("time_deltas must have shape (B, L) or (L,)") if time_deltas.shape[0] == 1 and batch_size != 1: - time_deltas = time_deltas.expand(batch_size, -1) + time_deltas = jnp.broadcast_to(time_deltas, (batch_size, time_deltas.shape[1])) elif time_deltas.shape[0] != batch_size: raise ValueError("time_deltas batch dimension does not match inputs") if time_deltas.shape[1] != steps: raise ValueError("time_deltas step dimension does not match inputs") - time_deltas = time_deltas.to(device=initial_state.device, dtype=initial_state.dtype) params = None if params_override == "BYPASS": @@ -79,4 +61,3 @@ def _normalize_batch_inputs( params = parameter_group.draw_parameters() return initial_state, control_inputs, time_deltas, params, was_batched - diff --git a/pit/integration/euler.py b/pit/integration/euler.py index dcbdea1..25ab767 100644 --- a/pit/integration/euler.py +++ b/pit/integration/euler.py @@ -2,25 +2,22 @@ from __future__ import annotations -import torch -from torch import nn - +from .._compat import jnp from ..parameters.definitions import AbstractParameterGroup from ..parameters.point import PointParameterGroup from ._utils import _normalize_batch_inputs -class Euler(nn.Module): - """Module to do Euler integration.""" +class Euler: + """Module to perform Euler integration using JAX arrays.""" def __init__( self, dynamics, parameters: AbstractParameterGroup | None = None, - timestep=0.10, - include_initial_state=False, + timestep: float = 0.10, + include_initial_state: bool = False, ) -> None: - super().__init__() self.dynamics = dynamics if parameters is None: self.model_params = PointParameterGroup(self.dynamics.parameter_list) @@ -30,8 +27,6 @@ def __init__( self.include_initial_state = include_initial_state def forward(self, initial_state, control_inputs, time_deltas=None): - """Integrate the dynamics using the Euler method.""" - ( initial_state, control_inputs, @@ -50,18 +45,20 @@ def forward(self, initial_state, control_inputs, time_deltas=None): integrated_states = [current_state] for i in range(control_inputs.shape[1]): - dt = time_deltas[:, i].unsqueeze(1) + dt = jnp.expand_dims(time_deltas[:, i], axis=1) control = control_inputs[:, i] diff = self.dynamics(current_state, control, params) current_state = current_state + diff * dt integrated_states.append(current_state) - integrated_states = torch.stack(integrated_states, dim=1) + integrated_states = jnp.stack(integrated_states, axis=1) if not self.include_initial_state: integrated_states = integrated_states[:, 1:] if not was_batched: - integrated_states = integrated_states.squeeze(0) + integrated_states = jnp.squeeze(integrated_states, axis=0) return integrated_states + + __call__ = forward diff --git a/pit/integration/rk4.py b/pit/integration/rk4.py index 1c96bed..4955eb8 100644 --- a/pit/integration/rk4.py +++ b/pit/integration/rk4.py @@ -1,24 +1,23 @@ -from typing import Union +from __future__ import annotations -import torch -from torch import nn +from typing import Union +from .._compat import jnp from ..parameters.definitions import AbstractParameterGroup from ..parameters.point import PointParameterGroup from ._utils import _normalize_batch_inputs -class RK4(nn.Module): - """Module to do RK4 integration""" +class RK4: + """Runge-Kutta 4th order integrator implemented with JAX arrays.""" def __init__( self, dynamics, - parameters: Union[AbstractParameterGroup, str] = None, - timestep=0.10, - include_initial_state=False, + parameters: Union[AbstractParameterGroup, str, None] = None, + timestep: float = 0.10, + include_initial_state: bool = False, ) -> None: - super().__init__() self.dynamics = dynamics if parameters == "BYPASS": self.model_params = "BYPASS" @@ -30,17 +29,6 @@ def __init__( self.include_initial_state = include_initial_state def forward(self, initial_state, control_inputs, time_deltas=None, params=None): - """ - We integrate the specified dynamics - - Args: - initial_state: Shape of (B, state_dims) or (state_dims) - control_inputs: Shape of (B, L, input_dims) or (L, input_dims) - time_deltas: Shape of (B, L) or (L) - - Output: - integrated_states: Shape of (B, L, state_dims) or (L, state_dims) - """ ( initial_state, control_inputs, @@ -60,26 +48,28 @@ def forward(self, initial_state, control_inputs, time_deltas=None, params=None): integrated_states = [current_state] for i in range(control_inputs.shape[1]): - dt = time_deltas[:, i].unsqueeze(1) + dt = jnp.expand_dims(time_deltas[:, i], axis=1) control = control_inputs[:, i] k1 = self.dynamics(current_state, control, params) - k2_state = current_state + dt * k1 / 2 + k2_state = current_state + dt * k1 / 2.0 k2 = self.dynamics(k2_state, control, params) - k3_state = current_state + dt * k2 / 2 + k3_state = current_state + dt * k2 / 2.0 k3 = self.dynamics(k3_state, control, params) k4_state = current_state + dt * k3 k4 = self.dynamics(k4_state, control, params) - current_state = current_state + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6 + current_state = current_state + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6.0 integrated_states.append(current_state) - integrated_states = torch.stack(integrated_states, dim=1) + integrated_states = jnp.stack(integrated_states, axis=1) if not self.include_initial_state: integrated_states = integrated_states[:, 1:] if not was_batched: - integrated_states = integrated_states.squeeze(0) + integrated_states = jnp.squeeze(integrated_states, axis=0) return integrated_states + + __call__ = forward diff --git a/pit/parameters/__init__.py b/pit/parameters/__init__.py index c1acba5..4f76bf7 100644 --- a/pit/parameters/__init__.py +++ b/pit/parameters/__init__.py @@ -1,3 +1,4 @@ -from .distribution import CovariantNormalParameterGroup, NormalParameterGroup -from .point import PointParameterGroup, ResidualPointParameterGroup -from .restricted import BoundedParameterGroup +from .definitions import ParameterSample +from .point import PointParameterGroup + +__all__ = ["ParameterSample", "PointParameterGroup"] diff --git a/pit/parameters/definitions.py b/pit/parameters/definitions.py index fabc6ed..47094d6 100644 --- a/pit/parameters/definitions.py +++ b/pit/parameters/definitions.py @@ -1,65 +1,68 @@ -import torch -from torch import nn +"""Minimal utilities for working with parameter groups.""" +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Dict, Iterable + +from .._compat import jnp + + +@dataclass(frozen=True) class ParameterSample: - def __init__(self, parameters: torch.Tensor, parameter_lookup: dict): - self.parameters = parameters - self.parameter_lookup = parameter_lookup + """Dictionary-like container for parameter arrays.""" - def __getitem__(self, key): - """This method should return the parameter value(s) for the key in text.""" + parameters: jnp.ndarray + parameter_lookup: Dict[str, int] + + def __getitem__(self, key: str) -> jnp.ndarray: return self.parameters[self.parameter_lookup[key]] -class AbstractParameterGroup(nn.Module): - def __init__(self, parameter_list: list, initial_value: dict=None): - super().__init__() - self.parameter_list = parameter_list - self.parameter_lookup = {param: i for i, param in enumerate(parameter_list)} +class AbstractParameterGroup: + """Tiny base class for parameter groups backed by JAX arrays.""" + + def __init__( + self, + parameter_list: Iterable[str], + initial_value: Dict[str, Any] | None = None, + ) -> None: + self.parameter_list = list(parameter_list) + self.parameter_lookup = {name: i for i, name in enumerate(self.parameter_list)} self.initialize_parameters() if initial_value: self.apply_initial_value(initial_value) - + @property - def num_params(self): + def num_params(self) -> int: return len(self.parameter_list) - def disable_gradients(self, parameter_name: str): - """This function should disable gradients for the given parameter.""" - raise NotImplementedError - - def enable_gradients(self, parameter_name: str): - """This function should enable gradients for the given parameter.""" + # The following methods define the required API for concrete subclasses. + def initialize_parameters(self) -> None: # pragma: no cover - interface definition raise NotImplementedError - def initialize_parameters(self): - """This function should initialize the parameters of this object.""" + def apply_initial_value(self, initial_value: Dict[str, Any]) -> None: # pragma: no cover - interface definition raise NotImplementedError - def apply_initial_value(self, initial_value: dict): - """This function should apply the initial parameter set to this object.""" + def get_evaluation_sample(self, batch_size: int = 1) -> ParameterSample: # pragma: no cover - interface definition raise NotImplementedError - def get_evaluation_sample(self, batch_size: int=1) -> ParameterSample: - """ - This method should return a sample of the parameters that can be used for evaluation. - This may be stable over multiple calls. - """ + def sample_parameters(self, batch_size: int = 1) -> ParameterSample: # pragma: no cover - interface definition raise NotImplementedError - def sample_parameters(self, batch_size: int=1) -> ParameterSample: - """ - This method should return a sample of the parameters that can be used for training. - This could return a different sample each time it is called. - """ - raise NotImplementedError - - def draw_parameters(self, batch_size: int=1) -> ParameterSample: - """ - This method will call the appropriate method based on training state. - """ - if self.training: - return self.sample_parameters(batch_size) - else: - return self.get_evaluation_sample(batch_size) \ No newline at end of file + # Backwards compatibility helpers ------------------------------------------------- + def train(self, mode: bool = True) -> "AbstractParameterGroup": # pragma: no cover - API compat + del mode + return self + + def eval(self) -> "AbstractParameterGroup": # pragma: no cover - API compat + return self + + def disable_gradients(self, parameter_name: str) -> None: # pragma: no cover - API compat + del parameter_name + + def enable_gradients(self, parameter_name: str) -> None: # pragma: no cover - API compat + del parameter_name + + def draw_parameters(self, batch_size: int = 1) -> ParameterSample: + return self.sample_parameters(batch_size) diff --git a/pit/parameters/distribution.py b/pit/parameters/distribution.py deleted file mode 100644 index 6de77da..0000000 --- a/pit/parameters/distribution.py +++ /dev/null @@ -1,116 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F -from torch.nn.utils.parametrize import register_parametrization -from .definitions import AbstractParameterGroup, ParameterSample -from torch.distributions import MultivariateNormal, Normal - - -class ScaleTril(nn.Module): - """Module to ensure that a matrix is symmetric positive definite""" - def forward(self, matrix, n): - # Return a positive triangular matrix, with positive diagonal - matrix = matrix @ matrix.T - matrix = matrix + torch.diag(F.softplus(n)) - return torch.linalg.cholesky(matrix) - -class Positive(nn.Module): - """Module to ensure that a parameter is non-negative""" - def forward(self, x): - return F.softplus(x) - -class CovariantNormalParameterGroup(AbstractParameterGroup): - """This class represents a group of parameters that are drawn from a multivariate normal distribution.""" - def __init__(self, parameter_list: list, initial_value: dict=None): - super().__init__(parameter_list, initial_value) - - def initialize_parameters(self): - self.loc = nn.Parameter(torch.zeros(len(self.parameter_list))) - self.raw_covariance = nn.Parameter(torch.eye(len(self.parameter_list))+0.5) - self.n = nn.Parameter(torch.ones(len(self.parameter_list))) - self.scale_tril = ScaleTril() - # self.scale_tril = nn.Parameter(torch.cholesky(covariance)) - # self.distribution = MultivariateNormal(loc=self.loc, covariance_matrix=self.covariance) - - @property - def covariance(self): - st = self.scale_tril(self.raw_covariance, self.n) - return st @ st.T - - def disable_gradients(self, parameter_name: str): - if parameter_name not in self.parameter_lookup: - raise ValueError(f"Parameter {parameter_name} not found in parameter list.") - self.loc[self.parameter_lookup[parameter_name]].requires_grad = False - - def enable_gradients(self, parameter_name: str): - if parameter_name not in self.parameter_lookup: - raise ValueError(f"Parameter {parameter_name} not found in parameter list.") - self.loc[self.parameter_lookup[parameter_name]].requires_grad = True - - def apply_initial_value(self, initial_value: dict): - self.loc.data = torch.tensor([initial_value[item] if item in initial_value else 0.0 for item in self.parameter_list]) - if 'covariance' in initial_value: - self.raw_covariance.data = torch.tensor(initial_value['covariance']) - else: - self.raw_covariance.data = torch.eye(len(self.parameter_list))+0.5 - # self.distribution = torch.distributions.MultivariateNormal(self.loc, self.covariance) - - def get_evaluation_sample(self, batch_size: int=1): - return ParameterSample( - torch.tile(self.loc.unsqueeze(1), (1, batch_size)), - self.parameter_lookup - ) - - def sample_parameters(self, batch_size: int=1): - return ParameterSample( - MultivariateNormal(loc=self.loc, scale_tril=self.scale_tril(self.raw_covariance, self.n)).rsample((batch_size, )).T, - self.parameter_lookup - ) - - # def to(self, *args, **kwargs): - # super().to(*args, **kwargs) - # # self.distribution = torch.distributions.MultivariateNormal(self.loc, self.covariance) - # self.distribution._unbroadcasted_scale_tril = self.distribution._unbroadcasted_scale_tril.to(*args, **kwargs) - # self.distribution.loc = self.distribution.loc.to(*args, **kwargs) - # if self.distribution.scale_tril is not None: - # self.distribution.scale_tril = self.distribution.scale_tril.to(*args, **kwargs) - # if self.distribution.covariance_matrix is not None: - # self.distribution.covariance_matrix = self.distribution.covariance_matrix.to(*args, **kwargs) - # if self.distribution.precision_matrix is not None: - # self.distribution.precision_matrix = self.distribution.precision_matrix.to(*args, **kwargs) - -class NormalParameterGroup(AbstractParameterGroup): - """This class represents a group of parameters that are drawn from a normal distribution.""" - def __init__(self, parameter_list: list, initial_value: dict=None): - super().__init__(parameter_list, initial_value) - - def initialize_parameters(self): - self.loc = nn.Parameter(torch.zeros(len(self.parameter_list))) - self.positive = Positive() - self.raw_scale = nn.Parameter(torch.ones(len(self.parameter_list))) - - @property - def scale(self): - return self.positive(self.raw_scale) - - def apply_initial_value(self, initial_value: dict): - self.loc.data = torch.tensor([initial_value[item] if item in initial_value else 0.0 for item in self.parameter_list]) - self.raw_scale.data = torch.tensor([initial_value[item+"_scale"] if item+"_scale" in initial_value else 1.0 for item in self.parameter_list]) - - def get_evaluation_sample(self, batch_size: int=1): - return ParameterSample( - torch.tile(self.loc.unsqueeze(1), (1, batch_size)), - self.parameter_lookup - ) - - def sample_parameters(self, batch_size: int=1): - return ParameterSample( - Normal(self.loc, self.scale).rsample((batch_size,)).T, - self.parameter_lookup - ) - - # def to(self, *args, **kwargs): - # super().to(*args, **kwargs) - # # self.distribution = torch.distributions.Normal(self.loc, self.scale) - # self.distribution.loc = self.distribution.loc.to(*args, **kwargs) - # self.distribution.scale = self.distribution.scale.to(*args, **kwargs) diff --git a/pit/parameters/point.py b/pit/parameters/point.py index 929eb49..8a861a1 100644 --- a/pit/parameters/point.py +++ b/pit/parameters/point.py @@ -1,102 +1,41 @@ -import torch -from torch import nn -from torch.nn import functional as F -from .definitions import AbstractParameterGroup, ParameterSample -import warnings - - -class PointParameterGroup(AbstractParameterGroup): - """This class represents a group of parameters that are point estimates.""" - - def __init__(self, parameter_list: list, initial_value: dict = None): - super().__init__(parameter_list, initial_value) - - def initialize_parameters(self): - self.params = nn.ParameterDict() - for param in self.parameter_list: - self.params[param] = nn.Parameter(torch.tensor(0.0, dtype=torch.float64)) +"""Deterministic parameter group backed by JAX arrays.""" - def disable_gradients(self, parameter_name: str): - self.params[parameter_name].requires_grad = False +from __future__ import annotations - def enable_gradients(self, parameter_name: str): - self.params[parameter_name].requires_grad = True +from typing import Dict - def apply_initial_value(self, initial_value: dict): - for param in self.parameter_list: - self.params[param].data = torch.tensor( - initial_value[param] if param in initial_value else 0.0, - dtype=torch.float64, - ) +from .._compat import jnp - def get_evaluation_sample(self, batch_size: int = 1): - return ParameterSample( - torch.tile( - torch.stack( - [self.params[param].data for param in self.parameter_list] - ).reshape(-1, 1), - (1, batch_size), - ), - self.parameter_lookup, - ) - - def sample_parameters(self, batch_size: int = 1): - return ParameterSample( - torch.tile( - torch.stack( - [self.params[param] for param in self.parameter_list] - ).reshape(-1, 1), - (1, batch_size), - ), - self.parameter_lookup, - ) +from .definitions import AbstractParameterGroup, ParameterSample -class ResidualPointParameterGroup(PointParameterGroup): - """This class represents a group of parameters that are point estimates.""" +class PointParameterGroup(AbstractParameterGroup): + """Parameter group whose values are simple point estimates.""" - def __init__(self, parameter_list: list, initial_value: dict = None): + def __init__(self, parameter_list: list, initial_value: dict | None = None): super().__init__(parameter_list, initial_value) - warnings.warn("This is also enforcing positivity.") - - def initialize_parameters(self): - self.baseline = dict() - self.params = nn.ParameterDict() - for param in self.parameter_list: - self.baseline[param] = torch.tensor(0.0, dtype=torch.float64) - self.params[param] = nn.Parameter(torch.tensor(0.0, dtype=torch.float64)) - - def apply_initial_value(self, initial_value: dict): - for param in self.parameter_list: - if param in initial_value: - self.baseline[param] = torch.tensor( - initial_value[param], dtype=torch.float64 - ) - - def get_evaluation_sample(self, batch_size: int = 1): - return ParameterSample( - torch.tile( - torch.stack( - [ - F.softplus(self.baseline[param] + self.params[param].data) - for param in self.parameter_list - ] - ).reshape(-1, 1), - (1, batch_size), - ), - self.parameter_lookup, - ) - def sample_parameters(self, batch_size: int = 1): - return ParameterSample( - torch.tile( - torch.stack( - [ - F.softplus(self.baseline[param] + self.params[param]) - for param in self.parameter_list - ] - ).reshape(-1, 1), - (1, batch_size), - ), - self.parameter_lookup, - ) + def initialize_parameters(self) -> None: + self.params: Dict[str, jnp.ndarray] = { + name: jnp.array(0.0, dtype=jnp.float32) for name in self.parameter_list + } + + def apply_initial_value(self, initial_value: dict) -> None: + for name, value in initial_value.items(): + if name in self.params: + self.params[name] = jnp.asarray(value, dtype=jnp.float32) + + def _stack_params(self) -> jnp.ndarray: + if not self.parameter_list: + return jnp.zeros((0,), dtype=jnp.float32) + return jnp.stack([self.params[name] for name in self.parameter_list]) + + def get_evaluation_sample(self, batch_size: int = 1) -> ParameterSample: + values = self._stack_params() + if values.ndim == 0: + values = values[None] + values = jnp.broadcast_to(values[:, None], (values.shape[0], batch_size)) + return ParameterSample(values, self.parameter_lookup) + + def sample_parameters(self, batch_size: int = 1) -> ParameterSample: + return self.get_evaluation_sample(batch_size) diff --git a/pit/parameters/restricted.py b/pit/parameters/restricted.py deleted file mode 100644 index bdf40e4..0000000 --- a/pit/parameters/restricted.py +++ /dev/null @@ -1,83 +0,0 @@ -import torch -from torch import nn -from torch.nn import functional as F -from .definitions import AbstractParameterGroup, ParameterSample - - -class BoundedParameterGroup(AbstractParameterGroup): - """This class represents a group of parameters that are bounded.""" - - def __init__( - self, parameter_list: list, initial_value: dict = None, bounds: dict = None - ): - self.bounds = bounds - super().__init__(parameter_list, initial_value) - - def initialize_parameters(self): - self.params = nn.ParameterDict() - for param in self.parameter_list: - self.params[param] = nn.Parameter(torch.tensor(0.0, dtype=torch.float64)) - - def apply_initial_value(self, initial_value: dict): - for param in self.parameter_list: - if param in initial_value: - lower, upper = self.bounds[param] - # We assume that the initial value is within the bounds - assert ( - initial_value[param] >= lower and initial_value[param] <= upper - ), ( - f"Initial value {initial_value[param]} for parameter {param} is out of bounds [{lower}, {upper}]" - ) - # What is the value of the parameter when normalized to where -1 is lower and 1 is upper? - normalized_value = torch.atanh( - ((torch.tensor(initial_value[param]) - lower) / (upper - lower)) - 1 - ) - - self.params[param].data = torch.tensor( - normalized_value, - dtype=torch.float64, - ) - else: - # If the parameter is not in the initial value, we set it to 0 - self.params[param].data = torch.tensor( - 0.0, - dtype=torch.float64, - ) - - def get_evaluation_sample(self, batch_size: int = 1): - return ParameterSample( - torch.tile( - torch.stack( - [ - (self.bounds[param][1] - self.bounds[param][0]) - * (F.tanh(self.params[param]) + 1) - + self.bounds[param][0] - for param in self.parameter_list - ] - ).reshape(-1, 1), - (1, batch_size), - ), - self.parameter_lookup, - ) - - def sample_parameters(self, batch_size: int = 1): - return ParameterSample( - torch.tile( - torch.stack( - [ - (self.bounds[param][1] - self.bounds[param][0]) - * (F.tanh(self.params[param]) + 1) - + self.bounds[param][0] - for param in self.parameter_list - ] - ).reshape(-1, 1), - (1, batch_size), - ), - self.parameter_lookup, - ) - - def disable_gradients(self, parameter_name: str): - self.params[parameter_name].requires_grad = False - - def enable_gradients(self, parameter_name: str): - self.params[parameter_name].requires_grad = True diff --git a/pyproject.toml b/pyproject.toml index afae8b2..0508185 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,9 @@ authors = [{ name = "Nandan Tumu", email = "nandan.t@live.com" }] requires-python = "~=3.10" readme = "README.md" dependencies = [ - "torch>=2.7.0,<3", + "jax>=0.4.31", "matplotlib>=3.9.2,<4", + "numpy>=1.26", "pytest>=8.3.2,<9", ] diff --git a/tests/dynamics/test_batching.py b/tests/dynamics/test_batching.py index 85d7d48..bbea5a1 100644 --- a/tests/dynamics/test_batching.py +++ b/tests/dynamics/test_batching.py @@ -1,6 +1,6 @@ -from __future__ import annotations +import numpy as np -import torch +from pit._compat import jnp from pit.dynamics.dynamic_bicycle import DynamicBicycle from pit.dynamics.kinematic_bicycle import Bicycle @@ -9,9 +9,23 @@ def _make_parameter_sample(values): - names = ['lf', 'lr', 'Iz', 'm', 'Df', 'Cf', 'Bf', 'Dr', 'Cr', 'Br', 'Cm', 'Cr0', 'Cr2'] + names = [ + "lf", + "lr", + "Iz", + "m", + "Df", + "Cf", + "Bf", + "Dr", + "Cr", + "Br", + "Cm", + "Cr0", + "Cr2", + ] lookup = {name: idx for idx, name in enumerate(names)} - tensor = torch.as_tensor(values, dtype=torch.float32) + tensor = jnp.asarray(values, dtype=jnp.float32) return ParameterSample(tensor, lookup) @@ -32,13 +46,13 @@ def test_dynamic_bicycle_single_matches_batch(): Cr2=0.01, ) - state = torch.tensor([0.0, 0.0, 0.2, 5.0, 0.3, 0.05, 0.02], dtype=torch.float32) - control = torch.tensor([0.4, 0.01], dtype=torch.float32) + state = jnp.array([0.0, 0.0, 0.2, 5.0, 0.3, 0.05, 0.02], dtype=jnp.float32) + control = jnp.array([0.4, 0.01], dtype=jnp.float32) params_single = _make_parameter_sample( [1.2, 1.3, 1400.0, 1500.0, 1.0, 1.2, 10.0, 1.0, 1.3, 11.0, 0.5, 0.1, 0.01] ) params_batch = _make_parameter_sample( - torch.tensor( + jnp.array( [ [1.2, 1.3, 1400.0, 1500.0, 1.0, 1.2, 10.0, 1.0, 1.3, 11.0, 0.5, 0.1, 0.01], ] @@ -46,12 +60,14 @@ def test_dynamic_bicycle_single_matches_batch(): ) diff_single = model.forward(state, control, params_single) - diff_batch = model.forward(state.unsqueeze(0), control.unsqueeze(0), params_batch) + diff_batch = model.forward(jnp.expand_dims(state, axis=0), jnp.expand_dims(control, axis=0), params_batch) tire_single = model.calculate_tire_forces(state, control, params_single) - tire_batch = model.calculate_tire_forces(state.unsqueeze(0), control.unsqueeze(0), params_batch) + tire_batch = model.calculate_tire_forces( + jnp.expand_dims(state, axis=0), jnp.expand_dims(control, axis=0), params_batch + ) - torch.testing.assert_close(diff_batch.squeeze(0), diff_single) - torch.testing.assert_close(tire_batch.squeeze(0), tire_single) + np.testing.assert_allclose(diff_batch.squeeze(0), diff_single) + np.testing.assert_allclose(tire_batch.squeeze(0), tire_single) def test_dynamic_bicycle_multi_batch_matches_single_calls(): @@ -71,22 +87,25 @@ def test_dynamic_bicycle_multi_batch_matches_single_calls(): Cr2=0.01, ) - states = torch.stack( + states = jnp.stack( [ - torch.tensor([0.0, 0.0, 0.1, 5.0, 0.2, 0.05, 0.02], dtype=torch.float32), - torch.tensor([1.0, -0.5, 0.3, 7.0, -0.1, 0.1, -0.05], dtype=torch.float32), - torch.tensor([-2.0, 0.4, -0.2, 3.5, 0.15, -0.02, 0.03], dtype=torch.float32), + jnp.array([0.0, 0.0, 0.1, 5.0, 0.2, 0.05, 0.02], dtype=jnp.float32), + jnp.array([1.0, -0.5, 0.3, 7.0, -0.1, 0.1, -0.05], dtype=jnp.float32), + jnp.array([-2.0, 0.4, -0.2, 3.5, 0.15, -0.02, 0.03], dtype=jnp.float32), ] ) - controls = torch.stack( + controls = jnp.stack( [ - torch.tensor([0.4, 0.01], dtype=torch.float32), - torch.tensor([-0.2, 0.05], dtype=torch.float32), - torch.tensor([0.1, -0.03], dtype=torch.float32), + jnp.array([0.4, 0.01], dtype=jnp.float32), + jnp.array([-0.2, 0.05], dtype=jnp.float32), + jnp.array([0.1, -0.03], dtype=jnp.float32), ] ) - base_params = torch.tensor([1.2, 1.3, 1400.0, 1500.0, 1.0, 1.2, 10.0, 1.0, 1.3, 11.0, 0.5, 0.1, 0.01], dtype=torch.float32) - params_matrix = torch.stack( + base_params = jnp.array( + [1.2, 1.3, 1400.0, 1500.0, 1.0, 1.2, 10.0, 1.0, 1.3, 11.0, 0.5, 0.1, 0.01], + dtype=jnp.float32, + ) + params_matrix = jnp.stack( [ base_params, base_params * 1.01, @@ -105,60 +124,60 @@ def test_dynamic_bicycle_multi_batch_matches_single_calls(): singles.append(model.forward(states[idx], controls[idx], params_single)) tires.append(model.calculate_tire_forces(states[idx], controls[idx], params_single)) - torch.testing.assert_close(batch_result, torch.stack(singles)) - torch.testing.assert_close(batch_tire, torch.stack(tires)) + np.testing.assert_allclose(batch_result, jnp.stack(singles)) + np.testing.assert_allclose(batch_tire, jnp.stack(tires)) def test_kinematic_bicycle_batching_matches_single(): model = Bicycle(wheelbase=2.5) - state = torch.tensor([0.0, 0.0, 0.2, 5.0], dtype=torch.float32) - control = torch.tensor([0.1, 0.2], dtype=torch.float32) + state = jnp.array([0.0, 0.0, 0.2, 5.0], dtype=jnp.float32) + control = jnp.array([0.1, 0.2], dtype=jnp.float32) diff_single = model.forward(state, control) - diff_batch = model.forward(state.unsqueeze(0), control.unsqueeze(0)) + diff_batch = model.forward(jnp.expand_dims(state, axis=0), jnp.expand_dims(control, axis=0)) - torch.testing.assert_close(diff_batch.squeeze(0), diff_single) + np.testing.assert_allclose(jnp.squeeze(diff_batch, axis=0), diff_single) - states = torch.stack( + states = jnp.stack( [ - torch.tensor([0.0, 0.0, 0.1, 5.0], dtype=torch.float32), - torch.tensor([1.0, -0.5, 0.3, 7.0], dtype=torch.float32), + jnp.array([0.0, 0.0, 0.1, 5.0], dtype=jnp.float32), + jnp.array([1.0, -0.5, 0.3, 7.0], dtype=jnp.float32), ] ) - controls = torch.stack( + controls = jnp.stack( [ - torch.tensor([0.1, 0.2], dtype=torch.float32), - torch.tensor([-0.2, 0.5], dtype=torch.float32), + jnp.array([0.1, 0.2], dtype=jnp.float32), + jnp.array([-0.2, 0.5], dtype=jnp.float32), ] ) batch_result = model.forward(states, controls) singles = [model.forward(states[i], controls[i]) for i in range(states.shape[0])] - torch.testing.assert_close(batch_result, torch.stack(singles)) + np.testing.assert_allclose(batch_result, jnp.stack(singles)) def test_unicycle_batching_matches_single(): model = Unicycle() - state = torch.tensor([0.0, 0.0, 0.2, 5.0], dtype=torch.float32) - control = torch.tensor([0.1, 0.2], dtype=torch.float32) + state = jnp.array([0.0, 0.0, 0.2, 5.0], dtype=jnp.float32) + control = jnp.array([0.1, 0.2], dtype=jnp.float32) diff_single = model.forward(state, control, params=None) - diff_batch = model.forward(state.unsqueeze(0), control.unsqueeze(0), params=None) - torch.testing.assert_close(diff_batch.squeeze(0), diff_single) + diff_batch = model.forward(jnp.expand_dims(state, axis=0), jnp.expand_dims(control, axis=0), params=None) + np.testing.assert_allclose(diff_batch.squeeze(0), diff_single) - states = torch.stack( + states = jnp.stack( [ - torch.tensor([0.0, 0.0, 0.1, 5.0], dtype=torch.float32), - torch.tensor([1.0, -0.5, 0.3, 7.0], dtype=torch.float32), + jnp.array([0.0, 0.0, 0.1, 5.0], dtype=jnp.float32), + jnp.array([1.0, -0.5, 0.3, 7.0], dtype=jnp.float32), ] ) - controls = torch.stack( + controls = jnp.stack( [ - torch.tensor([0.1, 0.2], dtype=torch.float32), - torch.tensor([-0.2, 0.5], dtype=torch.float32), + jnp.array([0.1, 0.2], dtype=jnp.float32), + jnp.array([-0.2, 0.5], dtype=jnp.float32), ] ) batch_result = model.forward(states, controls, params=None) singles = [model.forward(states[i], controls[i], params=None) for i in range(states.shape[0])] - torch.testing.assert_close(batch_result, torch.stack(singles)) + np.testing.assert_allclose(batch_result, jnp.stack(singles)) diff --git a/tests/integration/test_time_delta.py b/tests/integration/test_time_delta.py index 88967cd..bb07319 100644 --- a/tests/integration/test_time_delta.py +++ b/tests/integration/test_time_delta.py @@ -1,26 +1,29 @@ -import pytest -import torch +import numpy as np + +from pit._compat import jnp from pit.integration import RK4, Euler from pit.dynamics.unicycle import Unicycle + def test_time_delta_euler(): unicycle = Unicycle() euler = Euler(unicycle, timestep=0.1) - initial_state = torch.tensor([0.0, 0.0, 0.0, 0.0]) - control_inputs = torch.tensor([[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) + initial_state = jnp.array([0.0, 0.0, 0.0, 0.0]) + control_inputs = jnp.array([[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) euler_states = euler(initial_state, control_inputs) - skip_control_inputs = torch.tensor([[0.0, 1.0], [0.0, 0.0]]) - time_deltas = torch.tensor([0.1, 0.2]) + skip_control_inputs = jnp.array([[0.0, 1.0], [0.0, 0.0]]) + time_deltas = jnp.array([0.1, 0.2]) euler_states_skip = euler(initial_state, skip_control_inputs, time_deltas) - assert torch.allclose(euler_states[-1], euler_states_skip[-1]) + np.testing.assert_allclose(euler_states[-1], euler_states_skip[-1]) + def test_time_delta_rk4(): unicycle = Unicycle() rk4 = RK4(unicycle, timestep=0.1) - initial_state = torch.tensor([0.0, 0.0, 0.0, 0.0]) - control_inputs = torch.tensor([[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) + initial_state = jnp.array([0.0, 0.0, 0.0, 0.0]) + control_inputs = jnp.array([[0.0, 1.0], [0.0, 0.0], [0.0, 0.0]]) rk4_states = rk4(initial_state, control_inputs) - skip_control_inputs = torch.tensor([[0.0, 1.0], [0.0, 0.0]]) - time_deltas = torch.tensor([0.1, 0.2]) + skip_control_inputs = jnp.array([[0.0, 1.0], [0.0, 0.0]]) + time_deltas = jnp.array([0.1, 0.2]) rk4_states_skip = rk4(initial_state, skip_control_inputs, time_deltas) - assert torch.allclose(rk4_states[-1], rk4_states_skip[-1]) \ No newline at end of file + np.testing.assert_allclose(rk4_states[-1], rk4_states_skip[-1])