diff --git a/tests/test_autobatching.py b/tests/test_autobatching.py index a05e0608..716031c4 100644 --- a/tests/test_autobatching.py +++ b/tests/test_autobatching.py @@ -399,6 +399,137 @@ def test_binning_auto_batcher_restore_order_with_split_states( assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) +def test_binning_auto_batcher_with_iterator( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with an iterator input.""" + states = [si_sim_state, fe_supercell_sim_state] + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + ) + batcher.load_states(iter(states)) + + batches = [batch for batch, _ in batcher] + + # Check we got the expected number of systems + total_systems = sum(b.n_systems for b in batches) + assert total_systems == len(states) + + # Test restore_original_order + restored_states = batcher.restore_original_order(batches) + assert len(restored_states) == len(states) + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + + +def test_binning_auto_batcher_with_generator( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher with a generator input.""" + states = [si_sim_state, fe_supercell_sim_state] + + def state_generator(): + yield from states + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + ) + batcher.load_states(state_generator()) + + batches = [batch for batch, _ in batcher] + total_systems = sum(b.n_systems for b in batches) + assert total_systems == len(states) + + restored_states = batcher.restore_original_order(batches) + assert len(restored_states) == len(states) + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + + +def test_binning_auto_batcher_streaming_multiple_batches( + si_sim_state: ts.SimState, + fe_supercell_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Test streaming path produces correct batches across multiple pulls.""" + # max_memory_scaler=260 forces fe_supercell (216 atoms) into its own batch, + # si (8 atoms) states can batch together up to 32 per batch. + states = [si_sim_state, fe_supercell_sim_state, si_sim_state] + + def state_generator(): + yield from states + + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + ) + batcher.load_states(state_generator()) + + batches = [] + all_indices = [] + for batch, indices in batcher: + batches.append(batch) + all_indices.extend(indices) + + # All 3 states should have been processed + total_systems = sum(b.n_systems for b in batches) + assert total_systems == len(states) + + # Indices should cover all original positions + assert sorted(all_indices) == [0, 1, 2] + + # index_bins should be populated incrementally + assert len(batcher.index_bins) == len(batches) + + # Restore order should work across streaming batches + restored_states = batcher.restore_original_order(batches) + assert len(restored_states) == len(states) + assert restored_states[0].n_atoms == states[0].n_atoms + assert restored_states[1].n_atoms == states[1].n_atoms + assert restored_states[2].n_atoms == states[2].n_atoms + assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers) + assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers) + assert torch.all(restored_states[2].atomic_numbers == states[2].atomic_numbers) + + +def test_binning_auto_batcher_empty_iterator( + lj_model: LennardJonesModel, +) -> None: + """Test BinningAutoBatcher raises ValueError for empty iterator.""" + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + max_memory_scaler=260.0, + ) + with pytest.raises(ValueError, match="Iterator yielded no states"): + batcher.load_states(iter([])) + + +def test_binning_auto_batcher_iterator_requires_max_memory_scaler( + si_sim_state: ts.SimState, + lj_model: LennardJonesModel, +) -> None: + """Iterator inputs should require an explicit max_memory_scaler.""" + batcher = BinningAutoBatcher( + model=lj_model, + memory_scales_with="n_atoms", + ) + with pytest.raises(ValueError, match="Iterator inputs require max_memory_scaler"): + batcher.load_states(iter([si_sim_state])) + + def test_in_flight_max_metric_too_small( si_sim_state: ts.SimState, fe_supercell_sim_state: ts.SimState, diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 26a6775a..db0980ae 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -121,10 +121,12 @@ def _rev_argsort_bins(lst: list[float]) -> list[int]: weights = _get_bins(vals, n_dcs) keys = _get_bins(keys, n_dcs) - bins = [[]] if is_tuple_list else [{}] + list_bins: list[list[Any]] | None = [[]] if is_tuple_list else None + dict_bins: list[dict[int, float]] | None = [{}] if not is_tuple_list else None else: weights = sorted(items, key=lambda x: -x) - bins = [[]] + list_bins = [[]] + dict_bins = None # find the valid indices if lower_bound is not None and upper_bound is not None and lower_bound < upper_bound: @@ -174,9 +176,18 @@ def _rev_argsort_bins(lst: list[float]) -> list[int]: b = len(weight_sum) weight_sum.append(0.0) if isinstance(items, dict): - bins.append([] if is_tuple_list else {}) + if is_tuple_list: + if list_bins is None: + raise TypeError("tuple-list mode requires list bins") + list_bins.append([]) + else: + if dict_bins is None: + raise TypeError("dict mode requires dict bins") + dict_bins.append({}) else: - bins.append([]) + if list_bins is None: + raise TypeError("list items require list bins") + list_bins.append([]) # if we are at the very first item, use the empty bin already open else: @@ -184,15 +195,22 @@ def _rev_argsort_bins(lst: list[float]) -> list[int]: # put it in if isinstance(items, dict): - bin_ = bins[b] if is_tuple_list: + if list_bins is None: + raise TypeError("tuple-list mode requires list bins") + bin_ = list_bins[b] if not isinstance(bin_, list): raise TypeError("bins contain lists when tuple-list mode is used") bin_.append(item_key) - elif isinstance(bin_, dict): + else: + if dict_bins is None: + raise TypeError("dict mode requires dict bins") + bin_ = dict_bins[b] bin_[item_key] = weight else: - bin_ = bins[b] + if list_bins is None: + raise TypeError("list items require list bins") + bin_ = list_bins[b] if not isinstance(bin_, list): raise TypeError("bins contain lists when items is not dict") bin_.append(weight) @@ -202,8 +220,16 @@ def _rev_argsort_bins(lst: list[float]) -> list[int]: weight_sum[b] += weight if not is_tuple_list: - return bins - return [[new_dict[item_key] for item_key in bin_keys] for bin_keys in bins] + if isinstance(items, dict): + if dict_bins is None: + raise TypeError("dict mode requires dict bins") + return dict_bins + if list_bins is None: + raise TypeError("list items require list bins") + return list_bins + if list_bins is None: + raise TypeError("tuple-list mode requires list bins") + return [[new_dict[item_key] for item_key in bin_keys] for bin_keys in list_bins] def measure_model_memory_forward(state: SimState, model: ModelInterface) -> float: @@ -521,12 +547,13 @@ class BinningAutoBatcher[T: SimState]: """Batcher that groups states into bins of similar computational cost. Divides a collection of states into batches that can be processed efficiently - without exceeding GPU memory. States are grouped based on a memory scaling - metric to maximize GPU utilization. This approach is ideal for scenarios where - all states need to be evolved the same number of steps. + without exceeding GPU memory. For eager inputs, states are grouped based on a + memory scaling metric to maximize GPU utilization using global bin packing. + For iterator inputs, batches are formed lazily using greedy first-fit packing. + This approach is ideal for scenarios where all states need to be evolved the + same number of steps. - To avoid a slow memory estimation step, set the `max_memory_scaler` to a - known value. + To avoid a slow memory estimation step, set ``max_memory_scaler`` to a known value. Attributes: model (ModelInterface): Model used for memory estimation and processing. @@ -555,6 +582,17 @@ class BinningAutoBatcher[T: SimState]: # Restore original order ordered_final_states = batcher.restore_original_order(final_states) + + + # Or stream states from a generator using greedy packing + def state_generator(): + for atoms in large_dataset: + yield ts.initialize_state(atoms, device, dtype) + + + batcher.load_states(state_generator()) + for batch, _indices in batcher: + process(batch) """ index_bins: list[list[int]] @@ -608,26 +646,29 @@ def __init__( self.memory_scaling_factor = memory_scaling_factor self.max_memory_padding = max_memory_padding self.oom_error_message = oom_error_message + self._states_iterator: Iterator[T] | None = None - def load_states(self, states: T | Sequence[T]) -> float: + def load_states(self, states: T | Sequence[T] | Iterator[T]) -> float: """Load new states into the batcher. - Processes the input states, computes memory scaling metrics for each, - and organizes them into optimal batches using a bin-packing algorithm - to maximize GPU utilization. + Eager inputs (``SimState`` and ``Sequence``) are fully materialized and + packed up front using global bin packing. Iterator inputs are consumed + lazily and packed greedily in input order. Args: - states (SimState | list[SimState]): Collection of states to batch. Either a - list of individual SimState objects or a single batched SimState that - will be split into individual states. Each SimState has shape - information specific to its instance. + states (SimState | Sequence[SimState] | Iterator[SimState]): Collection + of states to batch. Can be a list of individual SimState objects, + a single batched SimState that will be split into individual states, + or an + iterator/generator yielding individual SimState objects. Returns: float: Maximum memory scaling metric that fits in GPU memory. Raises: ValueError: If any individual state has a memory scaling metric greater - than the maximum allowed value. + than the maximum allowed value, if an iterator yields no states, + or if an iterator is provided without ``max_memory_scaler``. Example:: @@ -637,13 +678,32 @@ def load_states(self, states: T | Sequence[T]) -> float: # Or load a batched state that will be split batcher.load_states(batched_state) + # Or stream states from an iterator/generator + batcher.load_states(iter(states)) + batcher.load_states(state_generator()) + Notes: - This method resets the current state bin index, so any ongoing iteration - will be restarted when this method is called. + Iterator inputs require ``max_memory_scaler`` to be set explicitly. + This method resets batching state, so any ongoing iteration restarts + when it is called. """ - batched = ( - states if isinstance(states, SimState) else ts.concatenate_states(states) - ) + self.memory_scalers: list[float] = [] + self.index_bins = [] + self.batched_states: list[list[T]] = [] + self.current_state_bin = 0 + self._states_iterator = None + + if isinstance(states, SimState): + self._load_eager(states) + elif isinstance(states, Sequence): + self._load_eager(ts.concatenate_states(list(states))) + else: + self._load_streaming(states) + + return self.max_memory_scaler # ty: ignore[invalid-return-type] + + def _load_eager(self, batched: T) -> None: + """Compute metrics and pack all batches upfront (for Sequence / SimState).""" self.memory_scalers = calculate_memory_scalers( batched, self.memory_scales_with, self.cutoff ) @@ -659,7 +719,6 @@ def load_states(self, states: T | Sequence[T]) -> float: self.max_memory_scaler = self.max_memory_scaler * self.max_memory_padding logger.debug("Estimated max memory scaler: %.3g", self.max_memory_scaler) - # verify that no systems are too large max_metric_value = max(self.memory_scalers) max_metric_idx = self.memory_scalers.index(max_metric_value) if max_metric_value > self.max_memory_scaler: @@ -673,11 +732,9 @@ def load_states(self, states: T | Sequence[T]) -> float: self.index_to_scaler = dict(enumerate(self.memory_scalers)) index_bins = to_constant_volume_bins( self.index_to_scaler, max_volume=self.max_memory_scaler - ) # list[dict[original_index: int, memory_scale:float]] - # Convert to list of lists of indices + ) self.index_bins = [list(batch.keys()) for batch in index_bins] self.batched_states = [[batched[index_bin]] for index_bin in self.index_bins] - self.current_state_bin = 0 logger.info( "BinningAutoBatcher: %d systems → %d batch(es), max_memory_scaler=%.3g", @@ -685,14 +742,29 @@ def load_states(self, states: T | Sequence[T]) -> float: len(self.index_bins), self.max_memory_scaler, ) - return self.max_memory_scaler + + def _load_streaming(self, states: Iterator[T]) -> None: + """Prepare for lazy greedy packing from an iterator.""" + states_iter = iter(states) + try: + first = next(states_iter) + except StopIteration as exc: + raise ValueError("Iterator yielded no states") from exc + + if not self.max_memory_scaler: + raise ValueError( + "Iterator inputs require max_memory_scaler to be set explicitly." + ) + + self._states_iterator = chain([first], states_iter) + self._iterator_idx = 0 def next_batch(self) -> tuple[T | None, list[int]]: """Get the next batch of states. - Returns batches sequentially until all states have been processed. Each batch - contains states grouped together to maximize GPU utilization without exceeding - memory constraints. + Returns batches sequentially until all states have been processed. Eager + inputs use pre-computed globally packed batches. Iterator inputs pull + states on demand and pack greedily without materializing the full input. Returns: tuple[T | None, list[int]]: A tuple containing: @@ -705,18 +777,13 @@ def next_batch(self) -> tuple[T | None, list[int]]: # Get batches one by one for batch, indices in batcher: process_batch(batch) - """ - # TODO: need to think about how this intersects with reporting too - # TODO: definitely a clever treatment to be done with iterators here - if self.current_state_bin < len(self.batched_states): + if self._states_iterator is None: + if self.current_state_bin >= len(self.batched_states): + return None, [] state_bin = self.batched_states[self.current_state_bin] state = ts.concatenate_states(state_bin) - indices = ( - self.index_bins[self.current_state_bin] - if self.current_state_bin < len(self.index_bins) - else [] - ) + indices = self.index_bins[self.current_state_bin] self.current_state_bin += 1 remaining = len(self.batched_states) - self.current_state_bin logger.info( @@ -730,7 +797,44 @@ def next_batch(self) -> tuple[T | None, list[int]]: remaining, ) return state, indices - return None, [] + + batch_states: list[T] = [] + batch_indices: list[int] = [] + current_sum = 0.0 + for state in self._states_iterator: + metric = calculate_memory_scalers( + state, self.memory_scales_with, self.cutoff + )[0] + if metric > self.max_memory_scaler: # ty: ignore[unsupported-operator] + raise ValueError( + f"State {metric=} is greater than max_metric " + f"{self.max_memory_scaler}, please set a larger max_metric " + f"or run smaller systems metric." + ) + if ( + current_sum + metric > self.max_memory_scaler # ty: ignore[unsupported-operator] + and batch_states + ): + self._states_iterator = chain([state], self._states_iterator) + break + batch_states.append(state) + batch_indices.append(self._iterator_idx) + self.memory_scalers.append(metric) + self._iterator_idx += 1 + current_sum += metric + + if not batch_states: + return None, [] + + self.index_bins.append(batch_indices) + self.current_state_bin += 1 + batch = ts.concatenate_states(batch_states) + logger.info( + "BinningAutoBatcher: returning batch %d with %d system(s) (streaming)", + self.current_state_bin, + batch.n_systems, + ) + return batch, batch_indices def __iter__(self) -> Iterator[tuple[T, list[int]]]: """Return self as an iterator.