diff --git a/environment.yml b/environment.yml index f65c4df..d7c1d44 100644 --- a/environment.yml +++ b/environment.yml @@ -17,3 +17,4 @@ dependencies: - pytest-cov - pytest-xdist - pytest-rerunfailures + - spyrmsd diff --git a/src/openfe_analysis/rmsd.py b/src/openfe_analysis/rmsd.py index 0b705d3..68f9617 100644 --- a/src/openfe_analysis/rmsd.py +++ b/src/openfe_analysis/rmsd.py @@ -5,12 +5,15 @@ import MDAnalysis as mda import netCDF4 as nc import numpy as np -from MDAnalysis.analysis import diffusionmap, rms +import spyrmsd.rmsd as srmsd +from MDAnalysis.analysis import rms from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.transformations import unwrap +from rdkit import Chem from .reader import FEReader from .transformations import Aligner, ClosestImageShift, NoJump +from .utils.universe_utils import guess_ligand_bonds, select_state_atoms def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Universe: @@ -216,6 +219,61 @@ def _single_frame(self) -> None: ) +class SymmetryCorrectedLigandRMSD(AnalysisBase): + """ + Symmetry-corrected 1D RMSD time series for a ligand AtomGroup. + + Parameters + ---------- + atomgroup : mda.AtomGroup + Ligand atoms to compute RMSD for. If ``rdmol`` is not provided, + bonds must be guessed on the atomgroup before instantiating this + class; use :func:`guess_ligand_bonds` for this purpose. + rdmol : Chem.Mol, optional + RDKit molecule corresponding to ``atomgroup``. If provided, it is + used directly and ``guess_ligand_bonds`` does not need to be called. + If ``None``, the RDKit molecule is derived from ``atomgroup`` via + ``convert_to("RDKIT")``. + """ + + _analysis_algorithm_is_parallelizable = False + + def __init__( + self, + atomgroup: mda.AtomGroup, + rdmol: Optional[Chem.Mol] = None, + **kwargs, + ): + super().__init__(atomgroup.universe.trajectory, **kwargs) + self._ag = atomgroup + self._mol = rdmol if rdmol is not None else atomgroup.convert_to("RDKIT") + self._aprops = np.array([atom.GetAtomicNum() for atom in self._mol.GetAtoms()]) + self._am = Chem.rdmolops.GetAdjacencyMatrix(self._mol) + + def _prepare(self): + self.results.rmsd = np.zeros(self.n_frames, dtype=np.float64) + # reference is taken from the first analyzed frame, not necessarily frame 0 + self._reference = self._ag.positions.copy() + self._isomorphisms: list | None = None + + def _single_frame(self) -> None: + frame_rmsd, isomorphisms, _ = srmsd._rmsd_isomorphic_core( + coords1=self._ag.positions.copy(), + coords2=self._reference, + aprops1=self._aprops, + aprops2=self._aprops, + am1=self._am, + am2=self._am, + center=False, + minimize=False, + isomorphisms=self._isomorphisms, + ) + self.results.rmsd[self._frame_index] = frame_rmsd + # cache isomorphisms after first frame to avoid redundant graph matching + if self._isomorphisms is None: + self._isomorphisms = isomorphisms + + class LigandCOMDrift(AnalysisBase): """ Ligand center-of-mass displacement from initial position. @@ -326,9 +384,9 @@ def gather_rms_data( # cheeky, but we can read the PDB topology once and reuse per universe # this then only hits the PDB file once for all replicas u = make_Universe(u_top._topology, ds, state=state_idx) - prot = u.select_atoms("protein and name CA") ligand = u.select_atoms("resname UNK") + state_lig = select_state_atoms(u, end_state="A").select_atoms("resname UNK") if prot: prot_rmsd = RMSDAnalysis(prot).run(step=skip) @@ -338,7 +396,9 @@ def gather_rms_data( output["protein_2D_RMSD"].append(prot_rmsd2d.results.rmsd2d) if ligand: - lig_rmsd = RMSDAnalysis(ligand, mass_weighted=True).run(step=skip) + # lig_rmsd = RMSDAnalysis(ligand, mass_weighted=True).run(step=skip) + guess_ligand_bonds(state_lig, delete_existing=True) + lig_rmsd = SymmetryCorrectedLigandRMSD(state_lig, mass_weighted=True).run(step=skip) output["ligand_RMSD"].append(lig_rmsd.results.rmsd) lig_com_drift = LigandCOMDrift(ligand).run(step=skip) diff --git a/src/openfe_analysis/tests/test_rmsd.py b/src/openfe_analysis/tests/test_rmsd.py index a9e8b09..4055dc8 100644 --- a/src/openfe_analysis/tests/test_rmsd.py +++ b/src/openfe_analysis/tests/test_rmsd.py @@ -4,13 +4,21 @@ import netCDF4 as nc import numpy as np import pytest +import spyrmsd.rmsd as srmsd from MDAnalysis.analysis import rms from MDAnalysis.lib.mdamath import make_whole from MDAnalysis.transformations import unwrap from numpy.testing import assert_allclose +from rdkit.Chem import rdmolops from openfe_analysis.reader import FEReader -from openfe_analysis.rmsd import gather_rms_data, make_Universe +from openfe_analysis.rmsd import ( + RMSDAnalysis, + SymmetryCorrectedLigandRMSD, + _select_state_ligand, + gather_rms_data, + make_Universe, +) from openfe_analysis.transformations import Aligner @@ -177,3 +185,77 @@ def test_ligand_com_continuity(mda_universe): assert max(jumps) < 5.0 u.trajectory.close() + + +def test_symmetry_corrected_ligand_rmsd_nonnegative(mda_universe): + """RMSD values must be non-negative for all frames.""" + u = mda_universe + state_lig = _select_state_ligand(u) + + result = SymmetryCorrectedLigandRMSD(state_lig).run() + + assert np.all(result.results.rmsd >= 0.0) + + +def test_symmetry_corrected_ligand_rmsd_zero_for_valid_swap(): + """ + For a water-like symmetric molecule, swapping the two equivalent H atoms + gives naive RMSD > 0 but SymmetryCorrectedLigandRMSD = 0. + """ + # Build a minimal universe with two frames: reference and swapped + coords_ref = np.array( + [ + [0.0, 0.0, 0.0], # O + [1.0, 0.0, 0.0], # H1 + [0.0, 1.0, 0.0], # H2 + ] + ) + coords_swapped = np.array( + [ + [0.0, 0.0, 0.0], # O + [0.0, 1.0, 0.0], # H2 in H1's slot + [1.0, 0.0, 0.0], # H1 in H2's slot + ] + ) + + u = mda.Universe.empty(3, trajectory=True) + u.add_TopologyAttr("elements", ["O", "H", "H"]) + u.add_TopologyAttr("names", ["O", "H1", "H2"]) + u.add_TopologyAttr("resnames", ["UNK"]) + u.add_TopologyAttr("resids", [1]) + u.load_new( + np.array([coords_ref, coords_swapped]), + order="fac", + ) + + ag = u.select_atoms("all") + + corrected = SymmetryCorrectedLigandRMSD(ag).run() + naive = RMSDAnalysis(ag).run() + + # Frame 0 is reference — both should be 0 + assert corrected.results.rmsd[0] == pytest.approx(0.0, abs=1e-5) + assert naive.results.rmsd[0] == pytest.approx(0.0, abs=1e-5) + + # Frame 1 is the swap — naive sees displacement, corrected sees zero + assert naive.results.rmsd[1] > 0.0 + assert corrected.results.rmsd[1] == pytest.approx(0.0, abs=1e-5) + + +def test_ligand_rmsd_mass_weighting_effect(simulation_skipped_nc, hybrid_system_skipped_pdb): + with nc.Dataset(simulation_skipped_nc) as ds: + u_top = mda.Universe(hybrid_system_skipped_pdb) + u = make_Universe(u_top._topology, ds, state=0) + ligand = u.select_atoms("resname UNK") + state_lig = _select_state_ligand(u) + + rmsd_full_mw = RMSDAnalysis(ligand, mass_weighted=True).run() + rmsd_full_no_mw = RMSDAnalysis(ligand, mass_weighted=False).run() + rmsd_state_mw = RMSDAnalysis(state_lig, mass_weighted=True).run() + rmsd_state_no_mw = RMSDAnalysis(state_lig, mass_weighted=False).run() + + print(f"Full ligand, mass weighted: {rmsd_full_mw.results.rmsd[:6]}") + print(f"Full ligand, no mass weighting: {rmsd_full_no_mw.results.rmsd[:6]}") + print(f"State ligand, mass weighted: {rmsd_state_mw.results.rmsd[:6]}") + print(f"State ligand, no mass weighting: {rmsd_state_no_mw.results.rmsd[:6]}") + print("Old expected: [0.0, 1.092039, 0.839234, 1.228383, 1.533331, 1.276798]") diff --git a/src/openfe_analysis/tests/utils/test_universe_utils.py b/src/openfe_analysis/tests/utils/test_universe_utils.py new file mode 100644 index 0000000..a386a59 --- /dev/null +++ b/src/openfe_analysis/tests/utils/test_universe_utils.py @@ -0,0 +1,172 @@ +import MDAnalysis as mda +import numpy as np +import pytest +from rdkit import Chem + +from openfe_analysis.rmsd import make_Universe +from openfe_analysis.utils.universe_utils import ( + correct_elements, + guess_ligand_bonds, + select_state_atoms, +) + + +@pytest.fixture +def universe(hybrid_system_skipped_pdb, simulation_skipped_nc): + u = make_Universe(hybrid_system_skipped_pdb, simulation_skipped_nc, state=0) + yield u + u.trajectory.close() + + +@pytest.fixture +def ligand_ag(universe): + return select_state_atoms(universe, end_state="A").select_atoms("resname UNK") + + +def test_guess_ligand_bonds_adds_bonds(ligand_ag): + """Bonds should be present on the atomgroup after guess_ligand_bonds.""" + original_count = len(ligand_ag.bonds) + # This also has stateB bond + assert original_count == 49 + guess_ligand_bonds(ligand_ag, delete_existing=True) + # Now only 48 stateA bonds + assert len(ligand_ag.bonds) == 48 + + +def test_guess_ligand_bonds_modifies_universe_inplace(ligand_ag): + """Bond topology should be reflected on the parent universe after guessing.""" + guess_ligand_bonds(ligand_ag) + universe_bonds = ligand_ag.universe.select_atoms("resname UNK").bonds + assert len(universe_bonds) > 0 + + +@pytest.mark.parametrize( + "end_state, expected_bfactors", + [ + ("A", (0.25, 0.5)), + ("B", (0.75, 0.5)), + ], +) +def test_select_state_atoms(universe, end_state, expected_bfactors): + """State selection should include state-unique and shared atoms.""" + state = select_state_atoms(universe, end_state=end_state) + assert len(state) > 0 + assert all(atom.bfactor in expected_bfactors for atom in state) + + +def test_select_state_atoms_invalid_state(universe): + """Invalid end_state should raise a ValueError.""" + with pytest.raises(ValueError, match="end_state must be 'A' or 'B'"): + select_state_atoms(universe, end_state="C") + + +def test_select_state_atoms_shared_atoms(universe): + """Shared atoms (bfactor 0.5) should appear in both state A and B selections.""" + state_a = select_state_atoms(universe, end_state="A") + state_b = select_state_atoms(universe, end_state="B") + shared_a = set(atom.ix for atom in state_a if atom.bfactor == 0.5) + shared_b = set(atom.ix for atom in state_b if atom.bfactor == 0.5) + assert shared_a == shared_b + + +def test_correct_elements_fixes_element(): + """correct_elements should update element where rdmol differs.""" + + # Build a minimal universe with a C atom + u = mda.Universe.empty(2, n_residues=1, trajectory=True) + u.add_TopologyAttr("elements", ["C", "C"]) # second atom is wrong + u.add_TopologyAttr("names", ["C1", "C2"]) + u.add_TopologyAttr("resnames", ["UNK"]) + u.add_TopologyAttr("resids", [1]) + u.load_new( + np.array([[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]]), + order="fac", + ) + ag = u.select_atoms("all") + + mol = Chem.RWMol() + mol.AddAtom(Chem.Atom(6)) # C + mol.AddAtom(Chem.Atom(7)) # N + rdmol = mol.GetMol() + + with pytest.warns(UserWarning, match="No atom_mapping provided"): + correct_elements(ag, rdmol) + + assert ag[0].element == "C" + assert ag[1].element == "N" + assert ag[1].name == "N" + + +def test_correct_elements_no_change_when_correct(): + """correct_elements should not modify atoms that already have correct elements.""" + + u = mda.Universe.empty(2, n_residues=1, trajectory=True) + u.add_TopologyAttr("elements", ["C", "N"]) + u.add_TopologyAttr("names", ["C1", "N1"]) + u.add_TopologyAttr("resnames", ["UNK"]) + u.add_TopologyAttr("resids", [1]) + u.load_new( + np.array([[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]]), + order="fac", + ) + ag = u.select_atoms("all") + + mol = Chem.RWMol() + mol.AddAtom(Chem.Atom(6)) # C + mol.AddAtom(Chem.Atom(7)) # N + rdmol = mol.GetMol() + + with pytest.warns(UserWarning, match="No atom_mapping provided"): + correct_elements(ag, rdmol) + + assert ag[0].element == "C" + assert ag[0].name == "C1" # name unchanged + assert ag[1].element == "N" + assert ag[1].name == "N1" # name unchanged + + +def test_correct_elements_with_atom_mapping(): + """correct_elements with atom_mapping should use mapping without warning.""" + + u = mda.Universe.empty(2, n_residues=1, trajectory=True) + u.add_TopologyAttr("elements", ["C", "C"]) # second atom is wrong + u.add_TopologyAttr("names", ["C1", "C2"]) + u.add_TopologyAttr("resnames", ["UNK"]) + u.add_TopologyAttr("resids", [1]) + u.load_new( + np.array([[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]]), + order="fac", + ) + ag = u.select_atoms("all") + + # rdmol has atoms in reverse order: N, C + mol = Chem.RWMol() + mol.AddAtom(Chem.Atom(7)) # N at rdmol index 0 + mol.AddAtom(Chem.Atom(6)) # C at rdmol index 1 + rdmol = mol.GetMol() + + # explicitly map ag index 0 -> rdmol index 1 (C), ag index 1 -> rdmol index 0 (N) + correct_elements(ag, rdmol, atom_mapping={0: 1, 1: 0}) + + assert ag[0].element == "C" # mapped to rdmol index 1 (C) + assert ag[1].element == "N" # mapped to rdmol index 0 (N) + assert ag[1].name == "N" + + +def test_correct_elements_raises_size_error(): + """correct_elements should raise ValueError if atom counts don't match.""" + + u = mda.Universe.empty(2, n_residues=1, trajectory=True) + u.add_TopologyAttr("elements", ["C", "N"]) + u.add_TopologyAttr("names", ["C1", "N1"]) + u.add_TopologyAttr("resnames", ["UNK"]) + u.add_TopologyAttr("resids", [1]) + u.load_new(np.array([[[0.0, 0.0, 0.0], [1.5, 0.0, 0.0]]]), order="fac") + ag = u.select_atoms("all") + + mol = Chem.RWMol() + mol.AddAtom(Chem.Atom(6)) # only 1 atom + rdmol = mol.GetMol() + + with pytest.raises(ValueError, match="atomgroup has 2 atoms but rdmol has 1"): + correct_elements(ag, rdmol) diff --git a/src/openfe_analysis/utils/universe_utils.py b/src/openfe_analysis/utils/universe_utils.py new file mode 100644 index 0000000..2950091 --- /dev/null +++ b/src/openfe_analysis/utils/universe_utils.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import warnings +from typing import Literal + +import MDAnalysis as mda +import numpy as np +from MDAnalysis.guesser.tables import vdwradii as MDA_VDWRADII +from rdkit import Chem + +# B-factor values used to identify atoms present at a given lambda state. +# 0.25 : atoms unique to state A +# 0.75 : atoms unique to state B +# 0.5 : atoms shared by both end states. +_BFACTOR_STATE_A = (0.25, 0.5) +_BFACTOR_STATE_B = (0.75, 0.5) + + +def select_state_atoms( + universe: mda.Universe, + end_state: Literal["A", "B"], +) -> mda.AtomGroup: + """ + Select all atoms present at a given end state. + + Atoms are identified by their b-factor values: + + - ``0.25`` — unique to state A + - ``0.75`` — unique to state B + - ``0.5`` — shared by both end states + + Parameters + ---------- + universe : mda.Universe + Universe containing the hybrid topology. + end_state : {"A", "B"} + The end state to select atoms for. + + Returns + ------- + mda.AtomGroup + All atoms present at the given end state. + + Raises + ------ + ValueError + If ``end_state`` is not ``"A"`` or ``"B"``. + """ + if end_state == "A": + bfactor_values = _BFACTOR_STATE_A + elif end_state == "B": + bfactor_values = _BFACTOR_STATE_B + else: + raise ValueError(f"end_state must be 'A' or 'B', got '{end_state}'") + + state_indices = np.array([atom.ix for atom in universe.atoms if atom.bfactor in bfactor_values]) + return universe.atoms[state_indices] + + +def guess_ligand_bonds( + atomgroup: mda.AtomGroup, + delete_existing: bool = False, +) -> None: + """ + Guess bonds for a ligand AtomGroup in-place. + + Parameters + ---------- + atomgroup : mda.AtomGroup + Ligand atoms for which bonds will be guessed. + delete_existing : bool, optional + If ``True``, delete existing bonds on the atomgroup before guessing. + This may be necessary to avoid cross-state bonds in hybrid topologies. + Default is ``False``. + """ + if delete_existing: + atomgroup.universe.delete_bonds(atomgroup.bonds) + # MDA vdw radii use uppercase element symbols (e.g. "CL", "BR", "NA"), + # but RDKit uses mixed case; add aliases so bond guessing works correctly + vdwradii = dict(MDA_VDWRADII) + vdwradii.update( + { + "Cl": vdwradii["CL"], + "Br": vdwradii["BR"], + "Na": vdwradii["NA"], + } + ) + atomgroup.guess_bonds(vdwradii) + + +def correct_elements( + atomgroup: mda.AtomGroup, + rdmol: Chem.Mol, + atom_mapping: dict[int, int] | None = None, +) -> None: + """ + Correct element and atom names in an AtomGroup in-place + using an RDKit molecule as the source of truth. + + This is needed for hybrid topologies where mapped atoms that + undergo element changes carry state A's element types, even when + state B's ligand is selected. + + Parameters + ---------- + atomgroup : mda.AtomGroup + Ligand atoms whose elements and names will be corrected. + rdmol : Chem.Mol + RDKit molecule with the correct element and atom name information. + atom_mapping : dict[int, int], optional + A mapping of ``{atomgroup_index: rdmol_index}`` defining the + correspondence between atoms in ``atomgroup`` and ``rdmol``. If + ``None``, atoms are matched by position which gives wrong results if + the atom order was not the same. + + Raises + ------ + ValueError + If the number of atoms in ``atomgroup`` and ``rdmol`` do not match. + """ + periodic_table = Chem.GetPeriodicTable() + + if len(atomgroup) != rdmol.GetNumAtoms(): + raise ValueError( + f"atomgroup has {len(atomgroup)} atoms but rdmol has {rdmol.GetNumAtoms()} atoms." + ) + + if atom_mapping is not None: + for ag_idx, rd_idx in atom_mapping.items(): + mda_atom = atomgroup[ag_idx] + rd_atom = rdmol.GetAtomWithIdx(rd_idx) + element = periodic_table.GetElementSymbol(rd_atom.GetAtomicNum()) + if mda_atom.element != element: + mda_atom.element = element + mda_atom.name = rd_atom.GetSymbol() + else: + warnings.warn( + "No atom_mapping provided to correct_elements. Assuming that " + "atom ordering is the same between atomgroup and rdmol. This may " + "give incorrect results if the atom ordering differs between the two.", + UserWarning, + ) + for mda_atom, rd_atom in zip(atomgroup, rdmol.GetAtoms()): + element = periodic_table.GetElementSymbol(rd_atom.GetAtomicNum()) + if mda_atom.element != element: + mda_atom.element = element + mda_atom.name = rd_atom.GetSymbol()