Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pit/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Compatibility layer providing JAX-like APIs with NumPy fallbacks."""

from __future__ import annotations

import numpy as _np

try: # pragma: no cover - prefer JAX when available
import jax
import jax.numpy as jnp # type: ignore
import jax.nn as jnn # type: ignore
import jax.random as jr # type: ignore
Array = jax.Array
USING_JAX = True
except ImportError: # pragma: no cover - fallback path
jnp = _np # type: ignore
USING_JAX = False

class _NN:
@staticmethod
def softplus(x):
return _np.log1p(_np.exp(-_np.abs(x))) + _np.maximum(x, 0.0)

jnn = _NN() # type: ignore

class _Random:
@staticmethod
def PRNGKey(seed: int):
return _np.random.default_rng(seed)

@staticmethod
def split(key):
seed1 = int(key.integers(0, 2**32 - 1))
seed2 = int(key.integers(0, 2**32 - 1))
return _np.random.default_rng(seed1), _np.random.default_rng(seed2)

@staticmethod
def multivariate_normal(key, mean, cov, shape):
size = int(shape[0]) if shape else None
return key.multivariate_normal(mean, cov, size=size)

@staticmethod
def normal(key, shape):
return key.normal(size=shape)

jr = _Random() # type: ignore
Array = _np.ndarray # type: ignore

__all__ = ["jnp", "jnn", "jr", "Array", "USING_JAX"]
25 changes: 12 additions & 13 deletions pit/dynamics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import torch
from torch import nn
"""Base definitions for dynamics models."""

from __future__ import annotations

from ..parameters.definitions import ParameterSample

class Dynamics(nn.Module):
"""Base Class for dynamics"""
def __init__(self) -> None:
super().__init__()

def forward(self, states, inputs, params: ParameterSample):
"""
Dynamics evolutions
class Dynamics:
"""Base class for dynamics models."""

Args:
states: Dimension of (N, state_dims)
inputs: Dimension of (N, control_inputs)
"""
parameter_list: list[str]

def forward(self, states, inputs, params: ParameterSample): # pragma: no cover - abstract
raise NotImplementedError

def __call__(self, states, inputs, params: ParameterSample):
return self.forward(states, inputs, params)
29 changes: 9 additions & 20 deletions pit/dynamics/_batching.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,23 @@
from __future__ import annotations

from typing import Callable, Tuple, TypeVar
from typing import Callable, Tuple

import torch
from .._compat import Array, jnp

TensorLike = TypeVar("TensorLike", bound=torch.Tensor)


def ensure_batch(tensor: TensorLike) -> Tuple[TensorLike, Callable[[TensorLike], TensorLike]]:
def ensure_batch(tensor: Array) -> Tuple[Array, Callable[[Array], Array]]:
"""Ensure a tensor has a batch dimension.

The helper promotes one-dimensional tensors to batched tensors by
unsqueezing a leading dimension. A callable is returned that can be
applied to tensors with the resulting shape to restore the original
dimensionality.

Args:
tensor: A tensor with shape ``(dim,)`` or ``(batch, dim)``.

