Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ dependencies:
- pytest-cov
- pytest-xdist
- pytest-rerunfailures
- spyrmsd
98 changes: 94 additions & 4 deletions src/openfe_analysis/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,20 @@
import MDAnalysis as mda
import netCDF4 as nc
import numpy as np
from MDAnalysis.analysis import diffusionmap, rms
import spyrmsd.rmsd as srmsd
from MDAnalysis.analysis import rms
from MDAnalysis.analysis.base import AnalysisBase
from MDAnalysis.guesser.tables import vdwradii as MDA_VDWRADII
from MDAnalysis.transformations import unwrap
from rdkit.Chem import rdmolops

from .reader import FEReader
from .transformations import Aligner, ClosestImageShift, NoJump

# B-factor values used to identify atoms present at a given lambda state.
# 0.25 marks atoms unique to one end state, 0.5 marks atoms shared by both.
_BFACTOR_STATE_VALUES = (0.25, 0.5)


def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Universe:
"""
Expand Down Expand Up @@ -204,6 +211,67 @@ def _conclude(self):
self.results.rmsd = np.asarray(self.results.rmsd)


class SymmetryCorrectedLigandRMSD(AnalysisBase):
"""
1D RMSD time series for an AtomGroup.

Parameters
----------
atomgroup : MDAnalysis.AtomGroup
Atoms to compute RMSD for.
mass_weighted : bool, optional
If True, compute mass-weighted RMSD.
"""

def __init__(self, atomgroup, mass_weighted=False, **kwargs):
super().__init__(atomgroup.universe.trajectory, **kwargs)
self._ag = atomgroup
self._mass_weighted = mass_weighted
self._isomorphisms = None

vdwradii = dict(MDA_VDWRADII)
vdwradii.update(
{
"Cl": vdwradii["CL"],
"Br": vdwradii["BR"],
"Na": vdwradii["NA"],
}
)

atomgroup.guess_bonds(vdwradii)
self._mol = atomgroup.convert_to("RDKIT")
self._aprops = np.array([atom.GetAtomicNum() for atom in self._mol.GetAtoms()])
self._am = rdmolops.GetAdjacencyMatrix(self._mol)

def _prepare(self):
self.results.rmsd = []
self._reference = self._ag.positions.copy()

if self._mass_weighted:
self._weights = self._ag.masses / np.mean(self._ag.masses)
else:
self._weights = None

def _single_frame(self):
frame_rmsd, isomorphisms, _ = srmsd._rmsd_isomorphic_core(
coords1=self._ag.positions.copy(),
coords2=self._reference,
aprops1=self._aprops,
aprops2=self._aprops,
am1=self._am,
am2=self._am,
center=False,
minimize=False,
isomorphisms=self._isomorphisms,
)
self.results.rmsd.append(frame_rmsd)
if self._isomorphisms is None:
self._isomorphisms = isomorphisms

def _conclude(self):
self.results.rmsd = np.asarray(self.results.rmsd)


class LigandCOMDrift(AnalysisBase):
"""
Ligand center-of-mass displacement from initial position.
Expand All @@ -230,6 +298,27 @@ def _conclude(self):
self.results.com_drift = np.asarray(self.results.com_drift)


def _select_state_ligand(u: mda.Universe) -> mda.AtomGroup:
"""
Select ligand atoms that are present at the current lambda state.

Atoms are identified by their b-factor values: ``0.25`` marks atoms
unique to one end state and ``0.5`` marks atoms shared by both end
states. Only atoms with these b-factor values and residue name "UNK"
are included.

Parameters
----------
u : mda.Universe

Returns
-------
MDAnalysis.AtomGroup
"""
state_indices = np.array([atom.ix for atom in u.atoms if atom.bfactor in _BFACTOR_STATE_VALUES])
return u.atoms[state_indices].select_atoms("resname UNK")


