Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 98 additions & 0 deletions torch_sim/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,38 @@ def merge(cls, constraints: list[Constraint]) -> Self:
constraints: Constraints to merge (all same type, already reindexed)
"""

def split_constraint(
self,
n_systems: int,
n_atoms_per_system: torch.Tensor,
) -> list[Self | None]:
"""Split this constraint into one constraint per system in a single pass.

Returns a list of length ``n_systems`` where each element is either a
single-system constraint (with local indices starting at 0) or ``None``
if this constraint does not apply to that system.

The default implementation falls back to calling
:meth:`select_sub_constraint` in a loop. Subclasses should override
for better performance (fewer GPU synchronisation points).

Args:
n_systems: Total number of systems in the batched state.
n_atoms_per_system: Tensor of shape ``(n_systems,)`` giving atom
counts per system.

Returns:
List of length *n_systems* with per-system constraints or ``None``.
"""
cumsum = _cumsum_with_zero(n_atoms_per_system)
result: list[Self | None] = [None] * n_systems
for sys_idx in range(n_systems):
start = int(cumsum[sys_idx].item())
end = int(cumsum[sys_idx + 1].item())
atom_idx = torch.arange(start, end, device=n_atoms_per_system.device)
result[sys_idx] = self.select_sub_constraint(atom_idx, sys_idx)
return result

@abstractmethod
def to(
self,
Expand Down Expand Up @@ -268,6 +300,32 @@ def select_sub_constraint(
return None
return type(self)(new_atom_idx)

def split_constraint(
self,
n_systems: int,
n_atoms_per_system: torch.Tensor,
) -> list[Self | None]:
"""Split atom constraint across systems in one pass."""
cumsum = _cumsum_with_zero(n_atoms_per_system)
# Assign each constrained atom to its system via searchsorted
system_of_atom = torch.searchsorted(cumsum[1:], self.atom_idx, right=True)
# One GPU sync: get counts per system on CPU
counts = torch.bincount(system_of_atom, minlength=n_systems).tolist()
# Sort constrained atom indices by system for contiguous slicing
sort_order = torch.argsort(system_of_atom)
sorted_atom_idx = self.atom_idx[sort_order]

result: list[Self | None] = [None] * n_systems
offset = 0
for sys_idx in range(n_systems):
count = counts[sys_idx]
if count == 0:
continue
local_indices = sorted_atom_idx[offset : offset + count] - cumsum[sys_idx]
result[sys_idx] = type(self)(local_indices)
offset += count
return result

def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002
"""Return copy with atom indices shifted by atom_offset."""
return type(self)(self.atom_idx + atom_offset)
Expand Down Expand Up @@ -373,6 +431,20 @@ def select_sub_constraint(
"""
return type(self)(torch.tensor([0])) if sys_idx in self.system_idx else None

def split_constraint(
self,
n_systems: int,
n_atoms_per_system: torch.Tensor, # noqa: ARG002
) -> list[Self | None]:
"""Split system constraint across systems in one pass."""
# One GPU sync: transfer system_idx to CPU
present = set(self.system_idx.tolist())
device = self.system_idx.device
result: list[Self | None] = [None] * n_systems
for sys_idx in present:
result[sys_idx] = type(self)(torch.tensor([0], device=device))
return result

def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002
"""Return copy with system indices shifted by system_offset."""
return type(self)(self.system_idx + system_offset)
Expand Down Expand Up @@ -1125,6 +1197,32 @@ def select_sub_constraint(
max_cumulative_strain=self.max_cumulative_strain,
)

def split_constraint(
self,
n_systems: int,
n_atoms_per_system: torch.Tensor, # noqa: ARG002
) -> list[Self | None]:
"""Split FixSymmetry across systems in one pass."""
# One GPU sync: transfer system_idx to CPU
system_indices_cpu = self.system_idx.tolist()
device = self.system_idx.device
zero = torch.tensor([0], device=device)
result: list[Self | None] = [None] * n_systems
for local_idx, sys_idx in enumerate(system_indices_cpu):
ref_cells = (
[self.reference_cells[local_idx]] if self.reference_cells else None
)
result[sys_idx] = type(self)(
[self.rotations[local_idx]],
[self.symm_maps[local_idx]],
zero.clone(),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use kwargs here? it's hard to understand what you are setting to zero.

adjust_positions=self.do_adjust_positions,
adjust_cell=self.do_adjust_cell,
reference_cells=ref_cells,
max_cumulative_strain=self.max_cumulative_strain,
)
return result

def __repr__(self) -> str:
"""String representation."""
n_ops = [r.shape[0] for r in self.rotations]
Expand Down
26 changes: 13 additions & 13 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,11 +1151,20 @@ def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901
for key, val in state.atom_extras.items():
split_atom_extras[key] = list(torch.split(val, system_sizes, dim=0))

# Pre-split all constraints in one pass per constraint (not per system),
# then transpose to per-system lists so each system's constraints are
# already collected.
n_systems = len(system_sizes)
per_system_constraints: list[list[Constraint]] = [[] for _ in range(n_systems)]
for constraint in state.constraints:
for sys_idx, sub in enumerate(
constraint.split_constraint(n_systems, state.n_atoms_per_system)
):
if sub is not None:
per_system_constraints[sys_idx].append(sub)

# Create a state for each system
states: list[T] = []
n_systems = len(system_sizes)
zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64)
cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0)))
for sys_idx in range(n_systems):
# Build per-system attributes (padded attributes stay padded for consistency)
per_system_dict = {
Expand Down Expand Up @@ -1183,18 +1192,9 @@ def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901
"_atom_extras": {
key: split_atom_extras[key][sys_idx] for key in split_atom_extras
},
"_constraints": per_system_constraints[sys_idx],
}

start_idx = int(cumsum_atoms[sys_idx].item())
end_idx = int(cumsum_atoms[sys_idx + 1].item())
atom_idx = torch.arange(start_idx, end_idx, device=state.device)
new_constraints: list[Constraint] = []
for constraint in state.constraints:
sub = constraint.select_sub_constraint(atom_idx, sys_idx)
if sub is not None:
new_constraints.append(sub)

system_attrs["_constraints"] = new_constraints
states.append(type(state)(**system_attrs)) # ty: ignore[invalid-argument-type]

return states
Expand Down
Loading