diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e13b1e0c66..fddbbf5b9b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -124,3 +124,63 @@ jobs: name: codecov-umbrella verbose: true fail_ci_if_error: false + + rits-install: + name: RitS install (CPU lane) + runs-on: ubuntu-latest + defaults: + run: + shell: bash -el {0} + + steps: + - name: Checkout ARC + uses: actions/checkout@v6 + with: + path: ARC + + - name: Clean Ubuntu Image + uses: jlumbroso/free-disk-space@main + with: + tool-cache: true + android: true + dotnet: true + haskell: true + large-packages: true + swap-storage: true + + - name: Set up micromamba + uses: mamba-org/setup-micromamba@v2 + with: + micromamba-version: 'latest' + init-shell: bash + + - name: Restore RitS cache (clone + rits_env + checkpoint) + id: rits-cache + uses: actions/cache@v4 + with: + path: | + ${{ github.workspace }}/RitS + ~/micromamba/envs/rits_env + key: rits-cpu-${{ runner.os }}-${{ hashFiles('ARC/devtools/install_rits.sh') }} + + - name: Install RitS (CPU) + shell: bash -el {0} + working-directory: ${{ github.workspace }} + run: | + # Use the existing RitS clone from the cache if present, else clone fresh. + if [[ -d RitS/.git ]]; then + bash ARC/devtools/install_rits.sh --cpu --path "$PWD/RitS" + else + bash ARC/devtools/install_rits.sh --cpu + fi + + - name: Smoke-test 'import megalodon' from rits_env + shell: bash -el {0} + run: | + micromamba run -n rits_env python - <<'PYEOF' + import megalodon, torch, torch_geometric + print("megalodon:", megalodon.__file__) + print("torch :", torch.__version__) + print("pyg :", torch_geometric.__version__) + print("cuda? :", torch.cuda.is_available()) + PYEOF diff --git a/Makefile b/Makefile index ff5b1e7091..0853ec5fcf 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ DEVTOOLS_DIR := devtools .PHONY: all help clean test test-unittests test-functional test-all \ install-all install-ci install-pyrdl install-rmg install-rmgdb install-autotst install-gcn \ install-gcn-cpu install-kinbot install-sella install-xtb install-torchani install-ob \ - lite check-env compile + install-rits lite check-env compile # Default target @@ -38,6 +38,7 @@ help: @echo " install-xtb Install xTB" @echo " install-torchani Install TorchANI" @echo " install-ob Install OpenBabel" + @echo " install-rits Install RitS (TS guesser, ~3 GB env + 364 MB checkpoint)" @echo "" @echo "Maintenance:" @echo " lite Run lite installation (no tests)" @@ -66,8 +67,8 @@ install: bash $(DEVTOOLS_DIR)/install_all.sh --rmg-rms install-ci: - @echo "Installing all external ARC dependencies for CI (no clean)..." - bash $(DEVTOOLS_DIR)/install_all.sh --no-clean + @echo "Installing all external ARC dependencies for CI (no clean, no RitS — RitS runs in its own CI lane)..." + bash $(DEVTOOLS_DIR)/install_all.sh --no-clean --no-rits install-lite: @echo "Installing ARC's lite version (no external dependencies)..." @@ -106,6 +107,9 @@ install-torchani: install-ob: bash $(DEVTOOLS_DIR)/install_ob.sh +install-rits: + bash $(DEVTOOLS_DIR)/install_rits.sh + lite: bash $(DEVTOOLS_DIR)/lite.sh diff --git a/arc/common.py b/arc/common.py index 32df575376..b32aa56276 100644 --- a/arc/common.py +++ b/arc/common.py @@ -141,7 +141,7 @@ def check_ess_settings(ess_settings: Optional[dict] = None) -> dict: f'strings. Got: {server_list} which is a {type(server_list)}') # run checks: for ess, server_list in settings_dict.items(): - if ess.lower() not in supported_ess + ['gcn', 'heuristics', 'autotst', 'kinbot', 'xtb_gsm', 'orca_neb']: + if ess.lower() not in supported_ess + ['gcn', 'heuristics', 'autotst', 'kinbot', 'rits', 'xtb_gsm', 'orca_neb']: raise SettingsError(f'Recognized ESS software are {supported_ess}. Got: {ess}') for server in server_list: if not isinstance(server, bool) and server.lower() not in [s.lower() for s in servers.keys()]: diff --git a/arc/job/adapter.py b/arc/job/adapter.py index de8c747718..16a1a9487a 100644 --- a/arc/job/adapter.py +++ b/arc/job/adapter.py @@ -98,6 +98,7 @@ class JobEnum(str, Enum): heuristics = 'heuristics' # ARC's heuristics kinbot = 'kinbot' # KinBot, 10.1016/j.cpc.2019.106947 gcn = 'gcn' # Graph neural network for isomerization, https://doi.org/10.1021/acs.jpclett.0c00500 + rits = 'rits' # Right into the Saddle, flow-matching TS generator, https://github.com/isayevlab/RitS, 10.26434/chemrxiv.15001681/v1 user = 'user' # user guesses xtb_gsm = 'xtb_gsm' # Double ended growing string method (DE-GSM), [10.1021/ct400319w, 10.1063/1.4804162] via xTB orca_neb = 'orca_neb' diff --git a/arc/job/adapters/common.py b/arc/job/adapters/common.py index 82a8db0c40..b44da783a9 100644 --- a/arc/job/adapters/common.py +++ b/arc/job/adapters/common.py @@ -73,12 +73,12 @@ 'Singlet_Carbene_Intra_Disproportionation': ['gcn', 'xtb_gsm', 'orca_neb'], } -all_families_ts_adapters = [] +all_families_ts_adapters = ['rits'] adapters_that_do_not_require_a_level_arg = ['xtb', 'torchani'] # Default is "queue", "pipe" will be called whenever needed. So just list 'incore'. -default_incore_adapters = ['autotst', 'crest', 'gcn', 'heuristics', 'kinbot', 'psi4', 'xtb', 'xtb_gsm', 'torchani', - 'openbabel'] +default_incore_adapters = ['autotst', 'crest', 'gcn', 'heuristics', 'kinbot', 'psi4', 'rits', + 'xtb', 'xtb_gsm', 'torchani', 'openbabel'] def _initialize_adapter(obj: 'JobAdapter', diff --git a/arc/job/adapters/scripts/rits_script.py b/arc/job/adapters/scripts/rits_script.py new file mode 100644 index 0000000000..1108e6cdc8 --- /dev/null +++ b/arc/job/adapters/scripts/rits_script.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +A standalone script to run RitS (Right into the Saddle) and emit TS guesses +as a YAML file consumable by ARC's RitSAdapter. + +This script must be invoked from inside the ``rits_env`` conda environment +(it does NOT import ``megalodon`` directly — RitS's own +``scripts/sample_transition_state.py`` does that). The parent ARC process +shells out to this script via ``subprocess.run`` so that ARC's main env +does not have to carry the heavy ML dependency stack. + +Input file (``input.yml``) +-------------------------- +Required keys: + reactant_xyz_path : str absolute path to a plain XYZ file (atom-mapped) + product_xyz_path : str absolute path to the matching product XYZ + rits_repo_path : str absolute path to the RitS source checkout + ckpt_path : str absolute path to the pretrained ``rits.ckpt`` + output_xyz_path : str absolute path RitS should write its raw output to + yml_out_path : str absolute path this script writes the parsed TSGuess list to + +Optional keys (with defaults): + config_path : str defaults to ``/scripts/conf/rits.yaml`` + n_samples : int default 10 + batch_size : int default 32 + charge : int default 0 + device : str default 'auto' (RitS picks GPU if visible, else CPU) + add_stereo : bool default False + num_steps : int default None (use config value) + +Output (``yml_out_path``) +------------------------- +A YAML *list* of TSGuess dictionaries. Each entry has: + method : 'RitS' + method_direction : 'F' + method_index : int (0-based sample index) + initial_xyz : str (XYZ-format coordinate block, no header lines) + success : bool + execution_time : str (str(datetime.timedelta)) + +If RitS fails to produce any usable output, the script writes a list with a +single failed-guess entry instead of raising — the parent adapter then logs +the failure but continues running other TS methods. +""" + +import argparse +import datetime +import os +import subprocess +import sys +import traceback +from typing import List, Optional + +import yaml + + +def read_yaml_file(path: str) -> dict: + """Read a YAML file and return its contents as a dict.""" + with open(path, 'r') as f: + return yaml.load(stream=f, Loader=yaml.FullLoader) + + +def string_representer(dumper, data): + """YAML representer that uses block literals for multi-line strings.""" + if len(data.splitlines()) > 1: + return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data, style='|') + return dumper.represent_scalar(tag='tag:yaml.org,2002:str', value=data) + + +def save_yaml_file(path: str, content) -> None: + """Save ``content`` to a YAML file at ``path``.""" + yaml.add_representer(str, string_representer) + with open(path, 'w') as f: + f.write(yaml.dump(data=content)) + + +def parse_multi_frame_xyz(xyz_path: str) -> List[str]: + """ + Parse a (possibly multi-frame) XYZ file into a list of coordinate-block strings. + + RitS writes a single XYZ file when ``--n_samples == 1`` and a multi-frame + XYZ when ``n_samples > 1`` (frames concatenated, each prefixed by an atom + count line and a blank/comment line). This parser handles both. + + Args: + xyz_path (str): Path to the XYZ file emitted by RitS. + + Returns: + List[str]: One coordinate block per frame, suitable for passing to + ``arc.species.converter.str_to_xyz`` (atom symbols + xyz only — no + header / comment lines). + """ + if not os.path.isfile(xyz_path): + return list() + with open(xyz_path, 'r') as f: + raw_lines = [line.rstrip('\n') for line in f] + frames = list() + i, n = 0, len(raw_lines) + while i < n: + # Skip blank lines between frames + while i < n and not raw_lines[i].strip(): + i += 1 + if i >= n: + break + # First non-blank line of a frame should be the atom count + try: + n_atoms = int(raw_lines[i].strip()) + except ValueError: + # Not a frame header — bail on this row to avoid an infinite loop + i += 1 + continue + i += 1 + # Comment / energy line (may be blank) + if i < n: + i += 1 + # The next n_atoms lines are coordinates + coord_lines = list() + for _ in range(n_atoms): + if i >= n: + break + coord_lines.append(raw_lines[i]) + i += 1 + if len(coord_lines) == n_atoms: + frames.append('\n'.join(coord_lines)) + return frames + + +def run_rits(input_dict: dict) -> List[dict]: + """ + Invoke ``scripts/sample_transition_state.py`` from the RitS source tree + and parse the resulting XYZ frames into a list of TSGuess dictionaries. + + Args: + input_dict (dict): The parsed contents of ``input.yml``. + + Returns: + List[dict]: One TSGuess-shaped dict per generated sample. Always at + least one entry — a failed sentinel if RitS produced nothing. + """ + repo = input_dict['rits_repo_path'] + sample_script = os.path.join(repo, 'scripts', 'sample_transition_state.py') + config_path = input_dict.get('config_path') or os.path.join(repo, 'scripts', 'conf', 'rits.yaml') + output_xyz = input_dict['output_xyz_path'] + n_samples = int(input_dict.get('n_samples', 10)) + batch_size = int(input_dict.get('batch_size', 32)) + charge = int(input_dict.get('charge', 0)) + device = str(input_dict.get('device', 'auto')) + add_stereo = bool(input_dict.get('add_stereo', False)) + num_steps = input_dict.get('num_steps') + + cmd = [ + sys.executable, sample_script, + '--reactant_xyz', input_dict['reactant_xyz_path'], + '--product_xyz', input_dict['product_xyz_path'], + '--config', config_path, + '--ckpt', input_dict['ckpt_path'], + '--output', output_xyz, + '--n_samples', str(n_samples), + '--batch_size', str(batch_size), + '--charge', str(charge), + '--device', device, + ] + if add_stereo: + cmd.append('--add_stereo') + if num_steps is not None: + cmd.extend(['--num_steps', str(num_steps)]) + + t0 = datetime.datetime.now() + print(f'[rits_script] running: {" ".join(cmd)}', flush=True) + completed = subprocess.run(cmd, cwd=repo) + elapsed = datetime.datetime.now() - t0 + + if completed.returncode != 0: + print(f'[rits_script] sample_transition_state.py exited with code {completed.returncode}', flush=True) + return [_failed_guess(elapsed, index=0)] + + frames = parse_multi_frame_xyz(output_xyz) + if not frames: + print(f'[rits_script] no frames parsed from {output_xyz}', flush=True) + return [_failed_guess(elapsed, index=0)] + + tsgs = list() + for i, coord_block in enumerate(frames): + tsgs.append({ + 'method': 'RitS', + 'method_direction': 'F', + 'method_index': i, + 'initial_xyz': coord_block, + 'success': True, + 'execution_time': str(elapsed), + }) + return tsgs + + +def _failed_guess(elapsed: datetime.timedelta, index: int = 0) -> dict: + """Return a failed-TSGuess sentinel dict.""" + return { + 'method': 'RitS', + 'method_direction': 'F', + 'method_index': index, + 'initial_xyz': None, + 'success': False, + 'execution_time': str(elapsed), + } + + +def parse_command_line_arguments(command_line_args: Optional[list] = None) -> argparse.Namespace: + """Parse the script's command-line arguments.""" + parser = argparse.ArgumentParser(description='Run RitS to generate TS guesses for an ARC reaction.') + parser.add_argument('--yml_in_path', metavar='input', type=str, default='input.yml', + help='Path to the input YAML file (default: ./input.yml).') + return parser.parse_args(command_line_args) + + +def main(): + """Entry point: read input.yml, run RitS, write output YAML.""" + args = parse_command_line_arguments() + yml_in_path = str(args.yml_in_path) + if not os.path.isfile(yml_in_path): + print(f'[rits_script] input file not found: {yml_in_path}', file=sys.stderr) + sys.exit(1) + input_dict = read_yaml_file(yml_in_path) + + try: + tsgs = run_rits(input_dict) + except Exception: + traceback.print_exc() + tsgs = [_failed_guess(datetime.timedelta(0), index=0)] + + save_yaml_file(path=input_dict['yml_out_path'], content=tsgs) + n_ok = sum(1 for tsg in tsgs if tsg.get('success')) + print(f'[rits_script] wrote {len(tsgs)} TSGuess entries ({n_ok} successful) to {input_dict["yml_out_path"]}', + flush=True) + + +if __name__ == '__main__': + main() diff --git a/arc/job/adapters/ts/__init__.py b/arc/job/adapters/ts/__init__.py index 5d571e8e80..115859da7a 100644 --- a/arc/job/adapters/ts/__init__.py +++ b/arc/job/adapters/ts/__init__.py @@ -2,5 +2,6 @@ import arc.job.adapters.ts.gcn_ts import arc.job.adapters.ts.heuristics import arc.job.adapters.ts.kinbot_ts +import arc.job.adapters.ts.rits_ts import arc.job.adapters.ts.xtb_gsm import arc.job.adapters.ts.orca_neb diff --git a/arc/job/adapters/ts/rits_test.py b/arc/job/adapters/ts/rits_test.py new file mode 100644 index 0000000000..6221fcf211 --- /dev/null +++ b/arc/job/adapters/ts/rits_test.py @@ -0,0 +1,975 @@ +#!/usr/bin/env python3 +# encoding: utf-8 + +""" +Unit tests for the RitS TS-guess adapter (``arc.job.adapters.ts.rits_ts``). + +Tier-1 (always runs): + * settings resolution and finder helpers + * pure-Python helpers: ``write_xyz_file``, ``parse_multi_frame_xyz``, + ``process_rits_tsg`` dedup + * adapter instantiation with ``testing=True``, file-path layout + * graceful skip when ``rits_env`` / checkpoint are missing + * input.yml writer (mocked subprocess) + +Tier-2 (gated on ``_rits_environment_ready()``): + * end-to-end ``execute_incore`` against the real ``rits_env`` for a + handful of family-diverse reactions sourced from + ``arc/job/adapters/ts/linear_test.py``. + +The Tier-2 tests are skipped automatically on CI runners that did not run +``install_rits.sh`` — the matching CI lane (``rits-install`` in +``.github/workflows/ci.yml``) installs the env and exercises them. +""" + +import importlib +import math +import os +import shutil +import sys +import unittest +from collections import Counter +from unittest import mock + +import arc.job.adapters.ts.rits_ts as rits_mod +from arc.common import ARC_TESTING_PATH, read_yaml_file +from arc.job.adapters.ts.rits_ts import (RitSAdapter, + _rits_environment_ready, + process_rits_tsg, + write_xyz_file, + ) +from arc.reaction import ARCReaction +from arc.species.converter import str_to_xyz, compare_confs +from arc.species.species import ARCSpecies, TSGuess + +HAS_RITS = _rits_environment_ready() + + +def _build_rxn_isomerization_propyl(): + """nC3H7 → iC3H7. The simplest isomerization in ARC's test suite.""" + return ARCReaction(r_species=[ARCSpecies(label='nC3H7', smiles='[CH2]CC')], + p_species=[ARCSpecies(label='iC3H7', smiles='C[CH]C')]) + + +def _build_rxn_diels_alder(): + """C=CC(=C)C + C=CC=O → CC1=CCC(C=O)CC1 — bimolecular Diels-Alder.""" + r1_xyz = """C 1.97753426 -0.34691463 -0.12195850 +C 0.96032171 0.45485914 -0.46215363 +C -0.43629664 0.27157147 -0.09968556 +C -1.35584640 1.15966116 -0.51269091 +C -0.83651671 -0.91436221 0.73635894 +H 2.98719352 -0.11575642 -0.44772907 +H 1.84910220 -1.24076974 0.47792776 +H 1.19368072 1.33006788 -1.06832846 +H -2.40510842 1.04750710 -0.25687679 +H -1.09525737 2.02366247 -1.11636739 +H -0.32888591 -0.89422114 1.70676182 +H -1.91408642 -0.93005704 0.93479551 +H -0.58767904 -1.85093188 0.22577726""" + r2_xyz = """C -1.22034116 -0.10890246 0.02353603 +C -0.04004107 0.51094374 -0.08149118 +C 1.22322531 -0.24393463 0.03286276 +O 2.30875132 0.31445302 -0.06186255 +H -1.30612429 -1.17741471 0.19480533 +H -2.14393224 0.45618508 -0.06217786 +H 0.04657041 1.57753840 -0.25245803 +H 1.13189173 -1.32886845 0.20678550""" + p_xyz = """C 2.60098776 -0.04177774 0.73723478 +C 1.20465630 0.10105432 0.20245819 +C 0.16278370 -0.55312927 0.74494799 +C -1.24024239 -0.46705077 0.21761600 +C -1.33954822 0.16452081 -1.17701034 +C -1.06935354 -0.87644399 -2.25040126 +O -0.50075393 -0.64415323 -3.31363975 +C -0.41124651 1.37364733 -1.29938488 +C 1.04721460 1.02438027 -0.98148987 +H 3.26841747 -0.42094194 -0.04336972 +H 2.64920967 -0.73532885 1.58328037 +H 2.97843218 0.92762356 1.07822844 +H 0.31418172 -1.19138627 1.61332708 +H -1.82762138 0.12846013 0.92764672 +H -1.67646259 -1.47309646 0.21384290 +H -2.37737283 0.48324136 -1.33650826 +H -1.50255476 -1.87505625 -2.06737417 +H -0.75069363 2.15000964 -0.60076538 +H -0.46865428 1.81280411 -2.30253884 +H 1.51473571 0.55339822 -1.85465668 +H 1.59082870 1.95894204 -0.79688170""" + r1 = ARCSpecies(label='R1', smiles='C=CC(=C)C', xyz=r1_xyz) + r2 = ARCSpecies(label='R2', smiles='C=CC=O', xyz=r2_xyz) + p = ARCSpecies(label='P', smiles='CC1=CCC(C=O)CC1', xyz=p_xyz) + return ARCReaction(r_species=[r1, r2], p_species=[p]) + + +def _build_rxn_one_plus_two_cycloaddition(): + """Singlet CH2 + C=C=C → C=C1CC1 — bimolecular addition with carbene.""" + ch2_xyz = """C 0.00000000 0.00000000 0.10513200 +H 0.00000000 0.98826300 -0.31539600 +H 0.00000000 -0.98826300 -0.31539600""" + c3h4_xyz = """C 1.29697653 0.02233190 0.00658756 +C 0.00000000 -0.00000034 0.00000210 +C -1.29697654 -0.02233198 -0.00658580 +H 1.86532844 -0.70256077 -0.56460908 +H 1.83420869 0.76626329 0.58339481 +H -1.85591941 0.54211003 -0.74397783 +H -1.84361771 -0.60581213 0.72518823""" + c4h6_xyz = """C 1.59999925 -0.11618654 -0.14166302 +C 0.29517860 -0.02143486 -0.02613492 +C -0.92013120 -0.71833111 0.10894610 +C -0.81238032 0.84414025 0.04444949 +H 2.21797993 0.77036923 -0.22897655 +H 2.09015362 -1.08321135 -0.15246324 +H -1.12327237 -1.17593811 1.06705013 +H -1.28992770 -1.23997489 -0.76270297 +H -0.94547237 1.40230195 0.96062403 +H -1.11212744 1.33826544 -0.86912905""" + r1 = ARCSpecies(label='CH2_singlet', adjlist="""multiplicity 1 +1 C u0 p1 c0 {2,S} {3,S} +2 H u0 p0 c0 {1,S} +3 H u0 p0 c0 {1,S} +""", xyz=ch2_xyz) + r2 = ARCSpecies(label='allene', smiles='C=C=C', xyz=c3h4_xyz) + p = ARCSpecies(label='methylene_cyclopropane', smiles='C=C1CC1', xyz=c4h6_xyz) + return ARCReaction(r_species=[r1, r2], p_species=[p]) + + +def _build_rxn_nh3_elimination(): + """NNN → H2NN(s) + NH3 — 1 reactant → 2 products elimination.""" + n3_xyz = """N -1.26709244 -0.00392551 -0.17821516 +N -0.00831159 0.62912211 -0.22607923 +N -0.03650217 1.66537185 0.72488290 +H -1.36396603 -0.52480010 0.69598616 +H -1.33497366 -0.72150540 -0.90528855 +H 0.20276134 1.00409437 -1.16407646 +H 0.01517757 1.28943240 1.67165685 +H -0.93213409 2.15501337 0.67312449""" + h2nn_xyz = """N 1.24087876 0.00949543 0.60790318 +N -0.09033762 -0.00069128 0.02459641 +H -0.47927195 -0.84665038 -0.39226764 +H -0.67126919 0.83784623 0.01648883""" + nh3_xyz = """N 0.00064924 -0.00099698 0.29559292 +H -0.41786606 0.84210396 -0.09477452 +H -0.52039228 -0.78225292 -0.10002797 +H 0.93760911 -0.05885406 -0.10079043""" + r = ARCSpecies(label='triazene', smiles='NNN', xyz=n3_xyz) + p1 = ARCSpecies(label='H2NNs', adjlist="""multiplicity 1 +1 N u0 p0 c+1 {2,S} {3,S} {4,D} +2 H u0 p0 c0 {1,S} +3 H u0 p0 c0 {1,S} +4 N u0 p2 c-1 {1,D} +""", xyz=h2nn_xyz) + p2 = ARCSpecies(label='NH3', smiles='N', xyz=nh3_xyz) + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +# === Group A: 1<->1 isomerizations ========================================= + +def _build_rxn_vinyl_alcohol_to_acetaldehyde(): + """Keto-enol tautomerization C2H4O: C=CO -> CC=O (6 atoms, 1,3-H shift).""" + r = ARCSpecies(label='vinyl_alcohol', smiles='C=CO') + p = ARCSpecies(label='acetaldehyde', smiles='CC=O') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_propenol_to_acetone(): + """Keto-enol tautomerization C3H6O: OC(=C)C -> CC(=O)C (10 atoms).""" + r = ARCSpecies(label='propen_2_ol', smiles='OC(=C)C') + p = ARCSpecies(label='acetone', smiles='CC(=O)C') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_cyclobutene_to_butadiene(): + """Electrocyclic ring opening C4H6: C1=CCC1 -> C=CC=C (10 atoms).""" + r = ARCSpecies(label='cyclobutene', smiles='C1=CCC1') + p = ARCSpecies(label='1_3_butadiene', smiles='C=CC=C') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_methoxy_to_hydroxymethyl(): + """1,2-H migration in CH3O radical: [O]C -> O[CH2] (5 atoms).""" + r = ARCSpecies(label='methoxy', smiles='[O]C') + p = ARCSpecies(label='hydroxymethyl', smiles='O[CH2]') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_ethoxy_to_alpha_hydroxyethyl(): + """1,2-H migration in CH3CH2O radical: CC[O] -> [CH2]CO (8 atoms).""" + r = ARCSpecies(label='ethoxy', smiles='CC[O]') + p = ARCSpecies(label='alpha_hydroxyethyl', smiles='[CH2]CO') + return ARCReaction(r_species=[r], p_species=[p]) + + +def _build_rxn_cyclopropane_to_propene(): + """Ring opening C3H6: C1CC1 -> C=CC (9 atoms).""" + r = ARCSpecies(label='cyclopropane', smiles='C1CC1') + p = ARCSpecies(label='propene', smiles='C=CC') + return ARCReaction(r_species=[r], p_species=[p]) + + +# === Group B: 1<->2 / 2<->1 (eliminations / cycloadditions) ================ + +def _build_rxn_cyclobutane_retro_22(): + """Retro [2+2] C4H8 -> 2 C2H4 (cyclobutane -> 2 ethene), 12 atoms.""" + r = ARCSpecies(label='cyclobutane', smiles='C1CCC1') + p1 = ARCSpecies(label='ethene_a', smiles='C=C') + p2 = ARCSpecies(label='ethene_b', smiles='C=C') + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +def _build_rxn_da_butadiene_ethene(): + """Small Diels-Alder C4H6 + C2H4 -> cyclohexene C6H10 (16 atoms).""" + r1 = ARCSpecies(label='1_3_butadiene', smiles='C=CC=C') + r2 = ARCSpecies(label='ethene', smiles='C=C') + p = ARCSpecies(label='cyclohexene', smiles='C1=CCCCC1') + return ARCReaction(r_species=[r1, r2], p_species=[p]) + + +def _build_rxn_ethanol_dehydration(): + """β-elimination CCO -> C=C + H2O (9 atoms).""" + r = ARCSpecies(label='ethanol', smiles='CCO') + p1 = ARCSpecies(label='ethene', smiles='C=C') + p2 = ARCSpecies(label='water', smiles='O') + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +def _build_rxn_methylamine_dehydrogenation(): + """1,2-dehydrogenation CN -> C=N + H2 (7 atoms total).""" + r = ARCSpecies(label='methylamine', smiles='CN') + p1 = ARCSpecies(label='methyleneamine', smiles='C=N') + p2 = ARCSpecies(label='dihydrogen', smiles='[H][H]') + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +def _build_rxn_ethyl_peroxy_ho2_elimination(): + """β-scission CCO[O] -> C=C + O[O] (9 atoms).""" + r = ARCSpecies(label='ethyl_peroxy', smiles='CCO[O]') + p1 = ARCSpecies(label='ethene', smiles='C=C') + p2 = ARCSpecies(label='hydroperoxyl', smiles='O[O]') + return ARCReaction(r_species=[r], p_species=[p1, p2]) + + +# === Group C: 2<->2 H-abstractions ========================================= + +def _build_rxn_hab_ch4_oh(): + """H-abstraction CH4 + OH -> CH3 + H2O (6 atoms each side).""" + r1 = ARCSpecies(label='methane', smiles='C') + r2 = ARCSpecies(label='hydroxyl', smiles='[OH]') + p1 = ARCSpecies(label='methyl', smiles='[CH3]') + p2 = ARCSpecies(label='water', smiles='O') + return ARCReaction(r_species=[r1, r2], p_species=[p1, p2]) + + +def _build_rxn_hab_c2h6_h(): + """H-abstraction C2H6 + H -> C2H5 + H2 (9 atoms).""" + r1 = ARCSpecies(label='ethane', smiles='CC') + r2 = ARCSpecies(label='H_atom', smiles='[H]') + p1 = ARCSpecies(label='ethyl', smiles='C[CH2]') + p2 = ARCSpecies(label='dihydrogen', smiles='[H][H]') + return ARCReaction(r_species=[r1, r2], p_species=[p1, p2]) + + +def _build_rxn_hab_nh3_oh(): + """H-abstraction NH3 + OH -> NH2 + H2O (6 atoms).""" + r1 = ARCSpecies(label='ammonia', smiles='N') + r2 = ARCSpecies(label='hydroxyl', smiles='[OH]') + p1 = ARCSpecies(label='amidogen', smiles='[NH2]') + p2 = ARCSpecies(label='water', smiles='O') + return ARCReaction(r_species=[r1, r2], p_species=[p1, p2]) + + +def _build_rxn_hab_ch3oh_h(): + """H-abstraction CH3OH + H -> CH2OH + H2 (7 atoms; abstracts α-CH).""" + r1 = ARCSpecies(label='methanol', smiles='CO') + r2 = ARCSpecies(label='H_atom', smiles='[H]') + p1 = ARCSpecies(label='hydroxymethyl', smiles='[CH2]O') + p2 = ARCSpecies(label='dihydrogen', smiles='[H][H]') + return ARCReaction(r_species=[r1, r2], p_species=[p1, p2]) + + +# --------------------------------------------------------------------------- +# Pure-python helpers + plumbing +# --------------------------------------------------------------------------- + +class TestRitSHelpers(unittest.TestCase): + """Helper-function unit tests that don't need rits_env.""" + + @classmethod + def setUpClass(cls): + cls.tmp_dir = os.path.join(ARC_TESTING_PATH, 'rits_helpers') + os.makedirs(cls.tmp_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmp_dir, ignore_errors=True) + + def test_write_xyz_file_round_trip(self): + """write_xyz_file should produce a parseable XYZ file with correct atom count.""" + xyz_dict = { + 'symbols': ('C', 'H', 'H', 'H', 'H'), + 'isotopes': (12, 1, 1, 1, 1), + 'coords': ( + (0.0, 0.0, 0.0), + (1.0, 0.0, 0.0), + (-1.0, 0.0, 0.0), + (0.0, 1.0, 0.0), + (0.0, -1.0, 0.0), + ), + } + path = os.path.join(self.tmp_dir, 'methane.xyz') + write_xyz_file(xyz_dict, path, comment='methane test') + self.assertTrue(os.path.isfile(path)) + with open(path) as f: + lines = f.read().splitlines() + # Header + self.assertEqual(int(lines[0]), 5) + self.assertEqual(lines[1], 'methane test') + # Body — 5 coordinate lines starting with the right symbols + body_symbols = [ln.split()[0] for ln in lines[2:7]] + self.assertEqual(body_symbols, ['C', 'H', 'H', 'H', 'H']) + # Round-trip via str_to_xyz + rt = str_to_xyz(path) + self.assertEqual(rt['symbols'], xyz_dict['symbols']) + + def test_write_xyz_file_strips_newlines_in_comment(self): + """A multi-line comment must not corrupt the XYZ format.""" + xyz_dict = { + 'symbols': ('H', 'H'), + 'isotopes': (1, 1), + 'coords': ((0.0, 0.0, 0.0), (0.74, 0.0, 0.0)), + } + path = os.path.join(self.tmp_dir, 'h2.xyz') + write_xyz_file(xyz_dict, path, comment='line1\nline2\nline3') + with open(path) as f: + lines = f.read().splitlines() + # Header is exactly 2 lines + 2 atoms = 4 lines minimum + self.assertEqual(int(lines[0]), 2) + self.assertNotIn('\n', lines[1]) + self.assertEqual(len(lines), 4) + + def test_process_rits_tsg_failed_entry(self): + """A failed-sentinel dict should not produce a TSGuess.""" + ts_species = ARCSpecies(label='TS', is_ts=True) + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 0, + 'initial_xyz': None, 'success': False, 'execution_time': '0:00:00.0'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertFalse(added) + self.assertEqual(len(ts_species.ts_guesses), 0) + + def test_process_rits_tsg_dedup_against_existing(self): + """A RitS guess that matches an existing GCN guess should not be appended; + the existing guess should be re-labeled to credit RitS as well.""" + ts_species = ARCSpecies(label='TS', is_ts=True) + # Plant a GCN guess first. + existing_xyz_str = """C 0.0 0.0 0.0 +H 1.0 0.0 0.0 +H -1.0 0.0 0.0 +H 0.0 1.0 0.0 +H 0.0 -1.0 0.0""" + existing = TSGuess(method='GCN', method_direction='F', method_index=0, + index=0, success=True) + existing.process_xyz(str_to_xyz(existing_xyz_str)) + ts_species.ts_guesses.append(existing) + # Submit a RitS guess with identical coordinates. + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 0, + 'initial_xyz': existing_xyz_str, 'success': True, + 'execution_time': '0:00:01.0'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertFalse(added) # not appended + self.assertEqual(len(ts_species.ts_guesses), 1) + # The existing guess should now credit both methods. Note: TSGuess + # lowercases the method string on construction. + merged = ts_species.ts_guesses[0].method.lower() + self.assertIn('rits', merged) + self.assertIn('gcn', merged) + + def test_process_rits_tsg_unique_guess_appended(self): + """A unique non-colliding guess should be appended.""" + ts_species = ARCSpecies(label='TS', is_ts=True) + unique_xyz = """C 0.0 0.0 0.0 +H 1.5 0.0 0.0 +H -1.5 0.0 0.0 +H 0.0 1.5 0.0 +H 0.0 -1.5 0.0""" + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 2, + 'initial_xyz': unique_xyz, 'success': True, + 'execution_time': '0:00:02.0'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertTrue(added) + self.assertEqual(len(ts_species.ts_guesses), 1) + # TSGuess lowercases method on construction. + self.assertEqual(ts_species.ts_guesses[0].method.lower(), 'rits') + self.assertEqual(ts_species.ts_guesses[0].method_index, 2) + self.assertTrue(ts_species.ts_guesses[0].success) + + def test_process_rits_tsg_collision_rejected(self): + """A guess where two atoms overlap must be rejected by colliding_atoms.""" + ts_species = ARCSpecies(label='TS', is_ts=True) + bad_xyz = """C 0.0 0.0 0.0 +H 0.0 0.0 0.0 +H -1.5 0.0 0.0 +H 0.0 1.5 0.0 +H 0.0 -1.5 0.0""" + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 0, + 'initial_xyz': bad_xyz, 'success': True, + 'execution_time': '0:00:00.5'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertFalse(added) + self.assertEqual(len(ts_species.ts_guesses), 0) + + def test_process_rits_tsg_dedup_catches_rigid_rotation(self): + """A rigidly rotated + translated copy of an existing TSGuess must be + deduped. This is the whole point of switching from byte-level + almost_equal_coords to distance-matrix compare_confs — RitS samples + each TS in its own random orientation, so rotated copies are common. + """ + ts_species = ARCSpecies(label='TS', is_ts=True) + # Plant the original (use atypical CH bond lengths so we can be sure + # the assertion isn't accidentally matching some default geometry). + original_xyz = """C 0.000 0.000 0.000 +H 0.700 0.700 0.700 +H -0.700 -0.700 0.700 +H -0.700 0.700 -0.700 +H 0.700 -0.700 -0.700""" + first = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 0, + 'initial_xyz': original_xyz, 'success': True, + 'execution_time': '0:00:00.0'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertTrue(first) + self.assertEqual(len(ts_species.ts_guesses), 1) + + # Build a 37° z-axis rotation + translation of the same molecule. + theta = math.radians(37.0) + cos_t, sin_t = math.cos(theta), math.sin(theta) + original_coords = [ + (0.000, 0.000, 0.000), + (0.700, 0.700, 0.700), + (-0.700, -0.700, 0.700), + (-0.700, 0.700, -0.700), + (0.700, -0.700, -0.700), + ] + rotated = [] + for x, y, z in original_coords: + rx = cos_t * x - sin_t * y + 10.0 # also translate by (+10, +5, -3) + ry = sin_t * x + cos_t * y + 5.0 + rz = z - 3.0 + rotated.append((rx, ry, rz)) + symbols = ('C', 'H', 'H', 'H', 'H') + rotated_xyz_str = '\n'.join( + f'{s} {x:.6f} {y:.6f} {z:.6f}' for s, (x, y, z) in zip(symbols, rotated) + ) + + added = process_rits_tsg( + tsg_dict={'method': 'RitS', 'method_direction': 'F', 'method_index': 1, + 'initial_xyz': rotated_xyz_str, 'success': True, + 'execution_time': '0:00:00.5'}, + local_path=self.tmp_dir, + ts_species=ts_species, + ) + self.assertFalse(added, + 'rotated+translated duplicate of an existing RitS guess ' + 'must be deduped via compare_confs (distance-matrix RMSD)') + self.assertEqual(len(ts_species.ts_guesses), 1, + 'no new TSGuess should be appended for a rotated duplicate') + + +class TestRitSScriptParser(unittest.TestCase): + """Direct unit tests for arc/job/adapters/scripts/rits_script.py:parse_multi_frame_xyz.""" + + @classmethod + def setUpClass(cls): + # Import the standalone script as a module so we can call its helpers directly. + scripts_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(rits_mod.__file__))), 'scripts') + if scripts_dir not in sys.path: + sys.path.insert(0, scripts_dir) + cls.rits_script = importlib.import_module('rits_script') + cls.tmp_dir = os.path.join(ARC_TESTING_PATH, 'rits_script_parser') + os.makedirs(cls.tmp_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(cls.tmp_dir, ignore_errors=True) + + def _write(self, name: str, body: str) -> str: + path = os.path.join(self.tmp_dir, name) + with open(path, 'w') as f: + f.write(body) + return path + + def test_single_frame_xyz(self): + body = "3\n\nC 0.0 0.0 0.0\nH 1.0 0.0 0.0\nH -1.0 0.0 0.0\n" + frames = self.rits_script.parse_multi_frame_xyz(self._write('one.xyz', body)) + self.assertEqual(len(frames), 1) + self.assertEqual(frames[0].splitlines()[0].split()[0], 'C') + + def test_multi_frame_xyz(self): + body = ("3\n\nC 0.0 0.0 0.0\nH 1.0 0.0 0.0\nH -1.0 0.0 0.0\n" + "3\n\nC 0.1 0.0 0.0\nH 1.1 0.0 0.0\nH -0.9 0.0 0.0\n") + frames = self.rits_script.parse_multi_frame_xyz(self._write('two.xyz', body)) + self.assertEqual(len(frames), 2) + # Frame 0 starts at the origin; frame 1 is shifted by +0.1 in x + self.assertAlmostEqual(float(frames[0].splitlines()[0].split()[1]), 0.0) + self.assertAlmostEqual(float(frames[1].splitlines()[0].split()[1]), 0.1) + + def test_missing_file_returns_empty_list(self): + frames = self.rits_script.parse_multi_frame_xyz(os.path.join(self.tmp_dir, 'nope.xyz')) + self.assertEqual(frames, list()) + + def test_garbage_does_not_loop_forever(self): + body = "this is not an xyz\nat all\n" + frames = self.rits_script.parse_multi_frame_xyz(self._write('garbage.xyz', body)) + self.assertEqual(frames, list()) + + +class TestRitSAdapterInstantiation(unittest.TestCase): + """Verify the adapter constructs and lays out files even without rits_env.""" + + @classmethod + def setUpClass(cls): + cls.maxDiff = None + cls.output_dir = os.path.join(ARC_TESTING_PATH, 'RitS', 'instantiation') + os.makedirs(cls.output_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(os.path.join(ARC_TESTING_PATH, 'RitS'), ignore_errors=True) + + def _build_adapter(self, project_dir: str, n_samples: int = 5): + rxn = _build_rxn_isomerization_propyl() + return RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=project_dir, + args={'keyword': {'n_samples': n_samples}}, + ) + + def test_instantiation_sets_paths_and_metadata(self): + proj = os.path.join(self.output_dir, 'paths') + adapter = self._build_adapter(proj, n_samples=7) + self.assertEqual(adapter.job_adapter, 'rits') + self.assertEqual(adapter.execution_type, 'incore') + self.assertEqual(adapter.url, 'https://github.com/isayevlab/RitS') + self.assertEqual(adapter.incore_capacity, 1) + self.assertEqual(adapter.n_samples, 7) + # File paths should all live under the local_path the adapter set up + self.assertTrue(adapter.reactant_xyz_path.endswith('reactant.xyz')) + self.assertTrue(adapter.product_xyz_path.endswith('product.xyz')) + self.assertTrue(adapter.ts_out_xyz_path.endswith('rits_ts.xyz')) + self.assertTrue(adapter.yml_in_path.endswith('input.yml')) + self.assertTrue(adapter.yml_out_path.endswith('output.yml')) + # All five paths should share a parent directory + parents = {os.path.dirname(p) for p in (adapter.reactant_xyz_path, + adapter.product_xyz_path, + adapter.ts_out_xyz_path, + adapter.yml_in_path, + adapter.yml_out_path)} + self.assertEqual(len(parents), 1) + + def test_default_n_samples(self): + proj = os.path.join(self.output_dir, 'default_samples') + adapter = RitSAdapter( + job_type='tsg', + reactions=[_build_rxn_isomerization_propyl()], + testing=True, + project='test_rits', + project_directory=proj, + ) + self.assertEqual(adapter.n_samples, rits_mod.DEFAULT_N_SAMPLES) + + def test_n_samples_invalid_args_falls_back_to_default(self): + proj = os.path.join(self.output_dir, 'bad_samples') + adapter = RitSAdapter( + job_type='tsg', + reactions=[_build_rxn_isomerization_propyl()], + testing=True, + project='test_rits', + project_directory=proj, + args={'keyword': {'n_samples': 'not-a-number'}}, + ) + self.assertEqual(adapter.n_samples, rits_mod.DEFAULT_N_SAMPLES) + + def test_missing_reactions_raises(self): + proj = os.path.join(self.output_dir, 'no_reactions') + with self.assertRaises(ValueError): + RitSAdapter(job_type='tsg', reactions=None, testing=True, + project='test_rits', project_directory=proj) + + +class TestRitSGracefulSkip(unittest.TestCase): + """When rits_env / checkpoint are missing, execute_incore must NOT raise.""" + + @classmethod + def setUpClass(cls): + cls.output_dir = os.path.join(ARC_TESTING_PATH, 'RitS', 'graceful_skip') + os.makedirs(cls.output_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(os.path.join(ARC_TESTING_PATH, 'RitS'), ignore_errors=True) + + def test_missing_python_logs_and_returns(self): + rxn = _build_rxn_isomerization_propyl() + adapter = RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=os.path.join(self.output_dir, 'no_python'), + ) + # Patch the module-level constants to simulate a host without rits_env. + with mock.patch.object(rits_mod, 'RITS_PYTHON', None), \ + mock.patch.object(rits_mod, 'RITS_REPO_PATH', '/nonexistent/RitS'), \ + mock.patch.object(rits_mod, 'RITS_CKPT_PATH', '/nonexistent/rits.ckpt'): + # Should not raise + adapter.execute_incore() + # No TS guesses should have been created + if rxn.ts_species is not None: + self.assertEqual(len(rxn.ts_species.ts_guesses), 0) + + def test_missing_checkpoint_logs_and_returns(self): + rxn = _build_rxn_isomerization_propyl() + adapter = RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=os.path.join(self.output_dir, 'no_ckpt'), + ) + with mock.patch.object(rits_mod, 'RITS_CKPT_PATH', '/nonexistent/ckpt'): + adapter.execute_incore() + if rxn.ts_species is not None: + self.assertEqual(len(rxn.ts_species.ts_guesses), 0) + + +class TestRitSInputYamlWritten(unittest.TestCase): + """Verify input.yml is written correctly without invoking the real subprocess.""" + + @classmethod + def setUpClass(cls): + cls.output_dir = os.path.join(ARC_TESTING_PATH, 'RitS', 'input_yml') + os.makedirs(cls.output_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(os.path.join(ARC_TESTING_PATH, 'RitS'), ignore_errors=True) + + def test_input_yml_contents(self): + """A successful execute_incore should write input.yml with all required keys. + + We mock subprocess.run so the test does not depend on rits_env actually + being installed.""" + rxn = _build_rxn_diels_alder() + adapter = RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=os.path.join(self.output_dir, 'da'), + args={'keyword': {'n_samples': 4}}, + ) + + # Pretend the env is fully ready, but make subprocess.run a no-op so we + # never actually invoke RitS — we only care about input.yml + the + # mapped reactant.xyz / product.xyz files we wrote. + fake_completed = mock.Mock(returncode=0) + with mock.patch.object(rits_mod, '_rits_environment_ready', return_value=True), \ + mock.patch.object(rits_mod, 'RITS_PYTHON', '/fake/python'), \ + mock.patch.object(rits_mod, 'RITS_REPO_PATH', '/fake/RitS'), \ + mock.patch.object(rits_mod, 'RITS_CKPT_PATH', '/fake/rits.ckpt'), \ + mock.patch('arc.job.adapters.ts.rits_ts.subprocess.run', + return_value=fake_completed) as run_mock: + adapter.execute_incore() + + self.assertTrue(run_mock.called) + # input.yml should exist with the keys our standalone script expects + self.assertTrue(os.path.isfile(adapter.yml_in_path)) + in_dict = read_yaml_file(adapter.yml_in_path) + for key in ('reactant_xyz_path', 'product_xyz_path', 'rits_repo_path', + 'ckpt_path', 'output_xyz_path', 'yml_out_path', + 'config_path', 'n_samples', 'batch_size', 'charge', 'device'): + self.assertIn(key, in_dict, f'missing key {key} in input.yml') + self.assertEqual(in_dict['n_samples'], 4) + self.assertEqual(in_dict['device'], 'auto') + self.assertEqual(in_dict['rits_repo_path'], '/fake/RitS') + self.assertEqual(in_dict['ckpt_path'], '/fake/rits.ckpt') + self.assertTrue(in_dict['config_path'].endswith('rits.yaml')) + # The reactant + product XYZ files should be on disk and have matching atom counts + self.assertTrue(os.path.isfile(adapter.reactant_xyz_path)) + self.assertTrue(os.path.isfile(adapter.product_xyz_path)) + with open(adapter.reactant_xyz_path) as f: + r_n = int(f.readline()) + with open(adapter.product_xyz_path) as f: + p_n = int(f.readline()) + self.assertEqual(r_n, p_n) + # Diels-Alder C=CC(=C)C + C=CC=O → CC1=CCC(C=O)CC1 has 21 atoms + self.assertEqual(r_n, 21) + + +# --------------------------------------------------------------------------- +# End-to-end runs against the real rits_env (skipped without it) +# --------------------------------------------------------------------------- + +@unittest.skipUnless(HAS_RITS, 'rits_env / checkpoint not installed; run `make install-rits` to enable.') +class TestRitSEndToEnd(unittest.TestCase): + """End-to-end runs through subprocess into the real rits_env. + + These tests are gated on `_rits_environment_ready()` so a CI runner that + skipped install_rits.sh still gets a green run. The matching CI lane + `rits-install` in .github/workflows/ci.yml installs the env and exercises + them on every PR. + + Each test asks for a small number of samples (n_samples=2) so the runtime + stays reasonable: even on CPU, two samples per reaction completes in well + under a minute on the model RitS ships. + """ + + @classmethod + def setUpClass(cls): + cls.output_dir = os.path.join(ARC_TESTING_PATH, 'RitS', 'e2e') + os.makedirs(cls.output_dir, exist_ok=True) + + @classmethod + def tearDownClass(cls): + shutil.rmtree(os.path.join(ARC_TESTING_PATH, 'RitS'), ignore_errors=True) + + def _run_e2e(self, rxn, label: str, expected_n_atoms: int, n_samples: int = 2, + expect_success: bool = True): + """Helper: build adapter, execute, return the (rxn, adapter) pair after assertions. + + Args: + rxn: The ARCReaction to feed to RitS. + label: Subdirectory name under the test output dir. + expected_n_atoms: Atom count both reactant and product XYZs should match. + n_samples: Number of TS samples to ask RitS for. + expect_success: When True, assert at least one usable TSGuess was produced. + When False, assert only that the adapter handled RitS's failure + gracefully (output.yml exists, failed-sentinel entry inside, no + crash). Used for reactions RitS cannot handle by design — e.g. + charged/zwitterionic species, where its OpenBabel bond inference + trips RDKit sanitization. + """ + proj = os.path.join(self.output_dir, label) + adapter = RitSAdapter( + job_type='tsg', + reactions=[rxn], + testing=True, + project='test_rits', + project_directory=proj, + args={'keyword': {'n_samples': n_samples}}, + ) + adapter.execute_incore() + + # The reactant + product XYZ that ARC fed to RitS must have matching atom counts + with open(adapter.reactant_xyz_path) as f: + r_n = int(f.readline()) + with open(adapter.product_xyz_path) as f: + p_n = int(f.readline()) + self.assertEqual(r_n, expected_n_atoms) + self.assertEqual(p_n, expected_n_atoms) + # The reactant and product elements must match as multisets — atoms are + # neither created nor destroyed across an elementary reaction. + r_xyz_dict = str_to_xyz(adapter.reactant_xyz_path) + p_xyz_dict = str_to_xyz(adapter.product_xyz_path) + expected_formula = Counter(r_xyz_dict['symbols']) + self.assertEqual(expected_formula, Counter(p_xyz_dict['symbols']), + f'reactant and product element multisets disagree for {label}') + + # The output YAML should exist and be readable in either case + self.assertTrue(os.path.isfile(adapter.yml_out_path), + f'rits_script.py did not write {adapter.yml_out_path}') + out = read_yaml_file(adapter.yml_out_path) or list() + self.assertGreater(len(out), 0, f'rits_script.py produced 0 entries for {label}') + successes = [tsg for tsg in out if tsg.get('success') and tsg.get('initial_xyz')] + + if expect_success: + self.assertGreater(len(successes), 0, + f'RitS produced 0 successful TSGuess entries for {label}') + # Strict check: EVERY successful TS must have the same atom count + # AND the same element multiset as the reactants. Catches both + # atom-count mismatches and element-shuffling bugs. + for i, tsg_dict in enumerate(successes): + ts_xyz = str_to_xyz(tsg_dict['initial_xyz']) + self.assertEqual( + len(ts_xyz['symbols']), expected_n_atoms, + f'{label} TS sample {i}: atom count {len(ts_xyz["symbols"])} ' + f'!= expected {expected_n_atoms}', + ) + actual_formula = Counter(ts_xyz['symbols']) + self.assertEqual( + actual_formula, expected_formula, + f'{label} TS sample {i}: molecular formula ' + f'{dict(actual_formula)} does not match reactant ' + f'{dict(expected_formula)}', + ) + else: + # Failure path: there should be exactly one failed sentinel entry, + # and the adapter must not have created any successful TS guesses + # on the reaction object. + self.assertEqual(len(successes), 0, + f'Expected RitS to fail on {label}, but got ' + f'{len(successes)} successful guess(es)') + self.assertTrue(all(not tsg.get('success') for tsg in out)) + return adapter + + def test_e2e_isomerization_propyl(self): + """nC3H7 → iC3H7 (10 atoms, isomerization). + + With ``n_samples=2`` RitS produces two TS guesses. We assert that + BOTH survive the distance-matrix dedup — verified empirically: + the two samples differ along the reaction coordinate (C-C of the + donor side: 1.52 Å vs 1.76 Å; migrating-H acceptor distance: + 1.28 Å vs 1.16 Å), with a distance-matrix RMSD of ~0.56 Å — well + above the 0.1 Å dedup threshold. They represent two diverse + starting points for downstream Gaussian/ORCA TS optimization, + which is exactly the value of asking for ``n_samples > 1``. + Rotated/translated *exact* copies would be merged — see + TestRitSHelpers.test_process_rits_tsg_dedup_catches_rigid_rotation. + """ + adapter = self._run_e2e(_build_rxn_isomerization_propyl(), + label='isom_propyl', expected_n_atoms=10) + rxn = adapter.reactions[0] + successful = [tsg for tsg in rxn.ts_species.ts_guesses if tsg.success] + # Both samples should survive — they are structurally distinct. + self.assertEqual( + len(successful), 2, + f'Expected 2 unique TS guesses for nC3H7→iC3H7 (each from a ' + f'separate point on the reaction coordinate), got {len(successful)}.', + ) + # Sanity-check they ARE distinct under compare_confs (else dedup is broken). + self.assertFalse( + compare_confs(successful[0].initial_xyz, successful[1].initial_xyz), + 'The two propyl TS guesses unexpectedly compare equal — RitS may have ' + 'collapsed onto a single saddle, or the dedup is mis-tuned.', + ) + + def test_e2e_diels_alder(self): + """Diels-Alder bimolecular addition (21 atoms).""" + self._run_e2e(_build_rxn_diels_alder(), + label='diels_alder', expected_n_atoms=21) + + def test_e2e_one_plus_two_cycloaddition(self): # fails + """1+2 cycloaddition with singlet carbene (10 atoms, bimolecular).""" + self._run_e2e(_build_rxn_one_plus_two_cycloaddition(), + label='one_plus_two', expected_n_atoms=10) + + def test_e2e_nh3_elimination_graceful_failure(self): # fails (as planned) + """1,2-NH3 elimination NNN → H2NN(s) + NH3 — RitS cannot handle this + because its OpenBabel bond inference rejects the zwitterionic + aminonitrene product (4-valent N+). The adapter must: + + * still write input.yml + reactant.xyz + product.xyz + * still get a non-empty output.yml back + * write a failed-sentinel TSGuess entry + * NOT raise + + This test pins the graceful-failure code path so it doesn't regress. + """ + adapter = self._run_e2e(_build_rxn_nh3_elimination(), + label='nh3_elim_graceful', expected_n_atoms=8, + expect_success=False) + # The reaction's ts_species should still exist but have no successful TSGuesses. + rxn = adapter.reactions[0] + self.assertIsNotNone(rxn.ts_species) + successful = [tsg for tsg in rxn.ts_species.ts_guesses if tsg.success] + self.assertEqual(len(successful), 0) + + # ----- Group A: 1<->1 isomerizations ------------------------------------- + + def test_e2e_vinyl_alcohol_to_acetaldehyde(self): + """Keto-enol tautomerization C2H4O (7 atoms: 2C + 4H + 1O).""" + self._run_e2e(_build_rxn_vinyl_alcohol_to_acetaldehyde(), + label='vinyl_alcohol_to_acetaldehyde', expected_n_atoms=7) + + def test_e2e_propenol_to_acetone(self): + """Keto-enol tautomerization C3H6O (10 atoms).""" + self._run_e2e(_build_rxn_propenol_to_acetone(), + label='propenol_to_acetone', expected_n_atoms=10) + + def test_e2e_cyclobutene_to_butadiene(self): + """Electrocyclic ring opening C4H6 (10 atoms).""" + self._run_e2e(_build_rxn_cyclobutene_to_butadiene(), + label='cyclobutene_to_butadiene', expected_n_atoms=10) + + def test_e2e_methoxy_to_hydroxymethyl(self): + """1,2-H migration in CH3O radical (5 atoms).""" + self._run_e2e(_build_rxn_methoxy_to_hydroxymethyl(), + label='methoxy_to_hydroxymethyl', expected_n_atoms=5) + + def test_e2e_ethoxy_to_alpha_hydroxyethyl(self): + """1,2-H migration in CH3CH2O radical (8 atoms).""" + self._run_e2e(_build_rxn_ethoxy_to_alpha_hydroxyethyl(), + label='ethoxy_to_alpha_hydroxyethyl', expected_n_atoms=8) + + def test_e2e_cyclopropane_to_propene(self): + """Cyclopropane ring opening C3H6 (9 atoms).""" + self._run_e2e(_build_rxn_cyclopropane_to_propene(), + label='cyclopropane_to_propene', expected_n_atoms=9) + + # ----- Group B: 1<->2 / 2<->1 (eliminations / cycloadditions) ----------- + + def test_e2e_cyclobutane_retro_22(self): + """Retro [2+2] cyclobutane -> 2 ethene (12 atoms).""" + self._run_e2e(_build_rxn_cyclobutane_retro_22(), + label='cyclobutane_retro_22', expected_n_atoms=12) + + def test_e2e_da_butadiene_ethene(self): + """Small Diels-Alder butadiene + ethene -> cyclohexene (16 atoms).""" + self._run_e2e(_build_rxn_da_butadiene_ethene(), + label='da_butadiene_ethene', expected_n_atoms=16) + + def test_e2e_ethanol_dehydration(self): + """β-elimination ethanol -> ethene + water (9 atoms).""" + self._run_e2e(_build_rxn_ethanol_dehydration(), + label='ethanol_dehydration', expected_n_atoms=9) + + def test_e2e_methylamine_dehydrogenation(self): + """1,2-dehydrogenation methylamine -> methyleneamine + H2 (7 atoms).""" + self._run_e2e(_build_rxn_methylamine_dehydrogenation(), + label='methylamine_dehydrogenation', expected_n_atoms=7) + + def test_e2e_ethyl_peroxy_ho2_elimination(self): + """β-scission ethyl peroxy -> ethene + HO2 (9 atoms).""" + self._run_e2e(_build_rxn_ethyl_peroxy_ho2_elimination(), + label='ethyl_peroxy_ho2_elimination', expected_n_atoms=9) + + # ----- Group C: 2<->2 H-abstractions ------- ----------------------------- + + def test_e2e_hab_ch4_oh(self): + """H-abstraction CH4 + OH -> CH3 + H2O (7 atoms total: 1C + 5H + 1O).""" + self._run_e2e(_build_rxn_hab_ch4_oh(), + label='hab_ch4_oh', expected_n_atoms=7) + + def test_e2e_hab_c2h6_h(self): + """H-abstraction C2H6 + H -> C2H5 + H2 (9 atoms).""" + self._run_e2e(_build_rxn_hab_c2h6_h(), + label='hab_c2h6_h', expected_n_atoms=9) + + def test_e2e_hab_nh3_oh(self): + """H-abstraction NH3 + OH -> NH2 + H2O (6 atoms).""" + self._run_e2e(_build_rxn_hab_nh3_oh(), + label='hab_nh3_oh', expected_n_atoms=6) + + def test_e2e_hab_ch3oh_h(self): + """H-abstraction CH3OH + H -> CH2OH + H2 (7 atoms; abstracts α-CH).""" + self._run_e2e(_build_rxn_hab_ch3oh_h(), + label='hab_ch3oh_h', expected_n_atoms=7) + + +if __name__ == '__main__': + unittest.main(testRunner=unittest.TextTestRunner(verbosity=2)) diff --git a/arc/job/adapters/ts/rits_ts.py b/arc/job/adapters/ts/rits_ts.py new file mode 100644 index 0000000000..48f1bd5f15 --- /dev/null +++ b/arc/job/adapters/ts/rits_ts.py @@ -0,0 +1,459 @@ +""" +An adapter for executing RitS (Right into the Saddle) TS-guess jobs. + +RitS is a flow-matching ML model that generates 3D transition-state geometries +directly from atom-mapped reactant + product structures, without requiring an +initial guess. Unlike GCN (which is restricted to isomerizations), RitS can +handle bimolecular reactions and supports charged species, so it covers a +strictly larger reaction space. + +Code source : https://github.com/isayevlab/RitS +Paper : 10.26434/chemrxiv.15001681/v1 +Pretrained ckpt : https://doi.org/10.5281/zenodo.19474153 + +Implementation notes +-------------------- +* The heavy ML stack (torch + torch-geometric + megalodon) lives in its own + conda env (``rits_env``), so this adapter never imports it directly. It + shells out to ``arc/job/adapters/scripts/rits_script.py`` via subprocess, + which in turn invokes RitS's own ``scripts/sample_transition_state.py``. +* RitS requires the reactant and product XYZ files to have the *same atom + count and the same atom ordering* (it aligns them by index). ARC's + ``rxn.get_reactants_xyz`` / ``get_products_xyz`` already produce mapped + outputs via ``rxn.atom_map``, so we can use them as-is. +* Multiple samples per reaction are produced in a single subprocess call + (RitS's ``--n_samples`` flag), avoiding the per-sample model-load overhead + that GCN incurs. +* If ``rits_env`` or the pretrained checkpoint is missing on the host, the + adapter logs a warning and exits cleanly without raising — the rest of + ARC's TS-search pipeline (heuristics, GCN, AutoTST, …) keeps running. +* ``incore_capacity = 1`` so the scheduler serializes RitS jobs and a single + GPU is not asked to load multiple checkpoints in parallel. +""" + +import datetime +import os +import subprocess +from typing import TYPE_CHECKING, List, Optional, Tuple, Union + +from arc.common import ARC_PATH, get_logger, save_yaml_file, read_yaml_file +from arc.imports import settings +from arc.job.adapter import JobAdapter +from arc.job.adapters.common import _initialize_adapter +from arc.job.factory import register_job_adapter +from arc.plotter import save_geo +from arc.species.converter import compare_confs, str_to_xyz, xyz_to_str +from arc.species.species import ARCSpecies, TSGuess, colliding_atoms + +if TYPE_CHECKING: + from arc.level import Level + from arc.reaction import ARCReaction + + +RITS_PYTHON = settings.get('RITS_PYTHON') +RITS_REPO_PATH = settings.get('RITS_REPO_PATH') +RITS_CKPT_PATH = settings.get('RITS_CKPT_PATH') + +RITS_SCRIPT_PATH = os.path.join(ARC_PATH, 'arc', 'job', 'adapters', 'scripts', 'rits_script.py') +DEFAULT_N_SAMPLES = 10 +DEFAULT_BATCH_SIZE = 32 + +logger = get_logger() + + +class RitSAdapter(JobAdapter): + """ + A class for executing RitS (Right into the Saddle) TS-guess jobs. + + Args: + project (str): The project's name. Used for setting the remote path. + project_directory (str): The path to the local project directory. + job_type (list, str): The job's type, validated against ``JobTypeEnum``. + args (dict, optional): Methods (including troubleshooting) to be used in + input files. For RitS the only currently-honored entry is + ``args['keyword']['n_samples']`` (int, default 10). + bath_gas (str, optional): A bath gas. Currently only used in OneDMin. + checkfile (str, optional): The path to a previous Gaussian checkfile. + conformer (int, optional): Conformer number if optimizing conformers. + constraints (list, optional): A list of constraints. + cpu_cores (int, optional): The total number of cpu cores requested for a job. + dihedral_increment (float, optional): Unused for RitS. + dihedrals (List[float], optional): The dihedral angles corresponding to + self.torsions. + directed_scan_type (str, optional): The type of the directed scan. + ess_settings (dict, optional): A dictionary of available ESS. + ess_trsh_methods (List[str], optional): A list of troubleshooting methods. + execution_type (str, optional): The execution type, 'incore', 'queue', or 'pipe'. + fine (bool, optional): Whether to use fine geometry optimization parameters. + initial_time (datetime.datetime or str, optional): The time at which this job was initiated. + irc_direction (str, optional): The direction of the IRC job. + job_id (int, optional): The job's ID determined by the server. + job_memory_gb (int, optional): The total job allocated memory in GB. + job_name (str, optional): The job's name. + job_num (int, optional): Used as the entry number in the database. + job_server_name (str, optional): Job's name on the server. + job_status (list, optional): The job's server and ESS statuses. + level (Level, optional): The level of theory to use. + max_job_time (float, optional): The maximal allowed job time on the server in hours. + run_multi_species (bool, optional): Whether to run a job for multiple species in the same input file. + reactions (List[ARCReaction], optional): Entries are ARCReaction instances. + rotor_index (int, optional): The 0-indexed rotor number. + server (str): The server to run on. + server_nodes (list, optional): The nodes this job was previously submitted to. + species (List[ARCSpecies], optional): Entries are ARCSpecies instances. + testing (bool, optional): Whether the object is generated for testing purposes. + times_rerun (int, optional): Number of times this job was re-run. + torsions (List[List[int]], optional): The 0-indexed atom indices of the torsion(s). + tsg (int, optional): TSGuess number if optimizing TS guesses. + xyz (dict, optional): The 3D coordinates to use. + """ + + def __init__(self, + project: str, + project_directory: str, + job_type: Union[List[str], str], + args: Optional[dict] = None, + bath_gas: Optional[str] = None, + checkfile: Optional[str] = None, + conformer: Optional[int] = None, + constraints: Optional[List[Tuple[List[int], float]]] = None, + cpu_cores: Optional[str] = None, + dihedral_increment: Optional[float] = None, + dihedrals: Optional[List[float]] = None, + directed_scan_type: Optional[str] = None, + ess_settings: Optional[dict] = None, + ess_trsh_methods: Optional[List[str]] = None, + execution_type: Optional[str] = None, + fine: bool = False, + initial_time: Optional[Union['datetime.datetime', str]] = None, + irc_direction: Optional[str] = None, + job_id: Optional[int] = None, + job_memory_gb: float = 14.0, + job_name: Optional[str] = None, + job_num: Optional[int] = None, + job_server_name: Optional[str] = None, + job_status: Optional[List[Union[dict, str]]] = None, + level: Optional['Level'] = None, + max_job_time: Optional[float] = None, + run_multi_species: bool = False, + reactions: Optional[List['ARCReaction']] = None, + rotor_index: Optional[int] = None, + server: Optional[str] = None, + server_nodes: Optional[list] = None, + queue: Optional[str] = None, + attempted_queues: Optional[List[str]] = None, + species: Optional[List['ARCSpecies']] = None, + testing: bool = False, + times_rerun: int = 0, + torsions: Optional[List[List[int]]] = None, + tsg: Optional[int] = None, + xyz: Optional[dict] = None, + ): + + # Single in-flight job per scheduler tick — RitS holds an ML model in + # GPU memory, parallelizing it across reactions would risk OOM. + self.incore_capacity = 1 + self.job_adapter = 'rits' + self.execution_type = execution_type or 'incore' + self.command = 'sample_transition_state.py' + self.url = 'https://github.com/isayevlab/RitS' + + if reactions is None: + raise ValueError('Cannot execute RitS without ARCReaction object(s).') + + # Number of TS samples to draw per reaction. Honored from args['keyword']['n_samples'] + # so users can bump it via the standard ARC adapter-args path. + self.n_samples = DEFAULT_N_SAMPLES + if args and isinstance(args, dict): + kw = args.get('keyword') or dict() + if 'n_samples' in kw: + try: + self.n_samples = int(kw['n_samples']) + except (TypeError, ValueError): + logger.warning( + f"RitS adapter: could not parse args['keyword']['n_samples']=" + f"{kw['n_samples']!r} as an int; falling back to " + f"DEFAULT_N_SAMPLES={DEFAULT_N_SAMPLES}." + ) + + _initialize_adapter(obj=self, + is_ts=True, + project=project, + project_directory=project_directory, + job_type=job_type, + args=args, + bath_gas=bath_gas, + checkfile=checkfile, + conformer=conformer, + constraints=constraints, + cpu_cores=cpu_cores, + dihedral_increment=dihedral_increment, + dihedrals=dihedrals, + directed_scan_type=directed_scan_type, + ess_settings=ess_settings, + ess_trsh_methods=ess_trsh_methods, + fine=fine, + initial_time=initial_time, + irc_direction=irc_direction, + job_id=job_id, + job_memory_gb=job_memory_gb, + job_name=job_name, + job_num=job_num, + job_server_name=job_server_name, + job_status=job_status, + level=level, + max_job_time=max_job_time, + run_multi_species=run_multi_species, + reactions=reactions, + rotor_index=rotor_index, + server=server, + server_nodes=server_nodes, + queue=queue, + attempted_queues=attempted_queues, + species=species, + testing=testing, + times_rerun=times_rerun, + torsions=torsions, + tsg=tsg, + xyz=xyz, + ) + + def write_input_file(self) -> None: + """No standalone input file — see set_files() (writes input.yml).""" + pass + + def set_files(self) -> None: + """ + Set files to be uploaded and downloaded for queue execution. + + ``self.files_to_upload`` is a list of dictionaries, each with the keys + ``'name'``, ``'source'``, ``'make_x'``, ``'local'``, and ``'remote'``. + """ + # 1. Upload + if self.execution_type != 'incore': + self.write_submit_script() + from arc.imports import settings as _s + self.files_to_upload.append(self.get_file_property_dictionary( + file_name=_s['submit_filenames'][_s['servers'][self.server]['cluster_soft']])) + if os.path.isfile(self.yml_in_path): + self.files_to_upload.append(self.get_file_property_dictionary(file_name='input.yml')) + if os.path.isfile(self.reactant_xyz_path): + self.files_to_upload.append(self.get_file_property_dictionary(file_name='reactant.xyz')) + if os.path.isfile(self.product_xyz_path): + self.files_to_upload.append(self.get_file_property_dictionary(file_name='product.xyz')) + # 2. Download + self.files_to_download.append(self.get_file_property_dictionary(file_name='output.yml')) + self.files_to_download.append(self.get_file_property_dictionary(file_name='rits_ts.xyz')) + + def set_additional_file_paths(self) -> None: + """Set the local file paths used by RitS at job time.""" + self.reactant_xyz_path = os.path.join(self.local_path, 'reactant.xyz') + self.product_xyz_path = os.path.join(self.local_path, 'product.xyz') + self.ts_out_xyz_path = os.path.join(self.local_path, 'rits_ts.xyz') + self.yml_in_path = os.path.join(self.local_path, 'input.yml') + self.yml_out_path = os.path.join(self.local_path, 'output.yml') + + def set_input_file_memory(self) -> None: + """Set the input file memory attribute.""" + self.cpu_cores, self.job_memory_gb = 1, 1 + + def execute_incore(self): + """Execute the RitS job locally (in-process subprocess).""" + self._log_job_execution() + self.initial_time = self.initial_time if self.initial_time else datetime.datetime.now() + self.execute_rits() + self.final_time = datetime.datetime.now() + + def execute_queue(self): + """Execute the RitS job to the server's queue.""" + self.execute_rits(exe_type='queue') + + def execute_rits(self, exe_type: str = 'incore'): + """ + Drive the RitS subprocess and stitch its output back into ARC. + + Args: + exe_type (str, optional): Either ``'incore'`` (run locally now) or + ``'queue'`` (just stage the input.yml + submit script). + """ + if not _rits_environment_ready(): + return + rxn = self.reactions[0] + if rxn.ts_species is None: + rxn.ts_species = ARCSpecies(label=self.species_label, + is_ts=True, + charge=rxn.charge, + multiplicity=rxn.multiplicity, + ) + + # Build atom-aligned reactant + product XYZ files. ARC's get_reactants_xyz / + # get_products_xyz already use rxn.atom_map to align orderings. + try: + r_xyz_dict = rxn.get_reactants_xyz(return_format='dict') + p_xyz_dict = rxn.get_products_xyz(return_format='dict') + except Exception as e: + logger.warning(f'RitS: could not build mapped XYZs for {rxn.label}: {e}') + return + if r_xyz_dict is None or p_xyz_dict is None: + logger.warning(f'RitS: empty mapped XYZs for {rxn.label}') + return + if len(r_xyz_dict['symbols']) != len(p_xyz_dict['symbols']): + logger.warning(f'RitS: atom count mismatch for {rxn.label} ' + f'(R has {len(r_xyz_dict["symbols"])}, P has {len(p_xyz_dict["symbols"])}). ' + f'Skipping.') + return + + write_xyz_file(r_xyz_dict, self.reactant_xyz_path, comment=f'{rxn.label} reactant') + write_xyz_file(p_xyz_dict, self.product_xyz_path, comment=f'{rxn.label} product') + + input_dict = { + 'reactant_xyz_path': self.reactant_xyz_path, + 'product_xyz_path': self.product_xyz_path, + 'rits_repo_path': RITS_REPO_PATH, + 'ckpt_path': RITS_CKPT_PATH, + 'output_xyz_path': self.ts_out_xyz_path, + 'yml_out_path': self.yml_out_path, + 'config_path': os.path.join(RITS_REPO_PATH, 'scripts', 'conf', 'rits.yaml'), + 'n_samples': self.n_samples, + 'batch_size': DEFAULT_BATCH_SIZE, + 'charge': int(rxn.charge or 0), + 'device': 'auto', + } + save_yaml_file(path=self.yml_in_path, content=input_dict) + + if exe_type == 'queue': + self.legacy_queue_execution() + return + + # Incore: subprocess into rits_script.py inside rits_env. + # Pass argv as a list (not shell=True) so paths containing spaces or + # shell-special characters are handled safely without quoting. + cmd = [RITS_PYTHON, RITS_SCRIPT_PATH, '--yml_in_path', self.yml_in_path] + result = subprocess.run(cmd, check=False) + if result.returncode != 0: + logger.warning(f'RitS subprocess returned non-zero exit code {result.returncode} for {rxn.label}.') + return + + if not os.path.isfile(self.yml_out_path): + logger.warning(f'RitS produced no output YAML at {self.yml_out_path} for {rxn.label}.') + return + + tsg_dicts = read_yaml_file(self.yml_out_path) or list() + n_added = 0 + for tsg_dict in tsg_dicts: + if process_rits_tsg(tsg_dict=tsg_dict, + local_path=self.local_path, + ts_species=rxn.ts_species): + n_added += 1 + + if len(self.reactions) < 5: + if n_added: + logger.info(f'RitS successfully found {n_added} TS guesses for {rxn.label}.') + else: + logger.info(f'RitS did not find any successful TS guesses for {rxn.label}.') + + +def write_xyz_file(xyz_dict: dict, path: str, comment: str = '') -> None: + """ + Write an ARC xyz dict to a plain XYZ file with the standard + ``\\n\\n...`` header. + + Args: + xyz_dict (dict): An ARC xyz dictionary. + path (str): Output file path. + comment (str): Optional comment line (kept on a single line). + """ + body = xyz_to_str(xyz_dict) + n_atoms = len(xyz_dict['symbols']) + safe_comment = comment.replace('\n', ' ').strip() + with open(path, 'w') as f: + f.write(f'{n_atoms}\n{safe_comment}\n{body}\n') + + +def process_rits_tsg(tsg_dict: dict, + local_path: str, + ts_species: ARCSpecies) -> bool: + """ + Convert a single TSGuess-shaped dict from ``rits_script.py`` into an ARC + ``TSGuess`` object, dedup against existing guesses, and append it. + + Dedup uses :func:`arc.species.converter.compare_confs`, which compares + *internal distance matrices* — so it correctly merges two RitS samples + that are the same TS structure in different rigid orientations. This is + a stricter test than the byte-level ``almost_equal_coords`` ARC's older + adapters use; RitS specifically benefits from it because every flow- + matching sample lands the molecule in its own random orientation, so + rotated duplicates are the common case. + + Args: + tsg_dict (dict): One entry from the YAML written by rits_script.py. + local_path (str): The job's local working directory (used by save_geo). + ts_species (ARCSpecies): The reaction's TS species accumulator. + + Returns: + bool: ``True`` if a new (unique, non-colliding) TS guess was appended, + ``False`` otherwise. + """ + if not tsg_dict.get('success') or not tsg_dict.get('initial_xyz'): + return False + try: + ts_xyz = str_to_xyz(tsg_dict['initial_xyz']) + except Exception as e: + logger.warning(f'RitS: could not parse TS xyz: {e}') + return False + if colliding_atoms(ts_xyz): + return False + + # Dedup against every existing TSGuess (regardless of method) using a + # rotation/translation-invariant distance-matrix comparator. If a match + # is found, augment the existing guess's method label instead of + # appending a duplicate. + for other_tsg in ts_species.ts_guesses: + if other_tsg.success and other_tsg.initial_xyz is not None \ + and other_tsg.initial_xyz.get('symbols') == ts_xyz['symbols'] \ + and compare_confs(ts_xyz, other_tsg.initial_xyz): + if 'rits' not in other_tsg.method.lower(): + other_tsg.method += ' and RitS' + return False + + method_index = int(tsg_dict.get('method_index', 0)) + tsg = TSGuess(method='RitS', + method_direction=tsg_dict.get('method_direction', 'F'), + method_index=method_index, + index=len(ts_species.ts_guesses), + success=True, + ) + tsg.process_xyz(ts_xyz) + ts_species.ts_guesses.append(tsg) + save_geo(xyz=ts_xyz, + path=local_path, + filename=f'RitS {method_index}', + format_='xyz', + comment=f'RitS sample {method_index}', + ) + return True + + +def _rits_environment_ready() -> bool: + """ + Check that everything RitS needs at runtime is in place. Logs a clear + one-line warning per missing piece and returns ``False`` so the adapter + can skip cleanly without raising. + """ + ok = True + if not RITS_PYTHON or not os.path.isfile(RITS_PYTHON): + logger.warning('RitS adapter: rits_env python not found ' + '(set RITS_PYTHON or run `make install-rits`). Skipping RitS TS guesses.') + ok = False + if not RITS_REPO_PATH or not os.path.isdir(RITS_REPO_PATH): + logger.warning('RitS adapter: RitS source checkout not found ' + '(set ARC_RITS_REPO or run `make install-rits`). Skipping RitS TS guesses.') + ok = False + if not RITS_CKPT_PATH or not os.path.isfile(RITS_CKPT_PATH): + logger.warning('RitS adapter: pretrained checkpoint not found ' + '(set ARC_RITS_CKPT or run `make install-rits`). Skipping RitS TS guesses.') + ok = False + return ok + + +register_job_adapter('rits', RitSAdapter) diff --git a/arc/main_test.py b/arc/main_test.py index 4034c55cd5..175db38cf1 100644 --- a/arc/main_test.py +++ b/arc/main_test.py @@ -87,6 +87,7 @@ def test_as_dict(self): 'orca': ['local'], 'orca_neb': ['local'], 'qchem': ['server1'], + 'rits': ['local'], 'terachem': ['server1'], 'torchani': ['local'], 'xtb': ['local'], diff --git a/arc/settings/settings.py b/arc/settings/settings.py index 7203ef8a8f..7c3cdbb818 100644 --- a/arc/settings/settings.py +++ b/arc/settings/settings.py @@ -9,6 +9,7 @@ import os import string import sys +from typing import Optional # Users should update the following server dictionary. # Instructions for RSA key generation can be found here: @@ -72,6 +73,7 @@ 'cfour': 'local', 'gaussian': ['local', 'server2'], 'gcn': 'local', + 'rits': 'local', 'mockter': 'local', 'molpro': ['local', 'server2'], 'onedmin': 'server1', @@ -89,7 +91,7 @@ supported_ess = ['cfour', 'gaussian', 'mockter', 'molpro', 'orca', 'qchem', 'terachem', 'onedmin', 'xtb', 'torchani', 'openbabel'] # TS methods to try when appropriate for a reaction (other than user guesses which are always allowed): -ts_adapters = ['heuristics', 'AutoTST', 'GCN', 'xtb_gsm', 'orca_neb'] +ts_adapters = ['heuristics', 'AutoTST', 'GCN', 'RitS', 'xtb_gsm', 'orca_neb'] # List here job types to execute by default default_job_types = {'conf_opt': True, # defaults to True if not specified @@ -172,6 +174,7 @@ output_filenames = {'cfour': 'output.out', 'gaussian': 'input.log', 'gcn': 'output.yml', + 'rits': 'output.yml', 'mockter': 'output.yml', 'molpro': 'input.out', 'onedmin': 'output.out', @@ -321,8 +324,9 @@ ARC_FAMILIES_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), 'data', 'families') # Default environment names for sister repos -TS_GCN_PYTHON, TANI_PYTHON, AUTOTST_PYTHON, ARC_PYTHON, XTB, OB_PYTHON, RMG_PYTHON, RMG_PATH, RMG_DB_PATH = \ - None, None, None, None, None, None, None, None, None +TS_GCN_PYTHON, TANI_PYTHON, AUTOTST_PYTHON, RITS_PYTHON, RITS_REPO_PATH, RITS_CKPT_PATH, \ + ARC_PYTHON, XTB, OB_PYTHON, RMG_PYTHON, RMG_PATH, RMG_DB_PATH = \ + None, None, None, None, None, None, None, None, None, None, None, None home = os.getenv("HOME") or os.path.expanduser("~") @@ -362,11 +366,72 @@ def find_executable(env_name, executable_name='python'): OB_PYTHON = find_executable('ob_env') TS_GCN_PYTHON = find_executable('ts_gcn') AUTOTST_PYTHON = find_executable('tst_env') +RITS_PYTHON = find_executable('rits_env') ARC_PYTHON = find_executable('arc_env') RMG_ENV_NAME = 'rmg_env' RMG_PYTHON = find_executable('rmg_env') XTB = find_executable('xtb_env', 'xtb') + +def find_rits_repo() -> Optional[str]: + """ + Locate a RitS source checkout. Used by the RitS TS adapter to find + 'scripts/sample_transition_state.py' and 'scripts/conf/rits.yaml', + which are not part of the importable 'megalodon' package. + + Search order: + 1. ``ARC_RITS_REPO`` environment variable (explicit override). + 2. ``~/Code/RitS`` (default for ARC dev machines). + 3. Sibling-of-ARC location ``/RitS`` — + matches what ``devtools/install_rits.sh`` produces. + + Returns: + Optional[str]: Absolute path to the repo root, or ``None`` if + nothing was found. The repo is considered "found" only if it + contains ``scripts/sample_transition_state.py``. + """ + candidates = list() + env_override = os.getenv('ARC_RITS_REPO') + if env_override: + candidates.append(env_override) + candidates.append(os.path.join(home, 'Code', 'RitS')) + arc_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + candidates.append(os.path.join(os.path.dirname(arc_root), 'RitS')) + for path in candidates: + if path and os.path.isfile(os.path.join(path, 'scripts', 'sample_transition_state.py')): + return os.path.abspath(path) + return None + + +def find_rits_ckpt(repo_path: Optional[str] = None) -> Optional[str]: + """ + Locate the pretrained RitS checkpoint file ('rits.ckpt'). + + Search order: + 1. ``ARC_RITS_CKPT`` environment variable (explicit override). + 2. ``/data/rits.ckpt`` — what ``install_rits.sh`` writes. + + Args: + repo_path (Optional[str]): The RitS repo path returned by + ``find_rits_repo()``. If ``None``, only the env-var override + is consulted. + + Returns: + Optional[str]: Absolute path to the checkpoint, or ``None``. + """ + env_override = os.getenv('ARC_RITS_CKPT') + if env_override and os.path.isfile(env_override): + return os.path.abspath(env_override) + if repo_path: + candidate = os.path.join(repo_path, 'data', 'rits.ckpt') + if os.path.isfile(candidate): + return os.path.abspath(candidate) + return None + + +RITS_REPO_PATH = find_rits_repo() +RITS_CKPT_PATH = find_rits_ckpt(RITS_REPO_PATH) + # Set RMG_DB_PATH with fallback methods rmg_db_candidates, rmg_candidates = list(), list() diff --git a/devtools/install_all.sh b/devtools/install_all.sh index c958fdd548..c324c7cf26 100644 --- a/devtools/install_all.sh +++ b/devtools/install_all.sh @@ -26,6 +26,7 @@ run_devtool () { bash "$DEVTOOLS_DIR/$1" "${@:2}"; } SKIP_CLEAN=false SKIP_EXT=false SKIP_ARC=false +SKIP_RITS=false RMG_ARGS=() ARC_ARGS=() EXT_ARGS=() @@ -36,6 +37,7 @@ while [[ $# -gt 0 ]]; do --no-clean) SKIP_CLEAN=true ;; --no-ext) SKIP_EXT=true ;; --no-arc) SKIP_ARC=true ;; + --no-rits) SKIP_RITS=true ;; --rmg-*) RMG_ARGS+=("--${1#--rmg-}") ;; --arc-*) ARC_ARGS+=("--${1#--arc-}") ;; --ext-*) EXT_ARGS+=("--${1#--ext-}") ;; @@ -44,6 +46,7 @@ while [[ $# -gt 0 ]]; do Usage: $0 [global-flags] [--rmg-xxx] [--arc-yyy] [--ext-zzz] --no-clean Skip micromamba/conda cache cleanup --no-ext Skip external tools (AutoTST, KinBot, …) + --no-rits Skip the RitS installer (heavy ML stack — usually run in its own CI lane) --rmg-path Forward '--path' to RMG installer --rmg-pip Forward '--pip' to RMG installer ... @@ -102,8 +105,15 @@ if [[ $SKIP_EXT == false ]]; then [xtb]=install_xtb.sh [Sella]=install_sella.sh [TorchANI]=install_torchani.sh + [RitS]=install_rits.sh ) + # Optionally drop RitS — used by `make install-ci` since CI runs RitS in its own lane + if [[ $SKIP_RITS == true ]]; then + unset 'EXT_INSTALLERS[RitS]' + echo "ℹ️ --no-rits: skipping RitS installer (run `make install-rits` or the rits CI lane separately)" + fi + # installer-specific flag whitelists declare -A EXT_FLAG_WHITELIST=( [install_gcn.sh]="--conda" diff --git a/devtools/install_rits.sh b/devtools/install_rits.sh new file mode 100755 index 0000000000..3a2e867e65 --- /dev/null +++ b/devtools/install_rits.sh @@ -0,0 +1,315 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ── defaults ─────────────────────────────────────────────────────────────── +RITS_REPO_URL="https://github.com/isayevlab/RitS.git" +RITS_ENV_NAME="rits_env" +FORCE_CPU=false +RITS_PATH="" +SKIP_CKPT=false +CUDA_VARIANT="" # one of: cpu, cu118, cu121, cu124, cu126 (empty → autodetect) +TORCH_VERSION="2.7.0" # must match RitS's pinned torch version + +# Pretrained checkpoint mirror (Dana Research Group, Zenodo) +# Google Drive checkpoint file source: https://drive.google.com/drive/folders/1DD2hmWx3E1klM3Ljon5r4gdquGoN_4v6 +# Source paper: https://github.com/isayevlab/RitS, 10.26434/chemrxiv.15001681/v1 +# Mirror DOI : https://doi.org/10.5281/zenodo.19474153 +RITS_CKPT_URL="https://zenodo.org/records/19474153/files/rits.ckpt?download=1" +RITS_CKPT_MD5="884121fcf7a5bfcfb826b7d5e28d379a" + +# ── parse flags ──────────────────────────────────────────────────────────── +TEMP=$(getopt -o h --long cpu,cuda:,path:,no-ckpt,help -- "$@") +eval set -- "$TEMP" +while true; do + case "$1" in + --cpu) + FORCE_CPU=true + shift + ;; + --cuda) + CUDA_VARIANT="$2" + shift 2 + ;; + --path) + RITS_PATH="$2" + shift 2 + ;; + --no-ckpt) + SKIP_CKPT=true + shift + ;; + -h|--help) + cat <] [--path ] [--no-ckpt] [--help] + + --cpu force a CPU-only PyTorch install (shortcut for --cuda cpu) + --cuda pick a specific PyG wheel variant: cpu, cu118, cu121, cu124, cu126 + (default: autodetect via nvcc / nvidia-smi) + --path use an existing RitS checkout instead of cloning + --no-ckpt skip the pretrained checkpoint download (offline installs) + -h this help + +By default the script clones (or updates) RitS as a sibling of the ARC repo, +creates the '${RITS_ENV_NAME}' conda env with python=3.10, autodetects the +host CUDA version, installs torch=${TORCH_VERSION} + matching PyTorch Geometric +companion wheels (torch-scatter / torch-sparse / torch-cluster / +torch-spline-conv / pyg-lib) from PyG's wheel index, runs 'pip install -e .' +so that 'import megalodon' works inside the env, and downloads + verifies the +pretrained 'rits.ckpt' from Zenodo +(${RITS_CKPT_URL%%\?*}). + +No training is required — RitS ships pretrained weights. +EOF + exit 0 + ;; + --) shift; break ;; + *) echo "Invalid flag: $1" >&2; exit 1 ;; + esac +done + +# ── pick a CUDA variant for the PyG wheels ─────────────────────────────── +# PyG publishes wheels for torch ${TORCH_VERSION} against these variants only: +SUPPORTED_VARIANTS=(cpu cu118 cu121 cu124 cu126) + +map_cuda_to_variant() { # X.Y → cu118|cu121|cu124|cu126|cpu + local ver="$1" + local major minor + major=${ver%%.*} + minor=${ver#*.} + minor=${minor%%.*} + if [[ -z "$major" || -z "$minor" ]]; then echo cpu; return; fi + if (( major > 12 )) || { (( major == 12 )) && (( minor >= 6 )); }; then echo cu126 + elif (( major == 12 )) && (( minor >= 4 )); then echo cu124 + elif (( major == 12 )) && (( minor >= 1 )); then echo cu121 + elif { (( major == 12 )) && (( minor == 0 )); } || \ + { (( major == 11 )) && (( minor >= 8 )); }; then echo cu118 + else echo cpu + fi +} + +if [[ -n "$CUDA_VARIANT" ]]; then + if $FORCE_CPU && [[ "$CUDA_VARIANT" != cpu ]]; then + echo "❌ --cpu and --cuda $CUDA_VARIANT are contradictory" >&2 + exit 1 + fi + # validate against the supported set + if ! printf '%s\n' "${SUPPORTED_VARIANTS[@]}" | grep -qx "$CUDA_VARIANT"; then + echo "❌ Unsupported --cuda variant: $CUDA_VARIANT" >&2 + echo " Supported: ${SUPPORTED_VARIANTS[*]}" >&2 + exit 1 + fi +elif $FORCE_CPU; then + CUDA_VARIANT="cpu" +elif command -v nvcc &>/dev/null; then + VER=$(nvcc --version | grep -oP "release \K[0-9]+\.[0-9]+" | head -n1) + CUDA_VARIANT=$(map_cuda_to_variant "$VER") + echo "🔍 nvcc reports CUDA $VER → using PyG variant '$CUDA_VARIANT'" +elif command -v nvidia-smi &>/dev/null; then + # The 'CUDA Version' field in nvidia-smi is the *driver's max supported* CUDA, which is the + # right ceiling for binary wheel compatibility (not driver_version, which is a different number). + VER=$(nvidia-smi 2>/dev/null | grep -oP "CUDA Version: \K[0-9]+\.[0-9]+" | head -n1 || true) + if [[ -n "$VER" ]]; then + CUDA_VARIANT=$(map_cuda_to_variant "$VER") + echo "🔍 nvidia-smi reports max CUDA $VER → using PyG variant '$CUDA_VARIANT'" + else + CUDA_VARIANT="cpu" + echo "🔍 Could not parse CUDA version from nvidia-smi → falling back to CPU" + fi +else + CUDA_VARIANT="cpu" + echo "🔍 No nvcc / nvidia-smi found → falling back to CPU" +fi +echo "→ PyG wheel variant: $CUDA_VARIANT" + +# ── locate ARC repo and the sibling clone root ──────────────────────────── +SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)" +if ARC_ROOT=$(git -C "$SCRIPT_DIR" rev-parse --show-toplevel 2>/dev/null); then + : +else + ARC_ROOT=$(cd "$SCRIPT_DIR/.." && pwd) +fi +CLONE_ROOT="$(dirname "$ARC_ROOT")" +echo "📂 ARC root : $ARC_ROOT" +echo "📂 Clone root: $CLONE_ROOT" + +# ── pick a conda frontend ───────────────────────────────────────────────── +if command -v micromamba &>/dev/null; then + COMMAND_PKG=micromamba +elif command -v mamba &>/dev/null; then + COMMAND_PKG=mamba +elif command -v conda &>/dev/null; then + COMMAND_PKG=conda +else + echo "❌ No micromamba/mamba/conda found in PATH" >&2 + exit 1 +fi +echo "✔️ Using $COMMAND_PKG" + +# Initialize shell integration so 'activate' works in this script +if [[ $COMMAND_PKG == micromamba ]]; then + eval "$(micromamba shell hook --shell=bash)" +else + BASE=$(conda info --base) + source "$BASE/etc/profile.d/conda.sh" +fi + +# ── clone or update RitS ────────────────────────────────────────────────── +if [[ -n "$RITS_PATH" ]]; then + if [[ ! -d "$RITS_PATH" ]]; then + echo "❌ --path was given but directory does not exist: $RITS_PATH" >&2 + exit 1 + fi + RITS_DIR="$(cd "$RITS_PATH" && pwd)" + echo "📂 Using existing RitS checkout at: $RITS_DIR" +else + RITS_DIR="$CLONE_ROOT/RitS" + if [[ -d "$RITS_DIR/.git" ]]; then + echo "🔄 Updating existing RitS clone at $RITS_DIR" + git -C "$RITS_DIR" fetch origin + git -C "$RITS_DIR" pull --ff-only || echo "⚠️ Could not fast-forward; leaving working tree as-is." + else + echo "⬇️ Cloning RitS into $RITS_DIR" + git clone "$RITS_REPO_URL" "$RITS_DIR" + fi +fi + +# ── create / update the rits_env conda environment ─────────────────────── +if $COMMAND_PKG env list | awk '{print $1}' | grep -qx "$RITS_ENV_NAME"; then + echo "♻️ '$RITS_ENV_NAME' already exists — updating in place." +else + echo "🆕 Creating '$RITS_ENV_NAME' (python=3.10)" + $COMMAND_PKG create -n "$RITS_ENV_NAME" -c conda-forge python=3.10 -y +fi + +set +u; $COMMAND_PKG activate "$RITS_ENV_NAME"; set -u + +# RDKit & OpenBabel are far smoother via conda-forge than pip +echo "📦 Installing rdkit + openbabel from conda-forge" +$COMMAND_PKG install -n "$RITS_ENV_NAME" -c conda-forge -y \ + "rdkit=2025.3.2" openbabel + +# Install PyTorch + PyTorch Geometric companion wheels for the chosen variant. +# We deliberately do NOT use RitS's requirements.txt because it pins +pt27cu126 +# specifically — we install the variant-matched companion wheels instead so the +# install works on CPU runners and on GPUs with CUDA != 12.6. +python -m pip install --upgrade pip + +if [[ "$CUDA_VARIANT" == "cpu" ]]; then + TORCH_INDEX="https://download.pytorch.org/whl/cpu" +else + TORCH_INDEX="https://download.pytorch.org/whl/${CUDA_VARIANT}" +fi +PYG_WHEELS="https://data.pyg.org/whl/torch-${TORCH_VERSION}+${CUDA_VARIANT}.html" + +echo "🚀 Installing torch==${TORCH_VERSION} (${CUDA_VARIANT}) from $TORCH_INDEX" +python -m pip install "torch==${TORCH_VERSION}" --index-url "$TORCH_INDEX" + +echo "🧮 Installing PyG companion wheels from $PYG_WHEELS" +# --only-binary :all: forces wheels, never source builds (those would need a CUDA toolkit) +python -m pip install \ + pyg-lib torch-scatter torch-sparse torch-cluster torch-spline-conv \ + --only-binary :all: -f "$PYG_WHEELS" + +echo "📦 Installing pure-Python megalodon dependencies from PyPI" +python -m pip install \ + "torch_geometric==2.6.1" \ + "hydra-core==1.3.2" \ + "lightning==2.5.1.post0" \ + "einops==0.8.1" \ + "wandb==0.19.11" \ + "pandas==2.2.3" \ + "tqdm==4.67.1" + +# Editable install of the megalodon package (this is what puts 'import megalodon' on path) +echo "🧷 pip install -e . (megalodon, src layout)" +python -m pip install -e "$RITS_DIR" + +# Sanity check — import megalodon AND the PyG companions, since a successful pip +# install does not guarantee the .so files actually load against the host's CUDA. +echo "🔍 Verifying inference stack inside $RITS_ENV_NAME" +python - <<'PYEOF' +import importlib, sys +mods = ["torch", "torch_geometric", "torch_scatter", "torch_sparse", + "torch_cluster", "torch_spline_conv", "megalodon"] +for m in mods: + try: + mod = importlib.import_module(m) + ver = getattr(mod, "__version__", "?") + print(f" ✔️ {m:<22} {ver}") + except Exception as e: + print(f" ❌ {m:<22} FAILED: {e}", file=sys.stderr) + sys.exit(1) +import torch +print(f" ℹ️ torch.cuda.is_available() = {torch.cuda.is_available()}") +PYEOF + +set +u; $COMMAND_PKG deactivate; set -u + +# ── download + verify pretrained checkpoint ────────────────────────────── +CKPT_DIR="$RITS_DIR/data" +CKPT_PATH="$CKPT_DIR/rits.ckpt" + +verify_md5() { # path expected_md5 + local path="$1" expected="$2" + local actual + if command -v md5sum &>/dev/null; then + actual=$(md5sum "$path" | awk '{print $1}') + elif command -v md5 &>/dev/null; then + actual=$(md5 -q "$path") + else + echo "❌ Neither md5sum nor md5 found in PATH; cannot verify checkpoint." >&2 + return 2 + fi + if [[ "$actual" != "$expected" ]]; then + echo "❌ Checksum mismatch for $path" >&2 + echo " expected: $expected" >&2 + echo " actual : $actual" >&2 + return 1 + fi + return 0 +} + +if $SKIP_CKPT; then + echo "ℹ️ --no-ckpt set, skipping checkpoint download." +elif [[ -f "$CKPT_PATH" ]]; then + echo "📦 Existing checkpoint found at $CKPT_PATH — verifying MD5..." + if verify_md5 "$CKPT_PATH" "$RITS_CKPT_MD5"; then + echo "✔️ Checkpoint MD5 OK ($RITS_CKPT_MD5)" + else + echo "❌ Existing checkpoint does not match the expected MD5." >&2 + echo " Refusing to overwrite — move it aside or delete it and re-run." >&2 + exit 1 + fi +else + mkdir -p "$CKPT_DIR" + if ! command -v curl &>/dev/null; then + echo "❌ curl is required to download the RitS checkpoint." >&2 + exit 1 + fi + echo "⬇️ Downloading rits.ckpt (~364 MB) from Zenodo:" + echo " $RITS_CKPT_URL" + TMP_CKPT="$(mktemp "${CKPT_DIR}/rits.ckpt.XXXXXX")" + if ! curl -fL --retry 3 --retry-delay 5 -o "$TMP_CKPT" "$RITS_CKPT_URL"; then + rm -f "$TMP_CKPT" + echo "❌ Download failed. Re-run the install, or pass --no-ckpt to skip." >&2 + exit 1 + fi + if verify_md5 "$TMP_CKPT" "$RITS_CKPT_MD5"; then + mv "$TMP_CKPT" "$CKPT_PATH" + echo "✔️ Checkpoint verified and saved to $CKPT_PATH" + else + rm -f "$TMP_CKPT" + echo "❌ Downloaded checkpoint failed MD5 verification — aborting." >&2 + exit 1 + fi +fi + +# ── final notes ─────────────────────────────────────────────────────────── +echo "" +echo "✅ RitS installation complete." +echo " Repo : $RITS_DIR" +echo " Env : $RITS_ENV_NAME" +echo " Ckpt : $([[ -f $CKPT_PATH ]] && echo $CKPT_PATH || echo '(not installed)')" +echo "" +echo " Mirror DOI : https://doi.org/10.5281/zenodo.19474153" +echo " Source : https://github.com/isayevlab/RitS"