diff --git a/torch_sim/constraints.py b/torch_sim/constraints.py index c6073f01..068c7f28 100644 --- a/torch_sim/constraints.py +++ b/torch_sim/constraints.py @@ -151,6 +151,38 @@ def merge(cls, constraints: list[Constraint]) -> Self: constraints: Constraints to merge (all same type, already reindexed) """ + def split_constraint( + self, + n_systems: int, + n_atoms_per_system: torch.Tensor, + ) -> list[Self | None]: + """Split this constraint into one constraint per system in a single pass. + + Returns a list of length ``n_systems`` where each element is either a + single-system constraint (with local indices starting at 0) or ``None`` + if this constraint does not apply to that system. + + The default implementation falls back to calling + :meth:`select_sub_constraint` in a loop. Subclasses should override + for better performance (fewer GPU synchronisation points). + + Args: + n_systems: Total number of systems in the batched state. + n_atoms_per_system: Tensor of shape ``(n_systems,)`` giving atom + counts per system. + + Returns: + List of length *n_systems* with per-system constraints or ``None``. + """ + cumsum = _cumsum_with_zero(n_atoms_per_system) + result: list[Self | None] = [None] * n_systems + for sys_idx in range(n_systems): + start = int(cumsum[sys_idx].item()) + end = int(cumsum[sys_idx + 1].item()) + atom_idx = torch.arange(start, end, device=n_atoms_per_system.device) + result[sys_idx] = self.select_sub_constraint(atom_idx, sys_idx) + return result + @abstractmethod def to( self, @@ -268,6 +300,32 @@ def select_sub_constraint( return None return type(self)(new_atom_idx) + def split_constraint( + self, + n_systems: int, + n_atoms_per_system: torch.Tensor, + ) -> list[Self | None]: + """Split atom constraint across systems in one pass.""" + cumsum = _cumsum_with_zero(n_atoms_per_system) + # Assign each constrained atom to its system via searchsorted + system_of_atom = torch.searchsorted(cumsum[1:], self.atom_idx, right=True) + # One GPU sync: get counts per system on CPU + counts = torch.bincount(system_of_atom, minlength=n_systems).tolist() + # Sort constrained atom indices by system for contiguous slicing + sort_order = torch.argsort(system_of_atom) + sorted_atom_idx = self.atom_idx[sort_order] + + result: list[Self | None] = [None] * n_systems + offset = 0 + for sys_idx in range(n_systems): + count = counts[sys_idx] + if count == 0: + continue + local_indices = sorted_atom_idx[offset : offset + count] - cumsum[sys_idx] + result[sys_idx] = type(self)(local_indices) + offset += count + return result + def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 """Return copy with atom indices shifted by atom_offset.""" return type(self)(self.atom_idx + atom_offset) @@ -373,6 +431,20 @@ def select_sub_constraint( """ return type(self)(torch.tensor([0])) if sys_idx in self.system_idx else None + def split_constraint( + self, + n_systems: int, + n_atoms_per_system: torch.Tensor, # noqa: ARG002 + ) -> list[Self | None]: + """Split system constraint across systems in one pass.""" + # One GPU sync: transfer system_idx to CPU + present = set(self.system_idx.tolist()) + device = self.system_idx.device + result: list[Self | None] = [None] * n_systems + for sys_idx in present: + result[sys_idx] = type(self)(torch.tensor([0], device=device)) + return result + def reindex(self, atom_offset: int, system_offset: int) -> Self: # noqa: ARG002 """Return copy with system indices shifted by system_offset.""" return type(self)(self.system_idx + system_offset) @@ -1125,6 +1197,32 @@ def select_sub_constraint( max_cumulative_strain=self.max_cumulative_strain, ) + def split_constraint( + self, + n_systems: int, + n_atoms_per_system: torch.Tensor, # noqa: ARG002 + ) -> list[Self | None]: + """Split FixSymmetry across systems in one pass.""" + # One GPU sync: transfer system_idx to CPU + system_indices_cpu = self.system_idx.tolist() + device = self.system_idx.device + zero = torch.tensor([0], device=device) + result: list[Self | None] = [None] * n_systems + for local_idx, sys_idx in enumerate(system_indices_cpu): + ref_cells = ( + [self.reference_cells[local_idx]] if self.reference_cells else None + ) + result[sys_idx] = type(self)( + [self.rotations[local_idx]], + [self.symm_maps[local_idx]], + zero.clone(), + adjust_positions=self.do_adjust_positions, + adjust_cell=self.do_adjust_cell, + reference_cells=ref_cells, + max_cumulative_strain=self.max_cumulative_strain, + ) + return result + def __repr__(self) -> str: """String representation.""" n_ops = [r.shape[0] for r in self.rotations] diff --git a/torch_sim/state.py b/torch_sim/state.py index 6584480a..8c6a7a04 100644 --- a/torch_sim/state.py +++ b/torch_sim/state.py @@ -1151,11 +1151,20 @@ def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901 for key, val in state.atom_extras.items(): split_atom_extras[key] = list(torch.split(val, system_sizes, dim=0)) + # Pre-split all constraints in one pass per constraint (not per system), + # then transpose to per-system lists so each system's constraints are + # already collected. + n_systems = len(system_sizes) + per_system_constraints: list[list[Constraint]] = [[] for _ in range(n_systems)] + for constraint in state.constraints: + for sys_idx, sub in enumerate( + constraint.split_constraint(n_systems, state.n_atoms_per_system) + ): + if sub is not None: + per_system_constraints[sys_idx].append(sub) + # Create a state for each system states: list[T] = [] - n_systems = len(system_sizes) - zero_tensor = torch.tensor([0], device=state.device, dtype=torch.int64) - cumsum_atoms = torch.cat((zero_tensor, torch.cumsum(state.n_atoms_per_system, dim=0))) for sys_idx in range(n_systems): # Build per-system attributes (padded attributes stay padded for consistency) per_system_dict = { @@ -1183,18 +1192,9 @@ def _split_state[T: SimState](state: T) -> list[T]: # noqa: C901 "_atom_extras": { key: split_atom_extras[key][sys_idx] for key in split_atom_extras }, + "_constraints": per_system_constraints[sys_idx], } - start_idx = int(cumsum_atoms[sys_idx].item()) - end_idx = int(cumsum_atoms[sys_idx + 1].item()) - atom_idx = torch.arange(start_idx, end_idx, device=state.device) - new_constraints: list[Constraint] = [] - for constraint in state.constraints: - sub = constraint.select_sub_constraint(atom_idx, sys_idx) - if sub is not None: - new_constraints.append(sub) - - system_attrs["_constraints"] = new_constraints states.append(type(state)(**system_attrs)) # ty: ignore[invalid-argument-type] return states