Skip to content
131 changes: 131 additions & 0 deletions tests/test_autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,137 @@ def test_binning_auto_batcher_restore_order_with_split_states(
assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers)


def test_binning_auto_batcher_with_iterator(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Test BinningAutoBatcher with an iterator input."""
states = [si_sim_state, fe_supercell_sim_state]

batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=260.0,
)
batcher.load_states(iter(states))

batches = [batch for batch, _ in batcher]

# Check we got the expected number of systems
total_systems = sum(b.n_systems for b in batches)
assert total_systems == len(states)

# Test restore_original_order
restored_states = batcher.restore_original_order(batches)
assert len(restored_states) == len(states)
assert restored_states[0].n_atoms == states[0].n_atoms
assert restored_states[1].n_atoms == states[1].n_atoms
assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers)
assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers)


def test_binning_auto_batcher_with_generator(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Test BinningAutoBatcher with a generator input."""
states = [si_sim_state, fe_supercell_sim_state]

def state_generator():
yield from states

batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=260.0,
)
batcher.load_states(state_generator())

batches = [batch for batch, _ in batcher]
total_systems = sum(b.n_systems for b in batches)
assert total_systems == len(states)

restored_states = batcher.restore_original_order(batches)
assert len(restored_states) == len(states)
assert restored_states[0].n_atoms == states[0].n_atoms
assert restored_states[1].n_atoms == states[1].n_atoms


def test_binning_auto_batcher_streaming_multiple_batches(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Test streaming path produces correct batches across multiple pulls."""
# max_memory_scaler=260 forces fe_supercell (216 atoms) into its own batch,
# si (8 atoms) states can batch together up to 32 per batch.
states = [si_sim_state, fe_supercell_sim_state, si_sim_state]

def state_generator():
yield from states

batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=260.0,
)
batcher.load_states(state_generator())

batches = []
all_indices = []
for batch, indices in batcher:
batches.append(batch)
all_indices.extend(indices)

# All 3 states should have been processed
total_systems = sum(b.n_systems for b in batches)
assert total_systems == len(states)

# Indices should cover all original positions
assert sorted(all_indices) == [0, 1, 2]

# index_bins should be populated incrementally
assert len(batcher.index_bins) == len(batches)

# Restore order should work across streaming batches
restored_states = batcher.restore_original_order(batches)
assert len(restored_states) == len(states)
assert restored_states[0].n_atoms == states[0].n_atoms
assert restored_states[1].n_atoms == states[1].n_atoms
assert restored_states[2].n_atoms == states[2].n_atoms
assert torch.all(restored_states[0].atomic_numbers == states[0].atomic_numbers)
assert torch.all(restored_states[1].atomic_numbers == states[1].atomic_numbers)
assert torch.all(restored_states[2].atomic_numbers == states[2].atomic_numbers)


def test_binning_auto_batcher_empty_iterator(
lj_model: LennardJonesModel,
) -> None:
"""Test BinningAutoBatcher raises ValueError for empty iterator."""
batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
max_memory_scaler=260.0,
)
with pytest.raises(ValueError, match="Iterator yielded no states"):
batcher.load_states(iter([]))


def test_binning_auto_batcher_iterator_requires_max_memory_scaler(
si_sim_state: ts.SimState,
lj_model: LennardJonesModel,
) -> None:
"""Iterator inputs should require an explicit max_memory_scaler."""
batcher = BinningAutoBatcher(
model=lj_model,
memory_scales_with="n_atoms",
)
with pytest.raises(ValueError, match="Iterator inputs require max_memory_scaler"):
batcher.load_states(iter([si_sim_state]))


def test_in_flight_max_metric_too_small(
si_sim_state: ts.SimState,
fe_supercell_sim_state: ts.SimState,
Expand Down
Loading
Loading