def gather_rms_data(
pdb_topology: pathlib.Path,
dataset: pathlib.Path,
Expand Down Expand Up @@ -299,9 +388,9 @@ def gather_rms_data(
# cheeky, but we can read the PDB topology once and reuse per universe
# this then only hits the PDB file once for all replicas
u = make_Universe(u_top._topology, ds, state=state_idx)

prot = u.select_atoms("protein and name CA")
ligand = u.select_atoms("resname UNK")
state_lig = _select_state_ligand(u)

if prot:
prot_rmsd = RMSDAnalysis(prot).run(step=skip)
Expand All @@ -328,8 +417,9 @@ def gather_rms_data(
# flattened = dist_mat[i, j]
# output["protein_2D_RMSD"].append(flattened)

if ligand.n_atoms > 0:
lig_rmsd = RMSDAnalysis(ligand, mass_weighted=True).run(step=skip)
if state_lig.n_atoms > 0:
# lig_rmsd = RMSDAnalysis(ligand, mass_weighted=True).run(step=skip)
lig_rmsd = SymmetryCorrectedLigandRMSD(state_lig, mass_weighted=True).run(step=skip)
output["ligand_RMSD"].append(lig_rmsd.results.rmsd)
# # Using the MDAnalysis RMSD class instead
# groupselections = ["resname UNK"]
Expand Down
84 changes: 83 additions & 1 deletion src/openfe_analysis/tests/test_rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,21 @@
import netCDF4 as nc
import numpy as np
import pytest
import spyrmsd.rmsd as srmsd
from MDAnalysis.analysis import rms
from MDAnalysis.lib.mdamath import make_whole
from MDAnalysis.transformations import unwrap
from numpy.testing import assert_allclose
from rdkit.Chem import rdmolops

from openfe_analysis.reader import FEReader
from openfe_analysis.rmsd import gather_rms_data, make_Universe
from openfe_analysis.rmsd import (
RMSDAnalysis,
SymmetryCorrectedLigandRMSD,
_select_state_ligand,
gather_rms_data,
make_Universe,
)
from openfe_analysis.transformations import Aligner


Expand Down Expand Up @@ -177,3 +185,77 @@ def test_ligand_com_continuity(mda_universe):

assert max(jumps) < 5.0
u.trajectory.close()


def test_symmetry_corrected_ligand_rmsd_nonnegative(mda_universe):
"""RMSD values must be non-negative for all frames."""
u = mda_universe
state_lig = _select_state_ligand(u)

result = SymmetryCorrectedLigandRMSD(state_lig).run()

assert np.all(result.results.rmsd >= 0.0)


def test_symmetry_corrected_ligand_rmsd_zero_for_valid_swap():
"""
For a water-like symmetric molecule, swapping the two equivalent H atoms
gives naive RMSD > 0 but SymmetryCorrectedLigandRMSD = 0.
"""
# Build a minimal universe with two frames: reference and swapped
coords_ref = np.array(
[
[0.0, 0.0, 0.0], # O
[1.0, 0.0, 0.0], # H1
[0.0, 1.0, 0.0], # H2
]
)
coords_swapped = np.array(
[
[0.0, 0.0, 0.0], # O
[0.0, 1.0, 0.0], # H2 in H1's slot
[1.0, 0.0, 0.0], # H1 in H2's slot
]
)

u = mda.Universe.empty(3, trajectory=True)
u.add_TopologyAttr("elements", ["O", "H", "H"])
u.add_TopologyAttr("names", ["O", "H1", "H2"])
u.add_TopologyAttr("resnames", ["UNK"])
u.add_TopologyAttr("resids", [1])
u.load_new(
np.array([coords_ref, coords_swapped]),
order="fac",
)

ag = u.select_atoms("all")

corrected = SymmetryCorrectedLigandRMSD(ag).run()
naive = RMSDAnalysis(ag).run()

# Frame 0 is reference — both should be 0
assert corrected.results.rmsd[0] == pytest.approx(0.0, abs=1e-5)
assert naive.results.rmsd[0] == pytest.approx(0.0, abs=1e-5)

# Frame 1 is the swap — naive sees displacement, corrected sees zero
assert naive.results.rmsd[1] > 0.0
assert corrected.results.rmsd[1] == pytest.approx(0.0, abs=1e-5)


def test_ligand_rmsd_mass_weighting_effect(simulation_skipped_nc, hybrid_system_skipped_pdb):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is just here for trouble shooting of the differences between the old RMSD and symmetry corrected. Will remove once things are sorted out.

with nc.Dataset(simulation_skipped_nc) as ds:
u_top = mda.Universe(hybrid_system_skipped_pdb)
u = make_Universe(u_top._topology, ds, state=0)
ligand = u.select_atoms("resname UNK")
state_lig = _select_state_ligand(u)

rmsd_full_mw = RMSDAnalysis(ligand, mass_weighted=True).run()
rmsd_full_no_mw = RMSDAnalysis(ligand, mass_weighted=False).run()
rmsd_state_mw = RMSDAnalysis(state_lig, mass_weighted=True).run()
rmsd_state_no_mw = RMSDAnalysis(state_lig, mass_weighted=False).run()

print(f"Full ligand, mass weighted: {rmsd_full_mw.results.rmsd[:6]}")
print(f"Full ligand, no mass weighting: {rmsd_full_no_mw.results.rmsd[:6]}")
print(f"State ligand, mass weighted: {rmsd_state_mw.results.rmsd[:6]}")
print(f"State ligand, no mass weighting: {rmsd_state_no_mw.results.rmsd[:6]}")
print("Old expected: [0.0, 1.092039, 0.839234, 1.228383, 1.533331, 1.276798]")