diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index edd5ac3645..03cfb654ab 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -75,6 +75,11 @@ jobs: packages/fairchem-demo-ocpapi[dev] \ -r tests/requirements.txt # pin test packages + - name: Install torchsim (Python 3.12+) + if: ${{ matrix.python_version == '3.12' || matrix.python_version == '3.13' }} + run: | + pip install packages/fairchem-core[torchsim] + - name: Install additional dependencies run: | wget https://github.com/m3g/packmol/archive/refs/tags/v20.15.0.tar.gz @@ -160,6 +165,11 @@ jobs: packages/fairchem-data-omat \ -r tests/requirements.txt # pin test packages + - name: Install torchsim (Python 3.12+) + if: ${{ matrix.python_version == '3.12' || matrix.python_version == '3.13' }} + run: | + pip install packages/fairchem-core[torchsim] + - name: Core GPU tests env: HF_TOKEN: ${{ secrets.HF_TOKEN }} diff --git a/packages/fairchem-core/pyproject.toml b/packages/fairchem-core/pyproject.toml index cae0e1f78e..aaf0c604ce 100644 --- a/packages/fairchem-core/pyproject.toml +++ b/packages/fairchem-core/pyproject.toml @@ -37,9 +37,9 @@ dependencies = [ dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"] docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "astroid<4", "umap-learn", "vdict", "ipywidgets", "jupyter_book>=2.0", "torch-dftd"] adsorbml = ["dscribe", "x3dase", "scikit-image"] +torchsim = ["torch-sim-atomistic>=0.5.2; python_version >= '3.12'"] extras = ["pymatgen", "quacc[phonons]>=0.15.3", "pandas", "nvalchemi-toolkit-ops", "pyarrow"] - [project.scripts] fairchem = "fairchem.core._cli:main" diff --git a/src/fairchem/core/calculate/torchsim_interface.py b/src/fairchem/core/calculate/torchsim_interface.py new file mode 100644 index 0000000000..4a7e900909 --- /dev/null +++ b/src/fairchem/core/calculate/torchsim_interface.py @@ -0,0 +1,250 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import os +import typing +from pathlib import Path + +import torch + +from fairchem.core import pretrained_mlip +from fairchem.core.calculate.ase_calculator import UMATask +from fairchem.core.common.utils import setup_imports, setup_logging +from fairchem.core.datasets.atomic_data import AtomicData, atomicdata_list_to_batch + +try: + import torch_sim as ts + from torch_sim.models.interface import ModelInterface +except ImportError: + ts = None + ModelInterface = None + + +if typing.TYPE_CHECKING: + from collections.abc import Callable + + from torch_sim import SimState + from torch_sim.typing import StateDict + +# Use object as fallback base class if ModelInterface is not available +# The __init__ method will raise ImportError if torch-sim is not installed +_TSModelInterface = ModelInterface if ModelInterface is not None else object + + +class FairChemModel(_TSModelInterface): # type: ignore[misc] + """FairChem model wrapper for computing atomistic properties. + + Wraps FairChem models to compute energies, forces, and stresses. Can be + initialized with a model checkpoint path or pretrained model name. + + Uses the fairchem-core-2.2.0+ predictor API for batch inference. + + Attributes: + predictor: The FairChem predictor for batch inference + task_name (UMATask): Task type for the model + _device (torch.device): Device where computation is performed + _dtype (torch.dtype): Data type used for computation + _compute_stress (bool): Whether to compute stress tensor + implemented_properties (list): Model outputs the model can compute + + Examples: + >>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True) + >>> results = model(state) + """ + + def __init__( + self, + model: str | Path, + neighbor_list_fn: Callable | None = None, + *, # force remaining arguments to be keyword-only + model_cache_dir: str | Path | None = None, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + compute_stress: bool = False, + task_name: UMATask | str | None = None, + ) -> None: + """Initialize the FairChem model. + + Args: + model (str | Path): Either a pretrained model name or path to model + checkpoint file. The function will first check if the input matches + a known pretrained model name, then check if it's a valid file path. + neighbor_list_fn (Callable | None): Function to compute neighbor lists + (not currently supported) + model_cache_dir (str | Path | None): Path where to save the model + device (torch.device | None): Device to use for computation. If None, + defaults to CUDA if available, otherwise CPU. + dtype (torch.dtype | None): Data type to use for computation + compute_stress (bool): Whether to compute stress tensor + task_name (UMATask | str | None): Task type for UMA models (optional, + only needed for UMA models) + + Raises: + ImportError: If torch-sim is not installed + NotImplementedError: If custom neighbor list function is provided + ValueError: If model is not a known model name or valid file path + """ + if ts is None or ModelInterface is None: + raise ImportError( + "torch-sim is required to use FairChemModel. " + + "Install it with: pip install fairchem-core[torchsim]" + ) + + setup_imports() + setup_logging() + super().__init__() + + self._dtype = dtype or torch.float32 + self._compute_stress = compute_stress + self._compute_forces = True + self._memory_scales_with = "n_atoms" + + if neighbor_list_fn is not None: + raise NotImplementedError( + "Custom neighbor list is not supported for FairChemModel." + ) + + # Convert Path to string for consistency + if isinstance(model, Path): + model = str(model) + + # Convert task_name to UMATask if it's a string (only for UMA models) + if isinstance(task_name, str): + task_name = UMATask(task_name) + + # Use the efficient predictor API for optimal performance + self._device = device or torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + device_str = str(self._device) + self.task_name = task_name + + # Create efficient batch predictor for fast inference + if model in pretrained_mlip.available_models: + if model_cache_dir and model_cache_dir.exists(): + self.predictor = pretrained_mlip.get_predict_unit( + model, device=device_str, cache_dir=model_cache_dir + ) + else: + self.predictor = pretrained_mlip.get_predict_unit( + model, device=device_str + ) + elif os.path.isfile(model): + self.predictor = pretrained_mlip.load_predict_unit(model, device=device_str) + else: + raise ValueError( + f"Invalid model name or checkpoint path: {model}. " + f"Available pretrained models are: {pretrained_mlip.available_models}" + ) + + # Determine implemented properties + # This is a simplified approach - in practice you might want to + # inspect the model configuration more carefully + self.implemented_properties = ["energy", "forces"] + if compute_stress: + self.implemented_properties.append("stress") + + @property + def dtype(self) -> torch.dtype: + """Return the data type used by the model.""" + return self._dtype + + @property + def device(self) -> torch.device: + """Return the device where the model is located.""" + return self._device + + def forward(self, state: SimState | StateDict | dict) -> dict: + """Compute energies, forces, and other properties. + + Args: + state (SimState | StateDict): State object containing positions, cells, + atomic numbers, and other system information. If a dictionary is provided, + it will be converted to a SimState. + + Returns: + dict: Dictionary of model predictions, which may include: + - energy (torch.Tensor): Energy with shape [batch_size] + - forces (torch.Tensor): Forces with shape [n_atoms, 3] + - stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3] + """ + sim_state = ( + state + if isinstance(state, ts.SimState) + else ts.SimState(**state, masses=torch.ones_like(state["positions"])) + ) + + if sim_state.device != self._device: + sim_state = sim_state.to(self._device) + + # Ensure system_idx has integer dtype (SimState guarantees presence) + if sim_state.system_idx.dtype != torch.int64: + sim_state.system_idx = sim_state.system_idx.to(dtype=torch.int64) + + # Convert SimState to AtomicData objects for efficient batch processing + from ase import Atoms + + n_atoms = torch.bincount(sim_state.system_idx) + atomic_data_list = [] + + for idx, (n, c) in enumerate( + zip(n_atoms, torch.cumsum(n_atoms, dim=0), strict=False) + ): + # Extract system data + positions = sim_state.positions[c - n : c].cpu().numpy() + atomic_nums = sim_state.atomic_numbers[c - n : c].cpu().numpy() + pbc = sim_state.pbc.cpu().numpy() + cell = ( + sim_state.row_vector_cell[idx].cpu().numpy() + if sim_state.row_vector_cell is not None + else None + ) + + # Create ASE Atoms object first + atoms = Atoms( + numbers=atomic_nums, + positions=positions, + cell=cell, + pbc=pbc if cell is not None else False, + ) + + atoms.info["charge"] = sim_state.charge[idx].item() + atoms.info["spin"] = sim_state.spin[idx].item() + + # Convert ASE Atoms to AtomicData (task_name only applies to UMA models) + # r_data_keys must be passed for charge/spin to be read from atoms.info + if self.task_name is None: + atomic_data = AtomicData.from_ase(atoms, r_data_keys=["charge", "spin"]) + else: + atomic_data = AtomicData.from_ase( + atoms, task_name=self.task_name, r_data_keys=["charge", "spin"] + ) + atomic_data_list.append(atomic_data) + + # Create batch for efficient inference + batch = atomicdata_list_to_batch(atomic_data_list) + batch = batch.to(self._device) + + # Run efficient batch prediction + predictions = self.predictor.predict(batch) + + # Convert predictions to torch-sim format + results: dict[str, torch.Tensor] = {} + results["energy"] = predictions["energy"].to(dtype=self._dtype) + results["forces"] = predictions["forces"].to(dtype=self._dtype) + + # Handle stress if requested and available + if self._compute_stress and "stress" in predictions: + stress = predictions["stress"].to(dtype=self._dtype) + # Ensure stress has correct shape [batch_size, 3, 3] + if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list): + stress = stress.view(-1, 3, 3) + results["stress"] = stress + + return results diff --git a/tests/core/calculate/test_torchsim_interface.py b/tests/core/calculate/test_torchsim_interface.py new file mode 100644 index 0000000000..98a4fc02a7 --- /dev/null +++ b/tests/core/calculate/test_torchsim_interface.py @@ -0,0 +1,344 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest +import torch +from ase.build import bulk, molecule + +from fairchem.core.calculate.torchsim_interface import FairChemModel + +if TYPE_CHECKING: + from collections.abc import Callable + +pytest.importorskip( + "torch_sim", + reason="torch_sim not installed. Install with: pip install fairchem-core[torchsim]", +) + +import torch_sim as ts # noqa: E402 +from torch_sim.models.interface import validate_model_outputs # noqa: E402 + +DTYPE = torch.float32 +DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + +@pytest.fixture() +def torchsim_model_oc20(direct_checkpoint) -> FairChemModel: + """Model for materials (periodic boundary conditions) using locally-trained checkpoint. + + Note: The checkpoint is trained on oc20_omol tasks, so it supports both: + - oc20 task (PBC - surfaces/catalysis) + - omol task (non-PBC - molecules) + """ + checkpoint_path, _ = direct_checkpoint + return FairChemModel(model=checkpoint_path, task_name="oc20", device=DEVICE) + + +@pytest.fixture() +def torchsim_model_omol(direct_checkpoint) -> FairChemModel: + """Model for molecules (non-PBC) using locally-trained checkpoint. + + Note: The checkpoint is trained on oc20_omol tasks, so it supports both: + - oc20 task (PBC - surfaces/catalysis) + - omol task (non-PBC - molecules) + """ + checkpoint_path, _ = direct_checkpoint + return FairChemModel(model=checkpoint_path, task_name="omol", device=DEVICE) + + +@pytest.mark.parametrize("task_name", ["oc20", "omol"]) +def test_task_initialization(direct_checkpoint, task_name: str) -> None: + """Test that different task names initialize correctly.""" + checkpoint_path, _ = direct_checkpoint + model = FairChemModel( + model=checkpoint_path, task_name=task_name, device=torch.device("cpu") + ) + assert model.task_name + assert str(model.task_name.value) == task_name + assert hasattr(model, "predictor") + + +@pytest.mark.parametrize( + ("task_name", "systems_func"), + [ + ( + "oc20", + lambda: [ + bulk("Si", "diamond", a=5.43), + bulk("Al", "fcc", a=4.05), + bulk("Fe", "bcc", a=2.87), + bulk("Cu", "fcc", a=3.61), + ], + ), + ( + "omol", + lambda: [ + molecule("H2O"), + molecule("CO2"), + molecule("CH4"), + molecule("NH3"), + ], + ), + ], +) +def test_homogeneous_batching( + direct_checkpoint, task_name: str, systems_func: Callable +) -> None: + """Test batching multiple systems with the same task.""" + systems = systems_func() + checkpoint_path, _ = direct_checkpoint + + if task_name == "omol": + for mol in systems: + mol.info |= {"charge": 0, "spin": 1} + + model = FairChemModel(model=checkpoint_path, task_name=task_name, device=DEVICE) + state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) + results = model(state) + + assert results["energy"].shape == (4,) + assert results["forces"].shape[0] == sum(len(s) for s in systems) + assert results["forces"].shape[1] == 3 + + energies = results["energy"] + uniq_energies = torch.unique(energies, dim=0) + assert len(uniq_energies) > 1, "Different systems should have different energies" + + +def test_heterogeneous_tasks(direct_checkpoint) -> None: + """Test different task types work with appropriate systems.""" + checkpoint_path, _ = direct_checkpoint + test_cases = [ + ("omol", [molecule("H2O")]), + ("oc20", [bulk("Pt", cubic=True)]), + ] + + for task_name, systems in test_cases: + if task_name == "omol": + systems[0].info |= {"charge": 0, "spin": 1} + + model = FairChemModel( + model=checkpoint_path, + task_name=task_name, + device=DEVICE, + ) + state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) + results = model(state) + + assert results["energy"].shape[0] == 1 + assert results["forces"].dim() == 2 + assert results["forces"].shape[1] == 3 + + +@pytest.mark.parametrize( + ("systems_func", "expected_count"), + [ + (lambda: [bulk("Si", "diamond", a=5.43)], 1), + ( + lambda: [ + bulk("H", "bcc", a=2.0), + bulk("Li", "bcc", a=3.0), + bulk("Si", "diamond", a=5.43), + bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)), + ], + 4, + ), + ( + lambda: [ + bulk(element, "fcc", a=4.0) + for element in ("Al", "Cu", "Ni", "Pd", "Pt") * 3 + ], + 15, + ), + ], +) +def test_batch_size_variations( + direct_checkpoint, systems_func: Callable, expected_count: int +) -> None: + """Test batching with different numbers and sizes of systems.""" + systems = systems_func() + checkpoint_path, _ = direct_checkpoint + + model = FairChemModel(model=checkpoint_path, task_name="oc20", device=DEVICE) + state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) + results = model(state) + + assert results["energy"].shape == (expected_count,) + assert results["forces"].shape[0] == sum(len(s) for s in systems) + assert results["forces"].shape[1] == 3 + assert torch.isfinite(results["energy"]).all() + assert torch.isfinite(results["forces"]).all() + + +@pytest.mark.parametrize("compute_stress", [True, False]) +def test_stress_computation( + conserving_mole_checkpoint, *, compute_stress: bool +) -> None: + """Test stress tensor computation using a conservative (non-direct-force) model.""" + systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)] + checkpoint_path, _ = conserving_mole_checkpoint + + model = FairChemModel( + model=checkpoint_path, + task_name="oc20", + device=DEVICE, + compute_stress=compute_stress, + ) + state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE) + results = model(state) + + assert "energy" in results + assert "forces" in results + if compute_stress: + assert "stress" in results + assert results["stress"].shape == (2, 3, 3) + assert torch.isfinite(results["stress"]).all() + else: + assert "stress" not in results + + +def test_device_consistency(direct_checkpoint) -> None: + """Test device consistency between model and data.""" + checkpoint_path, _ = direct_checkpoint + model = FairChemModel(model=checkpoint_path, task_name="oc20", device=DEVICE) + system = bulk("Si", "diamond", a=5.43) + state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE) + + results = model(state) + assert results["energy"].device == DEVICE + assert results["forces"].device == DEVICE + + +def test_empty_batch_error(direct_checkpoint) -> None: + """Test that empty batches raise appropriate errors.""" + checkpoint_path, _ = direct_checkpoint + model = FairChemModel( + model=checkpoint_path, task_name="oc20", device=torch.device("cpu") + ) + with pytest.raises((ValueError, RuntimeError, IndexError)): + model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32)) + + +def test_load_from_checkpoint_path(direct_checkpoint) -> None: + """Test loading model from a saved checkpoint file path.""" + checkpoint_path, _ = direct_checkpoint + loaded_model = FairChemModel(model=checkpoint_path, task_name="oc20", device=DEVICE) + + system = bulk("Si", "diamond", a=5.43) + state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE) + results = loaded_model(state) + + assert "energy" in results + assert "forces" in results + assert results["energy"].shape == (1,) + assert torch.isfinite(results["energy"]).all() + assert torch.isfinite(results["forces"]).all() + + +@pytest.mark.parametrize( + ("charge", "spin"), + [ + (0.0, 0.0), + (1.0, 1.0), + (-1.0, 0.0), + (0.0, 2.0), + ], +) +def test_charge_spin_handling(direct_checkpoint, charge: float, spin: float) -> None: + """Test that FairChemModel correctly handles charge and spin from atoms.info.""" + mol = molecule("H2O") + mol.info["charge"] = charge + mol.info["spin"] = spin + + state = ts.io.atoms_to_state([mol], device=DEVICE, dtype=DTYPE) + + assert state.charge[0].item() == charge + assert state.spin[0].item() == spin + + checkpoint_path, _ = direct_checkpoint + model = FairChemModel( + model=checkpoint_path, + task_name="omol", + device=DEVICE, + ) + + result = model(state) + + assert "energy" in result + assert result["energy"].shape == (1,) + assert "forces" in result + assert result["forces"].shape == (len(mol), 3) + assert torch.isfinite(result["energy"]).all() + assert torch.isfinite(result["forces"]).all() + + +def test_model_output_validation(torchsim_model_oc20: FairChemModel) -> None: + """Test that the model implementation follows the ModelInterface contract.""" + validate_model_outputs(torchsim_model_oc20, DEVICE, DTYPE) + + +def test_model_output_validation_with_stress(conserving_mole_checkpoint) -> None: + """Test ModelInterface contract for a conservative model that predicts stresses.""" + checkpoint_path, _ = conserving_mole_checkpoint + model = FairChemModel( + model=checkpoint_path, task_name="oc20", device=DEVICE, compute_stress=True + ) + validate_model_outputs(model, DEVICE, DTYPE) + + +def test_missing_torchsim_raises_import_error(monkeypatch) -> None: + """Test that FairChemModel raises ImportError when torch-sim is not installed.""" + # Mock the module-level variables to simulate torch-sim not being installed + import fairchem.core.calculate.torchsim_interface as torchsim_module + + # Save original values + original_ts = torchsim_module.ts + original_model_interface = torchsim_module.ModelInterface + + # Set to None to simulate missing torch-sim + monkeypatch.setattr(torchsim_module, "ts", None) + monkeypatch.setattr(torchsim_module, "ModelInterface", None) + + # Now try to instantiate - should raise ImportError + with pytest.raises( + ImportError, match="torch-sim is required to use FairChemModel.*Install it with" + ): + FairChemModel(model="dummy", task_name="oc20") + + # Restore original values (monkeypatch will do this automatically, but being explicit) + monkeypatch.setattr(torchsim_module, "ts", original_ts) + monkeypatch.setattr(torchsim_module, "ModelInterface", original_model_interface) + + +def test_invalid_model_path_raises_error() -> None: + """Test that FairChemModel raises ValueError for invalid model path.""" + with pytest.raises(ValueError, match="Invalid model name or checkpoint path"): + FairChemModel(model="/nonexistent/path/to/checkpoint.pt", task_name="oc20") + + +def test_invalid_task_name_raises_error(direct_checkpoint) -> None: + """Test that FairChemModel raises error for invalid task name.""" + checkpoint_path, _ = direct_checkpoint + with pytest.raises((ValueError, KeyError)): + FairChemModel(model=checkpoint_path, task_name="invalid_task") + + +def test_custom_neighbor_list_raises_error(direct_checkpoint) -> None: + """Test that FairChemModel raises NotImplementedError for custom neighbor list.""" + checkpoint_path, _ = direct_checkpoint + with pytest.raises( + NotImplementedError, match="Custom neighbor list is not supported" + ): + FairChemModel( + model=checkpoint_path, + task_name="oc20", + neighbor_list_fn=lambda x: x, + )