Returns:
A tuple containing the (potentially) batched tensor and a callable to
restore tensors with matching shape back to their original
dimensionality.
Promotes one-dimensional arrays to batched arrays by adding a leading
dimension. A callable is returned that can be applied to arrays with the
resulting shape to restore the original dimensionality.
"""

if tensor.ndim == 1:
batched = tensor.unsqueeze(0)
batched = jnp.expand_dims(tensor, axis=0)

def restore(result: TensorLike) -> TensorLike:
return result.squeeze(0)
def restore(result: Array) -> Array:
return jnp.squeeze(result, axis=0)

return batched, restore

Expand Down
204 changes: 82 additions & 122 deletions pit/dynamics/dynamic_bicycle.py
Original file line number Diff line number Diff line change
@@ -1,160 +1,120 @@
from __future__ import annotations

from .._compat import jnp

from . import Dynamics
from ._batching import ensure_batch
from ..parameters import PointParameterGroup, CovariantNormalParameterGroup, NormalParameterGroup
from ..parameters.definitions import ParameterSample

import torch
from torch import nn

X, Y, YAW, VX, VY, YAW_RATE, STEERING_ANGLE = 0, 1, 2, 3, 4, 5, 6
DRIVE_FORCE, STEER_SPEED = 0, 1
FRX, FFY, FRY = 0, 1, 2


class DynamicBicycle(Dynamics, nn.Module):
"""
This is a dynamic bicycle model
From AMZ Driverless: The Full Autonomous Racing System
Model reference point: CoG
Longitudinal drive-train forces act on the center of gravity
State Variable [x, y, yaw, vx, vy, yaw rate, steering angle]
Control Inputs [drive force, steering speed]
"""
class DynamicBicycle(Dynamics):
"""Dynamic bicycle model based on the AMZ Driverless formulation."""

def __init__(self, lf, lr, Iz, m, Df, Cf, Bf, Dr, Cr, Br, Cm, Cr0, Cr2, **kwargs) -> None:
super().__init__()
self.parameter_list = ['lf', 'lr', 'Iz', 'm', 'Df', 'Cf', 'Bf', 'Dr', 'Cr', 'Br', 'Cm', 'Cr0', 'Cr2']
del kwargs
self.parameter_list = [
"lf",
"lr",
"Iz",
"m",
"Df",
"Cf",
"Bf",
"Dr",
"Cr",
"Br",
"Cm",
"Cr0",
"Cr2",
]
self.initial_values = {
'lf': lf,
'lr': lr,
'Iz': Iz,
'm': m,
'Df': Df,
'Cf': Cf,
'Bf': Bf,
'Dr': Dr,
'Cr': Cr,
'Br': Br,
'Cm': Cm,
'Cr0': Cr0,
'Cr2': Cr2,
"lf": lf,
"lr": lr,
"Iz": Iz,
"m": m,
"Df": Df,
"Cf": Cf,
"Bf": Bf,
"Dr": Dr,
"Cr": Cr,
"Br": Br,
"Cm": Cm,
"Cr0": Cr0,
"Cr2": Cr2,
}
# if param_type == 'point':
# self.params = PointParameterGroup(self.param_names, self.initial_values)
# elif param_type == 'normal':
# self.params = NormalParameterGroup(self.param_names, self.initial_values)
# elif param_type == 'covariant':
# # raise FutureWarning("CovariantNormalParameterGroup is not implemented yet")
# self.params = CovariantNormalParameterGroup(self.param_names, self.initial_values)

# self.lf = torch.nn.Parameter(torch.tensor(lf, dtype=torch.float32))
# self.lr = torch.nn.Parameter(torch.tensor(lr, dtype=torch.float32))
# self.Iz = torch.nn.Parameter(torch.tensor(Iz, dtype=torch.float32))
# self.mass = torch.nn.Parameter(torch.tensor(mass, dtype=torch.float32))
# self.Df = torch.nn.Parameter(torch.tensor(Df, dtype=torch.float32))
# self.Cf = torch.nn.Parameter(torch.tensor(Cf, dtype=torch.float32))
# self.Bf = torch.nn.Parameter(torch.tensor(Bf, dtype=torch.float32))
# self.Dr = torch.nn.Parameter(torch.tensor(Dr, dtype=torch.float32))
# self.Cr = torch.nn.Parameter(torch.tensor(Cr, dtype=torch.float32))
# self.Br = torch.nn.Parameter(torch.tensor(Br, dtype=torch.float32))
# self.Cm = torch.nn.Parameter(torch.tensor(Cm, dtype=torch.float32))
# self.Cr0 = torch.nn.Parameter(torch.tensor(Cr0, dtype=torch.float32))
# self.Cr2 = torch.nn.Parameter(torch.tensor(Cr2, dtype=torch.float32))

def to(self, *args, **kwargs):
super().to(*args, **kwargs)
# self.params.to(*args, **kwargs)

def calculate_tire_forces(self, states, control_inputs, params: ParameterSample):
"""Get the tire forces at this point.

Args:
states: Shape of ``(B, 7)`` or ``(7,)``.
control_inputs: Shape of ``(B, 2)`` or ``(2,)``.

