Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ jobs:
packages/fairchem-demo-ocpapi[dev] \
-r tests/requirements.txt # pin test packages

- name: Install torchsim (Python 3.12+)
if: ${{ matrix.python_version == '3.12' || matrix.python_version == '3.13' }}
run: |
pip install packages/fairchem-core[torchsim]

- name: Install additional dependencies
run: |
wget https://github.com/m3g/packmol/archive/refs/tags/v20.15.0.tar.gz
Expand Down Expand Up @@ -160,6 +165,11 @@ jobs:
packages/fairchem-data-omat \
-r tests/requirements.txt # pin test packages

- name: Install torchsim (Python 3.12+)
if: ${{ matrix.python_version == '3.12' || matrix.python_version == '3.13' }}
run: |
pip install packages/fairchem-core[torchsim]

- name: Core GPU tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion packages/fairchem-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ dependencies = [
dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"]
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "astroid<4", "umap-learn", "vdict", "ipywidgets", "jupyter_book>=2.0", "torch-dftd"]
adsorbml = ["dscribe", "x3dase", "scikit-image"]
torchsim = ["torch-sim-atomistic>=0.5.2; python_version >= '3.12'"]
extras = ["pymatgen", "quacc[phonons]>=0.15.3", "pandas", "nvalchemi-toolkit-ops", "pyarrow"]


[project.scripts]
fairchem = "fairchem.core._cli:main"

Expand Down
250 changes: 250 additions & 0 deletions src/fairchem/core/calculate/torchsim_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,250 @@
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
Comment thread
lbluque marked this conversation as resolved.

from __future__ import annotations

import os
import typing
from pathlib import Path

import torch

from fairchem.core import pretrained_mlip
from fairchem.core.calculate.ase_calculator import UMATask
from fairchem.core.common.utils import setup_imports, setup_logging
from fairchem.core.datasets.atomic_data import AtomicData, atomicdata_list_to_batch

try:
import torch_sim as ts
from torch_sim.models.interface import ModelInterface
except ImportError:
ts = None
ModelInterface = None


if typing.TYPE_CHECKING:
from collections.abc import Callable

from torch_sim import SimState
from torch_sim.typing import StateDict

# Use object as fallback base class if ModelInterface is not available
# The __init__ method will raise ImportError if torch-sim is not installed
_TSModelInterface = ModelInterface if ModelInterface is not None else object


class FairChemModel(_TSModelInterface): # type: ignore[misc]
"""FairChem model wrapper for computing atomistic properties.

Wraps FairChem models to compute energies, forces, and stresses. Can be
initialized with a model checkpoint path or pretrained model name.

Uses the fairchem-core-2.2.0+ predictor API for batch inference.

Attributes:
predictor: The FairChem predictor for batch inference
task_name (UMATask): Task type for the model
_device (torch.device): Device where computation is performed
_dtype (torch.dtype): Data type used for computation
_compute_stress (bool): Whether to compute stress tensor
implemented_properties (list): Model outputs the model can compute

Examples:
>>> model = FairChemModel(model="path/to/checkpoint.pt", compute_stress=True)
>>> results = model(state)
"""

def __init__(
self,
model: str | Path,
neighbor_list_fn: Callable | None = None,
*, # force remaining arguments to be keyword-only
model_cache_dir: str | Path | None = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
compute_stress: bool = False,
task_name: UMATask | str | None = None,
) -> None:
"""Initialize the FairChem model.

Args:
model (str | Path): Either a pretrained model name or path to model
checkpoint file. The function will first check if the input matches
a known pretrained model name, then check if it's a valid file path.
neighbor_list_fn (Callable | None): Function to compute neighbor lists
(not currently supported)
model_cache_dir (str | Path | None): Path where to save the model
device (torch.device | None): Device to use for computation. If None,
defaults to CUDA if available, otherwise CPU.
dtype (torch.dtype | None): Data type to use for computation
compute_stress (bool): Whether to compute stress tensor
task_name (UMATask | str | None): Task type for UMA models (optional,
only needed for UMA models)

Raises:
ImportError: If torch-sim is not installed
NotImplementedError: If custom neighbor list function is provided
ValueError: If model is not a known model name or valid file path
"""
if ts is None or ModelInterface is None:
raise ImportError(
"torch-sim is required to use FairChemModel. "
+ "Install it with: pip install fairchem-core[torchsim]"
)

setup_imports()
setup_logging()
super().__init__()

self._dtype = dtype or torch.float32
self._compute_stress = compute_stress
self._compute_forces = True
self._memory_scales_with = "n_atoms"