Returns:
Tire forces with shape ``(B, 3)`` or ``(3,)`` [Frx, Ffy, Fry].
"""
states = jnp.asarray(states)
control_inputs = jnp.asarray(control_inputs)

states, unbatch_states = ensure_batch(states)
control_inputs, _ = ensure_batch(control_inputs)

device = params['lf'].device
tire_forces = torch.zeros((*states.shape[:-1], 3), device=device, dtype=states.dtype)

alpha_f = states[..., STEERING_ANGLE] - torch.arctan(
(states[..., YAW_RATE] * params['lf'] + states[..., VY]) / states[..., VX]
alpha_f = states[..., STEERING_ANGLE] - jnp.arctan(
(states[..., YAW_RATE] * params["lf"] + states[..., VY]) / states[..., VX]
)
alpha_r = torch.arctan(
(states[..., YAW_RATE] * params['lr'] - states[..., VY]) / states[..., VX]
alpha_r = jnp.arctan(
(states[..., YAW_RATE] * params["lr"] - states[..., VY]) / states[..., VX]
)

tire_forces[..., FRX] = (
params['Cm'] * control_inputs[..., DRIVE_FORCE]
- params['Cr0']
- params['Cr2'] * states[..., VX] ** 2.0
)
tire_forces[..., FFY] = params['Df'] * torch.sin(
params['Cf'] * torch.arctan(params['Bf'] * alpha_f)
)
tire_forces[..., FRY] = params['Dr'] * torch.sin(
params['Cr'] * torch.arctan(params['Br'] * alpha_r)
frx = (
params["Cm"] * control_inputs[..., DRIVE_FORCE]
- params["Cr0"]
- params["Cr2"] * states[..., VX] ** 2.0
)
ffy = params["Df"] * jnp.sin(params["Cf"] * jnp.arctan(params["Bf"] * alpha_f))
fry = params["Dr"] * jnp.sin(params["Cr"] * jnp.arctan(params["Br"] * alpha_r))
tire_forces = jnp.stack([frx, ffy, fry], axis=-1)
return unbatch_states(tire_forces)

def forward(self, states, control_inputs, params: ParameterSample):
"""Get the evaluated ODEs of the state at this point.

Args:
states: Shape of ``(B, 7)`` or ``(7,)``.
control_inputs: Shape of ``(B, 2)`` or ``(2,)``.
params: Parameter sample containing the vehicle parameters.
"""
states = jnp.asarray(states)
control_inputs = jnp.asarray(control_inputs)

states, unbatch_states = ensure_batch(states)
control_inputs, _ = ensure_batch(control_inputs)

diff = torch.zeros_like(states)
tire_forces = self.calculate_tire_forces(states, control_inputs, params)

diff[..., X] = (
states[..., VX] * torch.cos(states[..., YAW])
- states[..., VY] * torch.sin(states[..., YAW])
)
diff[..., Y] = (
states[..., VX] * torch.sin(states[..., YAW])
- states[..., VY] * torch.cos(states[..., YAW])
)
diff[..., YAW] = states[..., YAW_RATE]
diff[..., VX] = (
1.0
/ params['m']
* (
tire_forces[..., FRX]
- tire_forces[..., FFY] * torch.sin(states[..., STEERING_ANGLE])
+ states[..., VY] * states[..., YAW_RATE] * params['m']
)
diff_x = (
states[..., VX] * jnp.cos(states[..., YAW])
- states[..., VY] * jnp.sin(states[..., YAW])
)
diff[..., VY] = (
1.0
/ params['m']
* (
tire_forces[..., FRY]
+ tire_forces[..., FFY] * torch.cos(states[..., STEERING_ANGLE])
- states[..., VX] * states[..., YAW_RATE] * params['m']
)
diff_y = (
states[..., VX] * jnp.sin(states[..., YAW])
- states[..., VY] * jnp.cos(states[..., YAW])
)
diff[..., YAW_RATE] = (
1.0
/ params['Iz']
* (
tire_forces[..., FFY]
* params['lf']
* torch.cos(states[..., STEERING_ANGLE])
- tire_forces[..., FRY] * params['lr']
)
diff_yaw = states[..., YAW_RATE]
diff_vx = (
tire_forces[..., FRX]
- tire_forces[..., FFY] * jnp.sin(states[..., STEERING_ANGLE])
+ states[..., VY] * states[..., YAW_RATE] * params["m"]
) / params["m"]
diff_vy = (
tire_forces[..., FRY]
+ tire_forces[..., FFY] * jnp.cos(states[..., STEERING_ANGLE])
- states[..., VX] * states[..., YAW_RATE] * params["m"]
) / params["m"]
diff_yaw_rate = (
tire_forces[..., FFY] * params["lf"] * jnp.cos(states[..., STEERING_ANGLE])
- tire_forces[..., FRY] * params["lr"]
) / params["Iz"]
diff_steer = control_inputs[..., STEER_SPEED]

diff = jnp.stack(
[
diff_x,
diff_y,
diff_yaw,
diff_vx,
diff_vy,
diff_yaw_rate,
diff_steer,
],
axis=-1,
)
diff[..., STEERING_ANGLE] = control_inputs[..., STEER_SPEED]
return unbatch_states(diff)
Loading