if neighbor_list_fn is not None:
raise NotImplementedError(
"Custom neighbor list is not supported for FairChemModel."
)

# Convert Path to string for consistency
if isinstance(model, Path):
model = str(model)

# Convert task_name to UMATask if it's a string (only for UMA models)
if isinstance(task_name, str):
task_name = UMATask(task_name)

# Use the efficient predictor API for optimal performance
self._device = device or torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
device_str = str(self._device)
self.task_name = task_name

# Create efficient batch predictor for fast inference
if model in pretrained_mlip.available_models:
if model_cache_dir and model_cache_dir.exists():
self.predictor = pretrained_mlip.get_predict_unit(
model, device=device_str, cache_dir=model_cache_dir
)
else:
self.predictor = pretrained_mlip.get_predict_unit(
model, device=device_str
)
elif os.path.isfile(model):
self.predictor = pretrained_mlip.load_predict_unit(model, device=device_str)
else:
raise ValueError(
f"Invalid model name or checkpoint path: {model}. "
f"Available pretrained models are: {pretrained_mlip.available_models}"
)

# Determine implemented properties
# This is a simplified approach - in practice you might want to
# inspect the model configuration more carefully
self.implemented_properties = ["energy", "forces"]
if compute_stress:
self.implemented_properties.append("stress")

@property
def dtype(self) -> torch.dtype:
"""Return the data type used by the model."""
return self._dtype

@property
def device(self) -> torch.device:
"""Return the device where the model is located."""
return self._device

def forward(self, state: SimState | StateDict | dict) -> dict:
"""Compute energies, forces, and other properties.

Args:
state (SimState | StateDict): State object containing positions, cells,
atomic numbers, and other system information. If a dictionary is provided,
it will be converted to a SimState.

Returns:
dict: Dictionary of model predictions, which may include:
- energy (torch.Tensor): Energy with shape [batch_size]
- forces (torch.Tensor): Forces with shape [n_atoms, 3]
- stress (torch.Tensor): Stress tensor with shape [batch_size, 3, 3]
"""
sim_state = (
state
if isinstance(state, ts.SimState)
else ts.SimState(**state, masses=torch.ones_like(state["positions"]))
)

if sim_state.device != self._device:
sim_state = sim_state.to(self._device)

# Ensure system_idx has integer dtype (SimState guarantees presence)
if sim_state.system_idx.dtype != torch.int64:
sim_state.system_idx = sim_state.system_idx.to(dtype=torch.int64)

# Convert SimState to AtomicData objects for efficient batch processing
from ase import Atoms

n_atoms = torch.bincount(sim_state.system_idx)
atomic_data_list = []

for idx, (n, c) in enumerate(
zip(n_atoms, torch.cumsum(n_atoms, dim=0), strict=False)
):
# Extract system data
positions = sim_state.positions[c - n : c].cpu().numpy()
atomic_nums = sim_state.atomic_numbers[c - n : c].cpu().numpy()
pbc = sim_state.pbc.cpu().numpy()
cell = (
sim_state.row_vector_cell[idx].cpu().numpy()
if sim_state.row_vector_cell is not None
else None
)

# Create ASE Atoms object first
atoms = Atoms(
numbers=atomic_nums,
positions=positions,
cell=cell,
pbc=pbc if cell is not None else False,
)

atoms.info["charge"] = sim_state.charge[idx].item()
atoms.info["spin"] = sim_state.spin[idx].item()

# Convert ASE Atoms to AtomicData (task_name only applies to UMA models)
# r_data_keys must be passed for charge/spin to be read from atoms.info
if self.task_name is None:
atomic_data = AtomicData.from_ase(atoms, r_data_keys=["charge", "spin"])
else:
atomic_data = AtomicData.from_ase(
atoms, task_name=self.task_name, r_data_keys=["charge", "spin"]
)
atomic_data_list.append(atomic_data)

# Create batch for efficient inference
batch = atomicdata_list_to_batch(atomic_data_list)
batch = batch.to(self._device)

# Run efficient batch prediction
predictions = self.predictor.predict(batch)

# Convert predictions to torch-sim format
results: dict[str, torch.Tensor] = {}
results["energy"] = predictions["energy"].to(dtype=self._dtype)
results["forces"] = predictions["forces"].to(dtype=self._dtype)

# Handle stress if requested and available
if self._compute_stress and "stress" in predictions:
stress = predictions["stress"].to(dtype=self._dtype)
# Ensure stress has correct shape [batch_size, 3, 3]
if stress.dim() == 2 and stress.shape[0] == len(atomic_data_list):
stress = stress.view(-1, 3, 3)
results["stress"] = stress

return results
Loading
Loading