diff --git a/.github/workflows/jqmc-deploy-gh-pages.yml b/.github/workflows/jqmc-deploy-gh-pages.yml index 9d7db4ca..749f22b0 100644 --- a/.github/workflows/jqmc-deploy-gh-pages.yml +++ b/.github/workflows/jqmc-deploy-gh-pages.yml @@ -1,4 +1,4 @@ -name: publish jqmc gh-pages +name: Publish jqmc gh-pages workflow on: push: @@ -9,6 +9,10 @@ permissions: jobs: docs: + name: publish jqmc gh-pages + + if: github.repository == 'jqmc-project/jQMC' + runs-on: ubuntu-latest defaults: run: diff --git a/.github/workflows/jqmc-deploy-test.yml b/.github/workflows/jqmc-deploy-test.yml index e267c69e..2b58b5d4 100644 --- a/.github/workflows/jqmc-deploy-test.yml +++ b/.github/workflows/jqmc-deploy-test.yml @@ -1,10 +1,11 @@ -name: Publish jQMC distributions to test-PyPI +name: publish jQMC to test-PyPI workflow on: push: branches: [ "rc" ] jobs: deploy-test-pypi: + name: publish jqmc to test-pypi if: github.repository == 'jqmc-project/jQMC' runs-on: ubuntu-latest diff --git a/.github/workflows/jqmc-deploy.yml b/.github/workflows/jqmc-deploy.yml index c9eafe8f..af559863 100644 --- a/.github/workflows/jqmc-deploy.yml +++ b/.github/workflows/jqmc-deploy.yml @@ -1,4 +1,4 @@ -name: Publish jQMC distributions to PyPI +name: Publish jQMC to PyPI workflow # Trigger only when tags that start with "v" are pushed (e.g., v0.1.0) on: @@ -9,6 +9,7 @@ on: jobs: # validate tag validate_tag: + name: validate jqmc tag # Run only if this repository is rc if: startsWith(github.ref, 'refs/tags/v') && github.repository == 'jqmc-project/jQMC' runs-on: ubuntu-latest @@ -81,6 +82,7 @@ jobs: # deploy deploy-pypi: + name: publish jqmc to pypi # Run only if this repository is rc needs: validate_tag if: startsWith(github.ref, 'refs/tags/v') && github.repository == 'jqmc-project/jQMC' diff --git a/.github/workflows/jqmc-lint-ruff.yml b/.github/workflows/jqmc-lint-ruff.yml index 8c65f6f2..db6c2252 100644 --- a/.github/workflows/jqmc-lint-ruff.yml +++ b/.github/workflows/jqmc-lint-ruff.yml @@ -7,7 +7,7 @@ # To enforce additional ruff rules, fix the violations listed in # `lint.extend-ignore` in pyproject.toml and remove them from that list. -name: jqmc lint (ruff + pre-commit) +name: jqmc lint (ruff + pre-commit) workflow on: push: @@ -17,6 +17,7 @@ on: jobs: lint: + name: jqmc lint (ruff + pre-commit) runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 diff --git a/.github/workflows/jqmc-run-full-pytest-ubuntu.yml b/.github/workflows/jqmc-run-full-pytest-ubuntu.yml new file mode 100644 index 00000000..ba2a4e61 --- /dev/null +++ b/.github/workflows/jqmc-run-full-pytest-ubuntu.yml @@ -0,0 +1,136 @@ +# A full manual test of jqmc on a self-hosted runner. + +name: jqmc full test workflow (ubuntu) + +on: + workflow_dispatch: + +permissions: + contents: read + +jobs: + run: + name: jqmc full test + Codecov / Python ${{ matrix.python-version }} + + if: github.repository == 'jqmc-project/jQMC' + + runs-on: + group: jqmc-nightly-runners-ubuntu + labels: [self-hosted, Linux, X64] + + strategy: + fail-fast: false + matrix: + python-version: ["3.10.14", "3.11.9", "3.12.3"] + + timeout-minutes: 1440 + + steps: + - name: Show runner information + run: | + hostname + uname -a + cat /etc/os-release + + - uses: actions/checkout@v4 + + - name: Select Python + run: | + PY="$HOME/.pyenv/versions/${{ matrix.python-version }}/bin" + echo "$PY" >> "$GITHUB_PATH" + "$PY/python" --version + + - name: Install pytest, pytest-xdist, pytest-cov + run: | + python -m pip install --upgrade pip + python -m pip install pytest pytest-xdist pytest-cov + + - name: Install jqmc + run: | + python -m pip install . + + - name: Test jqmc FP64 (Intra-software comparisons) + run: | + pytest -n 8 -v tests/test_trexio.py --cov=jqmc --cov-branch --no-cov-on-fail + pytest -n 8 -v tests/test_init_electron_configurations.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_structure.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_AOs.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_MOs.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_determinant.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_jastrow.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_wave_function.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_ecps.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_swct.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_mcmc_force.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_lrdmc_force.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_checkpoint_components.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_checkpoint_mcmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_checkpoint_gfmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_ao_basis_optimization.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_mixed_precision.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + + - name: Test jqmc FP32+FP64 (Intra-software comparisons) + run: | + pytest -n 8 -v tests/test_trexio.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_init_electron_configurations.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_structure.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_AOs.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_MOs.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_determinant.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_jastrow.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_wave_function.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_ecps.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_swct.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_mcmc_force.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_lrdmc_force.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_checkpoint_components.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_checkpoint_mcmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_checkpoint_gfmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_ao_basis_optimization.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + pytest -n 8 -v tests/test_mixed_precision.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed + + - name: Test jqmc FP64 (Inter-software comparisons) + run: | + pytest -n 8 -v tests/test_comparison_with_turborvb_ECP.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_comparison_with_turborvb_AE.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + + - name: Test jqmc FP64 (QMC kernels without MPI, FP64) + run: | + pytest -n 8 -v tests/test_jqmc_command_lines.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_jqmc_mcmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_jqmc_gfmc_tau.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_jqmc_gfmc_bra.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + + - name: Test jqmc FP32+FP64 (QMC kernels without MPI, FP32+FP64) + run: | + pytest -n 8 -v tests/test_jqmc_command_lines.py --precision-mode=mixed --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_jqmc_mcmc.py --precision-mode=mixed --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_jqmc_gfmc_tau.py --precision-mode=mixed --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + pytest -n 8 -v tests/test_jqmc_gfmc_bra.py --precision-mode=mixed --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + + - name: Test jqmc FP64 (QMC kernels with 2MPIs, FP64) + run: | + mpirun -np 2 pytest -v tests/test_jqmc_mcmc.py + mpirun -np 2 pytest -v tests/test_jqmc_gfmc_tau.py + mpirun -np 2 pytest -v tests/test_jqmc_gfmc_bra.py + + - name: Test jqmc FP32+FP64 (QMC kernels with 2MPIs, FP32+FP64) + run: | + mpirun -np 2 pytest -v tests/test_jqmc_mcmc.py --precision-mode=mixed + mpirun -np 2 pytest -v tests/test_jqmc_gfmc_tau.py --precision-mode=mixed + mpirun -np 2 pytest -v tests/test_jqmc_gfmc_bra.py --precision-mode=mixed + + - name: Test jqmc-tool (Toolset for jqmc, FP64) + run: | + pytest -n 8 -v tests/test_jqmc_tool.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append + + - name: Combine Codecov xml files + run: | + python -m coverage xml -o coverage.xml + + - name: Upload coverage reports to Codecov + if: matrix.python-version == '3.12.3' + uses: codecov/codecov-action@v5 + with: + files: coverage.xml + token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/jqmc-run-full-pytest.yml b/.github/workflows/jqmc-run-full-pytest.yml deleted file mode 100644 index 2744aa70..00000000 --- a/.github/workflows/jqmc-run-full-pytest.yml +++ /dev/null @@ -1,118 +0,0 @@ -# A full test of jqmc. - -name: jqmc full test - -on: - push: - branches: [ "main" ] - paths-ignore: - - '.gitignore' - - '.github/**' - - 'doc/**' - - 'examples/**' - - 'benchmarks/**' - - 'README.md' - - '.pre-commit-config.yaml' - - 'jqmc_workflow/**' - -jobs: - run: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.10", "3.11", "3.12"] - - steps: - - name: Install gfortran and gcc - run: | - sudo apt-get update - sudo apt-get install gfortran - - - name: Install OpenBLAS and LAPACK - run: sudo apt-get install libopenblas-dev liblapack-dev - - - name: Install OpenMPI - run: sudo apt-get install openmpi-bin libopenmpi-dev - - - uses: actions/checkout@v3 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - - name: Install jqmc - run: | - python -m pip install flake8 pytest pytest-cov - python -m pip install . - - - name: Test jqmc FP64 (Intra-software comparisons) - run: | - pytest -s -v tests/test_trexio.py --cov=jqmc --cov-branch --no-cov-on-fail - pytest -s -v tests/test_init_electron_configurations.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_structure.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_AOs.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_MOs.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_determinant.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_jastrow.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_wave_function.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_ecps.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_swct.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_mcmc_force.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_lrdmc_force.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_checkpoint_components.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_checkpoint_mcmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_checkpoint_gfmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_ao_basis_optimization.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - # Skipped under full mode: - # pytest -s -v tests/test_mixed_precision.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - - - name: Test jqmc FP32+FP64 (Intra-software comparisons) - run: | - # Skipped under mixed mode: precision-insensitive (coverage already obtained in FP64 block). - # pytest -s -v tests/test_trexio.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - # pytest -s -v tests/test_init_electron_configurations.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - # pytest -s -v tests/test_structure.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_AOs.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_MOs.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_determinant.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_jastrow.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_wave_function.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_ecps.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_swct.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_mcmc_force.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_lrdmc_force.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - # Skipped under mixed mode: HDF5 roundtrip tests, precision-insensitive (coverage already obtained in FP64 block). - # pytest -s -v tests/test_checkpoint_components.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - # pytest -s -v tests/test_checkpoint_mcmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - # pytest -s -v tests/test_checkpoint_gfmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_ao_basis_optimization.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - pytest -s -v tests/test_mixed_precision.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append --precision-mode=mixed - - - name: Test jqmc FP64 (Inter-software comparisons) - run: | - pytest -s -v tests/test_comparison_with_turborvb_ECP.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_comparison_with_turborvb_AE.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - - - name: Test jqmc FP64 (QMC kernels without MPI) - run: | - pytest -s -v tests/test_jqmc_command_lines.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_jqmc_mcmc.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_jqmc_gfmc_tau.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - pytest -s -v tests/test_jqmc_gfmc_bra.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - - - name: Test jqmc-tool (Toolset for jqmc) - run: | - pytest -s -v tests/test_jqmc_tool.py --cov=jqmc --cov-branch --no-cov-on-fail --cov-append - - - name: Combine Codecov xml files - run: | - coverage xml -o coverage.xml - - - name: Upload coverage reports to Codecov - if: matrix.python-version == '3.12' - uses: codecov/codecov-action@v5 - with: - files: coverage.xml - token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/jqmc-run-long-pytest.yml b/.github/workflows/jqmc-run-long-pytest.yml index 99248a8a..24229ce1 100644 --- a/.github/workflows/jqmc-run-long-pytest.yml +++ b/.github/workflows/jqmc-run-long-pytest.yml @@ -1,6 +1,6 @@ # A long test of jqmc. -name: jqmc long test +name: jqmc long test workflow on: pull_request: @@ -17,6 +17,7 @@ on: jobs: run: + name: jqmc long test / Python ${{ matrix.python-version }} runs-on: ubuntu-latest strategy: fail-fast: false @@ -44,7 +45,7 @@ jobs: - name: Install jqmc run: | - python -m pip install flake8 pytest pytest-cov + python -m pip install pytest pytest-cov python -m pip install . - name: Test jqmc FP64/FP32+FP64 (Intra-software comparisons) @@ -76,13 +77,11 @@ jobs: pytest -s -v tests/test_comparison_with_turborvb_ECP.py pytest -s -v tests/test_comparison_with_turborvb_AE.py - - name: Test jqmc FP64/FP32+FP64 (QMC kernels without MPI) + - name: Test jqmc FP64 (QMC kernels without MPI) run: | pytest -s -v tests/test_jqmc_command_lines.py pytest -s -v tests/test_jqmc_mcmc.py - pytest -s -v tests/test_jqmc_mcmc.py --precision-mode=mixed pytest -s -v tests/test_jqmc_gfmc_tau.py - pytest -s -v tests/test_jqmc_gfmc_bra.py --precision-mode=mixed - name: Test jqmc-tool (toolset for jqmc) run: | diff --git a/.github/workflows/jqmc-run-rc-full-precision-pytest.yml b/.github/workflows/jqmc-run-rc-full-precision-pytest.yml deleted file mode 100644 index 07e43316..00000000 --- a/.github/workflows/jqmc-run-rc-full-precision-pytest.yml +++ /dev/null @@ -1,73 +0,0 @@ -# An rc test of jqmc. - -name: jqmc rc test - -on: - pull_request: - branches: [ "rc" ] - paths-ignore: - - '.gitignore' - - '.github/**' - - 'doc/**' - - 'examples/**' - - 'benchmarks/**' - - 'README.md' - - '.pre-commit-config.yaml' - - 'jqmc_workflow/**' - -jobs: - run: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.10", "3.11", "3.12"] - - steps: - - name: Install gfortran and gcc - run: | - sudo apt-get update - sudo apt-get install gfortran - - - name: Install OpenBLAS and LAPACK - run: sudo apt-get install libopenblas-dev liblapack-dev - - - name: Install OpenMPI - run: sudo apt-get install openmpi-bin libopenmpi-dev - - - uses: actions/checkout@v3 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - - name: Install jqmc - run: | - python -m pip install flake8 pytest pytest-cov - python -m pip install . - - - name: Test jqmc command-line - run: | - pytest -s -v tests/test_jqmc_command_lines.py - - - name: Test jqmc FP64 (Inter-software comparisons) - run: | - pytest -s -v tests/test_comparison_with_turborvb_ECP.py - pytest -s -v tests/test_comparison_with_turborvb_AE.py - - - name: Test jqmc FP64 (QMC kernels without MPI, FP64) - run: | - pytest -s -v tests/test_jqmc_mcmc.py - pytest -s -v tests/test_jqmc_gfmc_tau.py - pytest -s -v tests/test_jqmc_gfmc_bra.py - - - name: Test jqmc FP64 (QMC kernels with 2MPIs, FP64) - run: | - mpirun -np 2 pytest -s -v tests/test_jqmc_mcmc.py - mpirun -np 2 pytest -s -v tests/test_jqmc_gfmc_tau.py - mpirun -np 2 pytest -s -v tests/test_jqmc_gfmc_bra.py - - - name: Test jqmc-tool (toolset for jqmc) - run: | - pytest -s -v tests/test_jqmc_tool.py diff --git a/.github/workflows/jqmc-run-rc-mixed-precision-pytest.yml b/.github/workflows/jqmc-run-rc-mixed-precision-pytest.yml deleted file mode 100644 index c45a2974..00000000 --- a/.github/workflows/jqmc-run-rc-mixed-precision-pytest.yml +++ /dev/null @@ -1,68 +0,0 @@ -# An rc test of jqmc. - -name: jqmc rc test - -on: - pull_request: - branches: [ "rc" ] - paths-ignore: - - '.gitignore' - - '.github/**' - - 'doc/**' - - 'examples/**' - - 'benchmarks/**' - - 'README.md' - - '.pre-commit-config.yaml' - - 'jqmc_workflow/**' - -jobs: - run: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.10", "3.11", "3.12"] - - steps: - - name: Install gfortran and gcc - run: | - sudo apt-get update - sudo apt-get install gfortran - - - name: Install OpenBLAS and LAPACK - run: sudo apt-get install libopenblas-dev liblapack-dev - - - name: Install OpenMPI - run: sudo apt-get install openmpi-bin libopenmpi-dev - - - uses: actions/checkout@v3 - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - - name: Install jqmc - run: | - python -m pip install flake8 pytest pytest-cov - python -m pip install . - - - name: Test jqmc command-line - run: | - pytest -s -v tests/test_jqmc_command_lines.py - - - name: Test jqmc FP32+FP64 (QMC kernels without MPI, FP32+FP64) - run: | - pytest -s -v tests/test_jqmc_mcmc.py --precision-mode=mixed - pytest -s -v tests/test_jqmc_gfmc_tau.py --precision-mode=mixed - pytest -s -v tests/test_jqmc_gfmc_bra.py --precision-mode=mixed - - - name: Test jqmc FP32+FP64 (QMC kernels with 2MPIs, FP32+FP64) - run: | - mpirun -np 2 pytest -s -v tests/test_jqmc_mcmc.py --precision-mode=mixed - mpirun -np 2 pytest -s -v tests/test_jqmc_gfmc_tau.py --precision-mode=mixed - mpirun -np 2 pytest -s -v tests/test_jqmc_gfmc_bra.py --precision-mode=mixed - - - name: Test jqmc-tool (toolset for jqmc) - run: | - pytest -s -v tests/test_jqmc_tool.py diff --git a/.github/workflows/jqmc-run-short-pytest.yml b/.github/workflows/jqmc-run-short-pytest.yml index 55877b9d..9cd14c54 100644 --- a/.github/workflows/jqmc-run-short-pytest.yml +++ b/.github/workflows/jqmc-run-short-pytest.yml @@ -1,6 +1,6 @@ # A short test jqmc. -name: jqmc short test +name: jqmc short test workflow on: push: @@ -28,6 +28,7 @@ on: jobs: run: + name: jqmc short test / Python ${{ matrix.python-version }} runs-on: ubuntu-latest strategy: fail-fast: false diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 04a2d812..067339fe 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,7 +14,7 @@ repos: - id: check-added-large-files - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.15.12 + rev: v0.15.13 hooks: - id: ruff name: ruff (ambiguous unicode only) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4c215da6..325bbfc8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -54,7 +54,7 @@ We are willing to sacrifice some computational speed to achieve these goals. To * Submit a **Pull Request** (PR). * Upon PR creation or update, GitHub Actions will run the test suite. -* If all tests pass, @kousuke-nakano (a main maintainer) will review your changes. +* If all tests pass, @kousuke-nakano or another maintainer of @jqmc-project will review your changes. * Once approved, your PR will be merged into `main`. --- @@ -68,13 +68,13 @@ We are willing to sacrifice some computational speed to achieve these goals. To * `scipy` * `jax` * `flax` -* Other third-party packages should be avoided unless absolutely necessary. Any new dependency must be approved by @kousuke-nakano. +* Other third-party packages should be avoided unless absolutely necessary. Any new dependency must be approved by @kousuke-nakano or another maintainer of @jqmc-project. --- ### Release Process -* All official package releases are performed by @kousuke-nakano as needed. +* All official package releases are performed by @kousuke-nakano or another maintainer of @jqmc-project as needed. --- diff --git a/README.md b/README.md index 45549c8d..a2c59b3d 100644 --- a/README.md +++ b/README.md @@ -2,15 +2,15 @@ ![jqmc_logo](logo/logo_yoko2.jpg) -**jQMC** is an ab initio quantum Monte Carlo (QMC) simulation package developed entirely from scratch using `Python` and `JAX`. Originally designed for molecular systems --with future extensions planned for periodic systems-- **jQMC** implements two well-established QMC algorithms: Variational Monte Carlo (VMC) and a robust and efficient variant of Diffusion Monte Carlo algorithm known as Lattice Regularized Diffusion Monte Carlo (LRDMC). By leveraging `JAX` just-in-time (`jit`) compilation and vectorized mapping (`vmap`) functionalities, `jQMC` achieves high-performance computations **especially on GPUs** while remaining portable across CPUs and GPUs. See [here](http://jax.readthedocs.io/) for the details of `JAX`. The **jQMC** users and developers manual is available from [GitHub Pages](https://kousuke-nakano.github.io/jQMC/). - -![license](https://img.shields.io/github/license/kousuke-nakano/jQMC) -![tag](https://img.shields.io/github/v/tag/kousuke-nakano/jQMC) -![fork](https://img.shields.io/github/forks/kousuke-nakano/jQMC?style=social) -![stars](https://img.shields.io/github/stars/kousuke-nakano/jQMC?style=social) -![short-pytest](https://github.com/kousuke-nakano/jQMC/actions/workflows/jqmc-run-short-pytest.yml/badge.svg) -![full-pytest](https://github.com/kousuke-nakano/jQMC/actions/workflows/jqmc-run-full-pytest.yml/badge.svg) -![codecov](https://codecov.io/github/kousuke-nakano/jQMC/graph/badge.svg) +**jQMC** is an ab initio quantum Monte Carlo (QMC) simulation package developed entirely from scratch using `Python` and `JAX`. Originally designed for molecular systems --with future extensions planned for periodic systems-- **jQMC** implements two well-established QMC algorithms: Variational Monte Carlo (VMC) and a robust and efficient variant of Diffusion Monte Carlo algorithm known as Lattice Regularized Diffusion Monte Carlo (LRDMC). By leveraging `JAX` just-in-time (`jit`) compilation and vectorized mapping (`vmap`) functionalities, `jQMC` achieves high-performance computations **especially on GPUs** while remaining portable across CPUs and GPUs. See [here](http://jax.readthedocs.io/) for the details of `JAX`. The **jQMC** users and developers manual is available from [GitHub Pages](https://jqmc-project.github.io/jQMC/). + +![license](https://img.shields.io/github/license/jqmc-project/jQMC) +![tag](https://img.shields.io/github/v/tag/jqmc-project/jQMC) +![fork](https://img.shields.io/github/forks/jqmc-project/jQMC?style=social) +![stars](https://img.shields.io/github/stars/jqmc-project/jQMC?style=social) +![short-pytest](https://github.com/jqmc-project/jQMC/actions/workflows/jqmc-run-short-pytest.yml/badge.svg) +![full-pytest](https://github.com/jqmc-project/jQMC/actions/workflows/jqmc-run-full-pytest.yml/badge.svg) +![codecov](https://codecov.io/github/jqmc-project/jQMC/graph/badge.svg) ![DL](https://img.shields.io/pypi/dm/jqmc) ![python_version](https://img.shields.io/pypi/pyversions/jqmc) ![pypi_version](https://badge.fury.io/py/jqmc.svg) @@ -54,7 +54,7 @@ Kosuke Nakano (National Institute for Materials Science (NIMS), Japan) **The latest version of jQMC** can be installed via pip from the cloned GitHub repository. ```bash -% git clone https://github.com/kousuke-nakano/jQMC +% git clone https://github.com/jqmc-project/jQMC % cd jQMC % pip install . ``` @@ -100,7 +100,7 @@ Once the `main` branch is merged into the `rc` branch, the `GitHub` workflow lau ## How to deploy the documentation -Once the `main` branch is merged into the `rc-gh-pages` branch, the `GitHub` workflow launches the implemented documentaion building process (`jqmc-deploy-gh-pages.yml`) and deploy the compiled documentaiton to [GitHub Pages](https://kousuke-nakano.github.io/jQMC/). +Once the `main` branch is merged into the `rc-gh-pages` branch, the `GitHub` workflow launches the implemented documentaion building process (`jqmc-deploy-gh-pages.yml`) and deploy the compiled documentaiton to [GitHub Pages](https://jqmc-project.github.io/jQMC/). ## Contribution diff --git a/doc/UML.pu b/doc/UML.pu index a4bd7733..43093faf 100644 --- a/doc/UML.pu +++ b/doc/UML.pu @@ -1,9 +1,18 @@ @startuml uml ' size -scale 595*842 +scale 1200*1000 ' PlantUML configuration allowmixing +top to bottom direction + +' Compact layout +skinparam nodesep 15 +skinparam ranksep 25 +skinparam padding 4 +skinparam classAttributeIconSize 0 +skinparam defaultFontSize 18 +skinparam classFontSize 15 ' Command + Shift + P to toggle the PlantUML export mode ' inkscape uml.svg -o uml.pdf @@ -22,6 +31,10 @@ class Hamiltonian_data <> { class Structure_data <> { - positions: jax.Array + - pbc_flag: bool + - vec_a: tuple[float] + - vec_b: tuple[float] + - vec_c: tuple[float] - atomic_numbers: tuple[int] - element_symbols: tuple[str] - atomic_labels: tuple[str] @@ -81,9 +94,9 @@ class Geminal_data <> { } class Jastrow_data <> { - - jastrow_one_body_data: Jastrow_one_body_data - - jastrow_two_body_data: Jastrow_two_body_data - - jastrow_three_body_data: Jastrow_three_body_data + - jastrow_one_body_data: Jastrow_one_body_data | None + - jastrow_two_body_data: Jastrow_two_body_data | None + - jastrow_three_body_data: Jastrow_three_body_data | None } class Wavefunction_data <> { @@ -112,6 +125,7 @@ class MCMC { - hamiltonian_data : Hamiltonian_data - mcmc_seed : int - num_walkers : int + - Dt : float + run(num_mcmc_steps: int, max_time: int) : None + run_optimize(...) : None + get_E(...) : tuple @@ -119,70 +133,101 @@ class MCMC { + get_gF(...) : tuple } -class GFMC { +class GFMC_t { + Tau-step LRDMC (GFMC with fixed tau). + -- - hamiltonian_data : Hamiltonian_data - mcmc_seed : int - num_walkers : int - + run(num_gfmc_steps: int, max_time: int) : None + - tau : float + - alat : float + + run(num_mcmc_steps: int, max_time: int) : None + + get_E(...) : tuple + + get_aF(...) : tuple +} + +class GFMC_n { + Node-based LRDMC (GFMC with fixed alat). + -- + - hamiltonian_data : Hamiltonian_data + - mcmc_seed : int + - num_walkers : int + - E_scf : float + - alat : float + + run(num_mcmc_steps: int, max_time: int) : None + get_E(...) : tuple + get_aF(...) : tuple } ' Functions -rectangle "compute_local_energy_jax(\n hamiltonian_data: Hamiltonian_data,\n r_up_carts: jax.Array,\n r_dn_carts: jax.Array\n) -> float" as compute_local_energy_jax -rectangle "jax.grad(compute_local_energy_jax)(\n hamiltonian_data: Hamiltonian_data,\n r_up_carts: jax.Array,\n r_dn_carts: jax.Array\n) -> float" as grad_compute_local_energy_jax -rectangle "compute_discretized_kinetic_energy_jax(\n alat: float,\n wavefunction_data: Wavefunction_data,\n r_up_carts: jax.Array,\n r_dn_carts: jax.Array\n)" as compute_discretized_kinetic_energy_jax -' rectangle "compute_ecp_non_local_parts_jax(\n coulomb_potential_data: Coulomb_potential_data,\n wavefunction_data: Wavefunction_data,\n r_up_carts: jax.Array,\n r_dn_carts: jax.Array\n)" as compute_ecp_non_local_parts_jax +rectangle "compute_local_energy(\n hamiltonian_data: Hamiltonian_data,\n r_up_carts: jax.Array,\n r_dn_carts: jax.Array,\n RT: jax.Array\n) -> float" as compute_local_energy +rectangle "jax.grad(compute_local_energy)(\n hamiltonian_data: Hamiltonian_data,\n r_up_carts: jax.Array,\n r_dn_carts: jax.Array,\n RT: jax.Array\n) -> Hamiltonian_data" as grad_compute_local_energy +rectangle "compute_discretized_kinetic_energy(\n alat: float,\n wavefunction_data: Wavefunction_data,\n r_up_carts: jax.Array,\n r_dn_carts: jax.Array,\n RT: jax.Array\n) -> tuple" as compute_discretized_kinetic_energy -' Dependency relationships -note top of compute_local_energy_jax +' Notes on functions +note top of compute_local_energy This function computes the local energy with a given Hamiltonian_data and electron -positions (r_up_cart and r_dn_carts). +positions (r_up_carts and r_dn_carts). end note -note bottom of grad_compute_local_energy_jax +note bottom of grad_compute_local_energy This function computes **derivatives** of the local energy with a given Hamiltonian_data -and electron positions (r_up_cart and r_dn_carts). +and electron positions (r_up_carts and r_dn_carts). end note ' =============================== -' Class relationships (composition/aggregation) +' Class relationships +' +' *-- : composition (filled diamond) +' child cannot exist independently; lifecycle tied to parent. +' o-- : aggregation (hollow diamond) +' child can exist independently; used for optional (| None) +' fields and union-type alternatives. +' ..> : dependency (dashed arrow) +' one element uses / calls another. ' =============================== MCMC *-- Hamiltonian_data -GFMC *-- Hamiltonian_data -MCMC ..> compute_local_energy_jax: calls -MCMC ..> grad_compute_local_energy_jax: calls -GFMC ..> compute_discretized_kinetic_energy_jax: calls -' GFMC ..> compute_ecp_non_local_parts_jax: calls +GFMC_t *-- Hamiltonian_data +GFMC_n *-- Hamiltonian_data +MCMC ..> compute_local_energy: calls +MCMC ..> grad_compute_local_energy: calls +GFMC_t ..> compute_discretized_kinetic_energy: calls +GFMC_n ..> compute_discretized_kinetic_energy: calls + Jastrow_one_body_data *-- Structure_data -Jastrow_three_body_data *-- AOs_sphe_data -Jastrow_three_body_data *-- AOs_cart_data -Jastrow_three_body_data *-- MOs_data + +' Union-type alternatives: orb_data field holds exactly one of these +Jastrow_three_body_data o-- AOs_sphe_data +Jastrow_three_body_data o-- AOs_cart_data +Jastrow_three_body_data o-- MOs_data Hamiltonian_data *-- Structure_data Hamiltonian_data *-- Coulomb_potential_data Hamiltonian_data *-- Wavefunction_data Coulomb_potential_data *-- Structure_data -Coulomb_potential_data *-- Wavefunction_data AOs_cart_data *-- Structure_data AOs_sphe_data *-- Structure_data -MOs_data *-- AOs_cart_data -MOs_data *-- AOs_sphe_data -Geminal_data *-- AOs_cart_data -Geminal_data *-- AOs_sphe_data -Geminal_data *-- MOs_data +' Union-type alternatives: aos_data holds exactly one of these +MOs_data o-- AOs_cart_data +MOs_data o-- AOs_sphe_data + +' Union-type alternatives: orb_data_up/dn_spin holds exactly one of these +Geminal_data o-- AOs_cart_data +Geminal_data o-- AOs_sphe_data +Geminal_data o-- MOs_data Wavefunction_data *-- Jastrow_data Wavefunction_data *-- Geminal_data -Jastrow_data *-- Jastrow_one_body_data -Jastrow_data *-- Jastrow_two_body_data -Jastrow_data *-- Jastrow_three_body_data +' Optional fields (| None) -> aggregation +Jastrow_data o-- Jastrow_one_body_data +Jastrow_data o-- Jastrow_two_body_data +Jastrow_data o-- Jastrow_three_body_data @enduml diff --git a/doc/changelog.md b/doc/changelog.md index 3c2fbd57..2de9f256 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -2,6 +2,61 @@ # Change Log +## May-29-2026: v0.2.2 + +First stable release since v0.1.0. v0.2.2 ships everything accumulated across four alphas (v0.2.0a1, v0.2.1a1, v0.2.1a2, v0.2.2a1) plus a final round of polish. Per-alpha sections are preserved below; this entry is a roll-up of the highlights from v0.1.0 to v0.2.2. + +### Highlights (v0.1.0 -> v0.2.2) + +#### Optimization + +* **Linear Method (LM) optimizer** integrated under `method="sr"` with a unified `use_lm` / `lm_subspace_dim` hierarchy (plain SR / aSR / LM). New `|v_0|^2 < 0.9` fallback to plain SR keeps non-linear-regime updates from producing NaN energies. +* **Adaptive learning rate** for Stochastic Reconfiguration. +* **MO optimization** for JSD via the projection method with Attacalite-Sorella regularization, plus geminal AO -> MO projection. +* **AO basis optimization** (`opt_J3_basis_coeff/exp`, `opt_lambda_basis_coeff/exp`) with shell-shared constraint and dual symmetrization. +* **Distributed tall-CG SR** solver via `psum`, removing `mpi_size`-scaling memory in the SR solve. + +#### Performance + +* **Fast-update use** across MCMC / VMC / LRDMC, with mat-vec hot paths converted to GEMM for better GPU utilization. +* **On-GPU VMC optimization** with `use_device_collectives` auto-selected by JAX backend; multi-GPU `run_optimize` supported. +* **LU -> SVD** in determinant / geminal / GFMC_n / GFMC_t for ill-conditioned stability; Cartesian / Spherical AO conversion (Cartesian GTOs are substantially faster on GPU); ECP fast path (`compute_ecp_coulomb_potential_fast`). + +#### Numerical precision + +* **Mixed-precision support** with `"full"` / `"mixed"` modes and per-zone dtype control. Three explicit design principles. AGP/SD geminal stays fp64 to prevent `log|det|` amplification; electron-nucleus `r - R` differences are reconstructed in fp64 before downcast to avoid catastrophic cancellation. `ao_grad_lap` and `mo_grad_lap` zones are split for finer-grained control. + +#### Features + +* **LRDMC atomic forces** with the Pathak-Wagner regularization. +* **Runtime-selectable Jastrow forms**: `jastrow_1b_type` and `jastrow_2b_type` (`exp` / `pade`). +* **`use_swct` flag** to toggle Space Warp Coordinate Transformation in MCMC and GFMC_n / GFMC_t. + +#### `jqmc_workflow` automation package + +* **jqmc-workflow** is introduced as a multi-stage QMC pipeline orchestrator (WF conversion -> VMC opt -> MCMC / LRDMC production) with automatic step estimation, checkpointing, and remote job management. + +#### bug fixes + +* GFMC_n / GFMC_t spin-polarized (`n_up != n_dn`, `n_dn >= 1`) MPI bug. +* MPI deadlock in `max_time` / `stop_flag` checks; `Allreduce` vs `allreduce` for scalars. +* Optimizer step estimation; force NaN; MCMC memory overflow from `r_up_history` / `r_dn_history` storage. + +#### Infrastructure + +* **Restart files** migrated from pickle `.chk` to HDF5 `.h5` (no backward compatibility). +* **Ruff lint pipeline** (`jqmc-lint-ruff.yml`) and pre-commit updates; non-ASCII cleanup across code and docstrings. +* **Nightly CI + Codecov** activated with the `pytest-xdist` support. +* **Examples**: 11 end-to-end tutorials (`jqmc-example01` to `jqmc-example08`, `jqmc-workflow-example01` to `jqmc-workflow-example03`). +* **Project ownership** transferred to the `jqmc-project` GitHub organization; URLs updated. + +### Breaking changes since v0.1.0 + +* Restart files: pickle `.chk` is no longer supported; HDF5 `.h5` is the only format. +* Optimizer API: `num_param_opt`, `opt_filter_min_SN_ratio`, `adaptive_learning_rate`, and `method="lm"` are all removed or replaced; the Linear Method is now accessed via `method="sr"` with `use_lm=true` (and the new `lm_subspace_dim` / `lm_cond` parameters). + +See the per-alpha sections below for full details. + ## May-18-2026: v0.2.2a1 This release brings configurable mixed-precision support, deep kernel-level performance work (AOs, Jastrow, det/Jastrow ratios, GFMC), on-GPU VMC optimization, and a project-wide lint/cleanup. diff --git a/doc/conf.py b/doc/conf.py index cbbf5e18..af6fb80a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -50,7 +50,7 @@ def _generate_examples_page(): "# Examples\n" "\n" "Example files for **jQMC** are found at\n" - ".\n" + ".\n" "\n" ) @@ -212,7 +212,7 @@ def _dedup_footnote(m): # documentation. html_theme_options = { "icon_links": [ - {"name": "GitHub", "url": "https://github.com/kousuke-nakano/jQMC", "icon": "fa-brands fa-github"}, + {"name": "GitHub", "url": "https://github.com/jqmc-project/jQMC", "icon": "fa-brands fa-github"}, ], } # html_theme_options = { diff --git a/doc/examples.md b/doc/examples.md index 22191a5e..9c01c816 100644 --- a/doc/examples.md +++ b/doc/examples.md @@ -3,7 +3,7 @@ # Examples Example files for **jQMC** are found at -. +. ## jqmc-example01: @@ -2510,12 +2510,18 @@ This example uses local execution (`jqmc_setting_local/`): localhost: machine_type: local queuing: false + jobsubmit: bash # required even for queuing=false: command used to invoke the submit script + # (use "bash" or "sh"; for queuing=true add jobcheck / jobdel / jobnum_index) ## Remote machines require ssh_host: ## cluster: ## ssh_host: my-cluster # Host alias in ~/.ssh/config ## machine_type: remote ## queuing: true +## jobsubmit: qsub # required: scheduler submit command (qsub / sbatch / ...) +## jobcheck: qstat # required when queuing=true: status query command +## jobdel: qdel # required when queuing=true: cancel command +## jobnum_index: 0 # required when queuing=true: index of the job-id token in jobsubmit's stdout ## ... ``` diff --git a/doc/install.md b/doc/install.md index 66f07533..b0dd62f5 100644 --- a/doc/install.md +++ b/doc/install.md @@ -2,4 +2,4 @@ # Installation -How to Install **jQMC** is written in `Readme.md` https://github.com/kousuke-nakano/jQMC/tree/main/. +How to Install **jQMC** is written in `Readme.md` https://github.com/jqmc-project/jQMC/tree/main/. diff --git a/doc/overview.md b/doc/overview.md index d7aa2a4b..696b6be3 100644 --- a/doc/overview.md +++ b/doc/overview.md @@ -2,12 +2,12 @@ **jQMC** is an ab initio quantum Monte Carlo (QMC) simulation package developed entirely from scratch using Python and JAX. Originally designed for molecular systems--with future extensions planned for periodic systems--**jQMC** implements two well-established QMC algorithms: Variational Monte Carlo (VMC) and a robust and efficient variant of Diffusion Monte Carlo known as Lattice Regularized Diffusion Monte Carlo (LRDMC). By leveraging JAX just-in-time (jit) compilation and vectorized mapping (vmap) functionalities, jQMC achieves high-performance computations especially on GPUs while remaining portable across CPUs and GPUs. -![license](https://img.shields.io/github/license/kousuke-nakano/jQMC) -![tag](https://img.shields.io/github/v/tag/kousuke-nakano/jQMC) -![fork](https://img.shields.io/github/forks/kousuke-nakano/jQMC?style=social) -![stars](https://img.shields.io/github/stars/kousuke-nakano/jQMC?style=social) -![full-pytest](https://github.com/kousuke-nakano/jQMC/actions/workflows/jqmc-run-full-pytest.yml/badge.svg) -![codecov](https://codecov.io/github/kousuke-nakano/jQMC/graph/badge.svg) +![license](https://img.shields.io/github/license/jqmc-project/jQMC) +![tag](https://img.shields.io/github/v/tag/jqmc-project/jQMC) +![fork](https://img.shields.io/github/forks/jqmc-project/jQMC?style=social) +![stars](https://img.shields.io/github/stars/jqmc-project/jQMC?style=social) +![full-pytest](https://github.com/jqmc-project/jQMC/actions/workflows/jqmc-run-full-pytest.yml/badge.svg) +![codecov](https://codecov.io/github/jqmc-project/jQMC/graph/badge.svg) ![DL](https://img.shields.io/pypi/dm/jqmc) ![python_version](https://img.shields.io/pypi/pyversions/jqmc) ![pypi_version](https://badge.fury.io/py/jqmc.svg) diff --git a/doc/workflows.key b/doc/workflows.key deleted file mode 100755 index 671ac59a..00000000 Binary files a/doc/workflows.key and /dev/null differ diff --git a/jqmc/_setting.py b/jqmc/_setting.py index ef6ea535..2e0f8874 100644 --- a/jqmc/_setting.py +++ b/jqmc/_setting.py @@ -56,7 +56,6 @@ GFMC_MIN_COLLECT_STEPS = 5 # on the fly statistics param -GFMC_ON_THE_FLY_WARMUP_STEPS = 20 GFMC_ON_THE_FLY_COLLECT_STEPS = 10 GFMC_ON_THE_FLY_BIN_BLOCKS = 10 diff --git a/jqmc/jqmc_gfmc.py b/jqmc/jqmc_gfmc.py index ab86d865..43cf0c6b 100644 --- a/jqmc/jqmc_gfmc.py +++ b/jqmc/jqmc_gfmc.py @@ -63,7 +63,6 @@ GFMC_MIN_WARMUP_STEPS, GFMC_ON_THE_FLY_BIN_BLOCKS, GFMC_ON_THE_FLY_COLLECT_STEPS, - GFMC_ON_THE_FLY_WARMUP_STEPS, get_eps, ) from .coulomb_potential import ( @@ -660,6 +659,7 @@ def run(self, num_mcmc_steps: int = 50, max_time: int = 86400) -> None: timer_mpi_barrier = 0.0 timer_collection = 0.0 timer_reconfiguration = 0.0 + mpi_comm.Barrier() gfmc_total_start = time.perf_counter() # toml(control) filename @@ -1291,6 +1291,7 @@ def _projection_t_streaming( self.__alat, self.__hamiltonian_data, ) + mpi_comm.Barrier() end_init = time.perf_counter() timer_projection_init += end_init - start_init logger.info("End compilation of the GFMC projection funciton.") @@ -1544,6 +1545,7 @@ def _compute_local_energy_t( _ = _jit_vmap_swct_domega_t(self.__hamiltonian_data.structure_data, self.__latest_r_up_carts) _ = _jit_vmap_swct_domega_t(self.__hamiltonian_data.structure_data, self.__latest_r_dn_carts) end_init_force = time.perf_counter() + timer_projection_init += end_init_force - start_init_force logger.info("End compilation of force gradient functions.") logger.info(f"Elapsed Time = {end_init_force - start_init_force:.2f} sec.") logger.info("") @@ -1677,6 +1679,7 @@ def _run_projection_loop_streaming(pcl, tll, wll, ru, rd, Ainv, key, ks): self.__jax_PRNG_key_list, _init_kinetic_state_list_compile, ).compile() + mpi_comm.Barrier() end_warmup = time.perf_counter() timer_projection_init += end_warmup - start_warmup logger.info("End compilation of the GFMC projection while_loop driver.") @@ -2295,6 +2298,9 @@ def _run_projection_loop_streaming(pcl, tll, wll, ru, rd, Ainv, key, ks): self.__latest_A_old_inv = vmap(_compute_initial_A_inv_t, in_axes=(0, 0))( self.__latest_r_up_carts, self.__latest_r_dn_carts ) + # block before Barrier so the A_inv GPU work is included in timer_reconfiguration + # and dispatch-queue skew does not leak into the next step's barrier wait. + self.__latest_A_old_inv.block_until_ready() # Barrier after MPI operation mpi_comm.Barrier() @@ -5783,6 +5789,7 @@ def _compute_local_energy_n( _ = _jit_vmap_swct_omega_n(self.__hamiltonian_data.structure_data, self.__latest_r_dn_carts) _ = _jit_vmap_swct_domega_n(self.__hamiltonian_data.structure_data, self.__latest_r_up_carts) _ = _jit_vmap_swct_domega_n(self.__hamiltonian_data.structure_data, self.__latest_r_dn_carts) + mpi_comm.Barrier() end_init = time.perf_counter() timer_projection_init += end_init - start_init logger.info("End compilation of the GFMC projection funciton.") @@ -6365,6 +6372,9 @@ def _compute_local_energy_n( self.__latest_r_up_carts = jnp.asarray(latest_r_up_carts_after_branching, dtype=jnp.float64) self.__latest_r_dn_carts = jnp.asarray(latest_r_dn_carts_after_branching, dtype=jnp.float64) self.__latest_A_old_inv = _jit_vmap_A_inv_n(self.__latest_r_up_carts, self.__latest_r_dn_carts) + # block before Barrier so the A_inv GPU work is included in timer_reconfiguration + # and dispatch-queue skew does not leak into the next step's barrier wait. + self.__latest_A_old_inv.block_until_ready() mpi_comm.Barrier() @@ -6376,10 +6386,10 @@ def _compute_local_energy_n( start_update_E_scf = time.perf_counter() ## parameters for E_scf - eq_steps = GFMC_ON_THE_FLY_WARMUP_STEPS num_gfmc_collect_steps = GFMC_ON_THE_FLY_COLLECT_STEPS num_gfmc_bin_blocks = GFMC_ON_THE_FLY_BIN_BLOCKS + # (A) accumulate __G_L / __G_e_L every step (after enough stored_w_L) if mpi_rank == 0: if i_mcmc_step >= num_gfmc_collect_steps: e_L = self.__stored_e_L[self.__mcmc_counter + num_mcmc_done] @@ -6390,40 +6400,46 @@ def _compute_local_energy_n( self.__G_L.append(G_L) self.__G_e_L.append(G_L * e_L) - if (i_mcmc_step + 1) % mcmc_interval == 0: - if i_mcmc_step > eq_steps: - if mpi_rank == 0: - num_gfmc_warmup_steps = np.minimum(eq_steps, i_mcmc_step - eq_steps) - logger.debug(f" Computing E_scf at step {i_mcmc_step}.") - G_eq = np.array(self.__G_L[num_gfmc_warmup_steps:]) - G_e_L_eq = np.array(self.__G_e_L[num_gfmc_warmup_steps:]) - G_e_L_split = np.array_split(G_e_L_eq, num_gfmc_bin_blocks) - G_e_L_binned = np.array([np.sum(G_e_L_list) for G_e_L_list in G_e_L_split]) - G_split = np.array_split(G_eq, num_gfmc_bin_blocks) - G_binned = np.array([np.sum(G_list) for G_list in G_split]) - G_e_L_binned_sum = np.sum(G_e_L_binned) - G_binned_sum = np.sum(G_binned) - E_jackknife = [ - (G_e_L_binned_sum - G_e_L_binned[m]) / (G_binned_sum - G_binned[m]) - for m in range(num_gfmc_bin_blocks) - ] - E_mean = np.average(E_jackknife) - E_std = np.sqrt(num_gfmc_bin_blocks - 1) * np.std(E_jackknife) - E_mean = float(E_mean) - E_std = float(E_std) - else: - E_mean = None - E_std = None - - E_mean = mpi_comm.bcast(E_mean, root=0) - E_std = mpi_comm.bcast(E_std, root=0) + # (B) E_scf update schedule: + # - rapid phase (i_mcmc_step < mcmc_interval = N/100): update every step + # - thereafter: update every mcmc_interval steps + # - skip when there are not yet enough G_L samples for jackknife + n_G_L = max(0, i_mcmc_step - num_gfmc_collect_steps + 1) + have_enough = n_G_L >= num_gfmc_bin_blocks + in_rapid_phase = i_mcmc_step < mcmc_interval + on_throttle = (i_mcmc_step + 1) % mcmc_interval == 0 + + if have_enough and (in_rapid_phase or on_throttle): + if mpi_rank == 0: + # Skip the bad-regime stored_w_L that bleed into __G_L via the K-product. + # During the very early phase the available samples are limited, so + # clamp to keep at least num_gfmc_bin_blocks samples for jackknife. + num_gfmc_warmup_steps = min( + num_gfmc_collect_steps + num_gfmc_bin_blocks, + max(n_G_L - num_gfmc_bin_blocks, 0), + ) + G_eq = np.array(self.__G_L[num_gfmc_warmup_steps:]) + G_e_L_eq = np.array(self.__G_e_L[num_gfmc_warmup_steps:]) + G_e_L_split = np.array_split(G_e_L_eq, num_gfmc_bin_blocks) + G_e_L_binned = np.array([np.sum(G_e_L_list) for G_e_L_list in G_e_L_split]) + G_split = np.array_split(G_eq, num_gfmc_bin_blocks) + G_binned = np.array([np.sum(G_list) for G_list in G_split]) + G_e_L_binned_sum = np.sum(G_e_L_binned) + G_binned_sum = np.sum(G_binned) + E_jackknife = [ + (G_e_L_binned_sum - G_e_L_binned[m]) / (G_binned_sum - G_binned[m]) for m in range(num_gfmc_bin_blocks) + ] + E_mean = float(np.average(E_jackknife)) + E_std = float(np.sqrt(num_gfmc_bin_blocks - 1) * np.std(E_jackknife)) + else: + E_mean = None + E_std = None - self.__E_scf = E_mean - E_scf_std = E_std + E_mean = mpi_comm.bcast(E_mean, root=0) + E_std = mpi_comm.bcast(E_std, root=0) - logger.debug(f" Updated E_scf = {self.__E_scf:.5f} +- {E_scf_std:.5f} Ha.") - else: - logger.debug(f" Init E_scf = {self.__E_scf:.5f} Ha. Being equilibrated.") + self.__E_scf = E_mean + logger.debug(f" E_scf = {self.__E_scf:.5f} +- {E_std:.5f} Ha.") mpi_comm.Barrier() end_update_E_scf = time.perf_counter() @@ -8224,19 +8240,29 @@ def _compute_local_energy_n_debug( self.__latest_r_dn_carts = jnp.asarray(latest_r_dn_carts_after_branching, dtype=jnp.float64) # update E_scf - eq_steps = GFMC_ON_THE_FLY_WARMUP_STEPS num_gfmc_collect_steps = GFMC_ON_THE_FLY_COLLECT_STEPS num_gfmc_bin_blocks = GFMC_ON_THE_FLY_BIN_BLOCKS - if (i_mcmc_step + 1) % mcmc_interval == 0: - if i_mcmc_step > eq_steps: - self.__E_scf, E_scf_std = self.get_E_on_the_fly( - num_gfmc_warmup_steps=np.minimum(eq_steps, i_mcmc_step - eq_steps), - num_gfmc_bin_blocks=num_gfmc_bin_blocks, - num_gfmc_collect_steps=num_gfmc_collect_steps, - ) - logger.debug(f" Updated E_scf = {self.__E_scf:.5f} +- {E_scf_std:.5f} Ha.") - else: - logger.debug(f" Init E_scf = {self.__E_scf:.5f} Ha. Being equilibrated.") + + # E_scf update schedule: + # - rapid phase (i_mcmc_step < mcmc_interval = N/100): update every step + # - thereafter: update every mcmc_interval steps + # - skip when there are not yet enough G_L samples for jackknife + n_G_L = max(0, i_mcmc_step - num_gfmc_collect_steps + 1) + have_enough = n_G_L >= num_gfmc_bin_blocks + in_rapid_phase = i_mcmc_step < mcmc_interval + on_throttle = (i_mcmc_step + 1) % mcmc_interval == 0 + + if have_enough and (in_rapid_phase or on_throttle): + num_gfmc_warmup_steps = min( + num_gfmc_collect_steps + num_gfmc_bin_blocks, + max(n_G_L - num_gfmc_bin_blocks, 0), + ) + self.__E_scf, E_scf_std = self.get_E_on_the_fly( + num_gfmc_warmup_steps=num_gfmc_warmup_steps, + num_gfmc_bin_blocks=num_gfmc_bin_blocks, + num_gfmc_collect_steps=num_gfmc_collect_steps, + ) + logger.debug(f" E_scf = {self.__E_scf:.5f} +- {E_scf_std:.5f} Ha.") # count up, here is the end of the branching step. num_mcmc_done += 1 diff --git a/jqmc/jqmc_mcmc.py b/jqmc/jqmc_mcmc.py index 6d67ad02..cb1ecaa5 100644 --- a/jqmc/jqmc_mcmc.py +++ b/jqmc/jqmc_mcmc.py @@ -133,7 +133,9 @@ def _loglevel_devel(self, message, *args, **kwargs): # - wide + direct : (X X^T + eps I) y = X F via psum # - wide + CG : same system, conjugate gradient with psum'd matvec # - tall + direct : (X^T X + eps I) z = F, theta = X z via all_gather -# - tall + CG : same system, conjugate gradient on replicated inputs +# - tall + CG : same system, conjugate gradient via psum on +# sharded (N_local,) state -- ``X`` stays sharded so +# per-rank memory is independent of ``mpi_size``. # # Compiled kernels are cached at module level so the JIT cost is paid once # per process. The same code path runs on: @@ -198,6 +200,60 @@ def body(state): return x_f, jnp.sqrt(rs_f), k_f +def _cg_while_loop_sharded(b, apply_A, x0, max_iter, tol, dtype, axis_name): + """``_cg_while_loop`` variant for vectors sharded along ``axis_name``. + + All vectors (``b``, ``x0`` and the output of ``apply_A``) live in + the per-rank slice ``(N_local,)``; inner products are summed + globally with ``jax.lax.psum`` so the loop's convergence check and + step-size scalars match the replicated (un-sharded) implementation + bit-for-bit when only the round-off order changes. + + This helper is what makes the "distributed" CG kernel below feasible + without materialising the full design matrix on every rank. + """ + tiny = jnp.asarray(jnp.finfo(dtype).tiny, dtype=dtype) + tol_sq = tol * tol + + def gdot(a, c): + # Global inner product across the sharded axis. + return jax.lax.psum(jnp.dot(a, c), axis_name) + + r0 = b - apply_A(x0) + rs0 = gdot(r0, r0) + state0 = ( + x0, + r0, + r0, # p + rs0, + jnp.int32(0), + jnp.bool_(False), # breakdown + ) + + def cond(state): + _x, _r, _p, rs, k, breakdown = state + return (k < max_iter) & (rs > tol_sq) & jnp.logical_not(breakdown) + + def body(state): + x, r, p, rs_old, k, breakdown = state + Ap = apply_A(p) + denom = gdot(p, Ap) + new_breakdown = breakdown | jnp.logical_not(jnp.isfinite(denom)) | (jnp.abs(denom) <= tiny) + safe_denom = jnp.where(new_breakdown, jnp.asarray(1.0, dtype=dtype), denom) + alpha = rs_old / safe_denom + x_new = jnp.where(new_breakdown, x, x + alpha * p) + r_new = jnp.where(new_breakdown, r, r - alpha * Ap) + rs_new_real = gdot(r_new, r_new) + rs_new = jnp.where(new_breakdown, rs_old, rs_new_real) + safe_rs_old = jnp.where(rs_old > 0, rs_old, jnp.asarray(1.0, dtype=dtype)) + beta = rs_new / safe_rs_old + p_new = jnp.where(new_breakdown, p, r_new + beta * p) + return (x_new, r_new, p_new, rs_new, k + 1, new_breakdown) + + x_f, _r_f, _p_f, rs_f, k_f, _bk_f = jax.lax.while_loop(cond, body, state0) + return x_f, jnp.sqrt(rs_f), k_f + + def _get_sr_wide_direct_kernel(): """Wide-matrix direct SR solve: ``theta = (X X^T + eps I)^{-1} (X F)``. @@ -304,7 +360,21 @@ def _solve(X, F, epsilon): def _get_sr_tall_cg_kernel(): - """Tall-matrix CG SR solve via push-through identity.""" + """Tall-matrix CG SR solve via push-through identity. + + ``X`` stays sharded as ``(P, N_local)`` on every rank; the only + communications are + + * ``psum`` of ``X v_local`` to form the ``(P,)`` projection used by + ``apply_A``; + * ``psum`` of the CG inner products (scalars); + * ``psum`` of ``X y_local`` to assemble the final ``theta``; + * one ``all_gather`` of ``y_local`` (size ``N_total``, vector) so + the returned warm-start dual has the canonical replicated shape. + + Per-rank peak memory therefore stays proportional to ``P * N_local`` + plus ``O(P) + O(N_total)`` scratch, independent of ``mpi_size``. + """ cached = getattr(_get_sr_tall_cg_kernel, "_cached", None) if cached is not None: return cached @@ -323,20 +393,38 @@ def _get_sr_tall_cg_kernel(): PSpec(), # tol PSpec(), # x0 (N_total,) replicated ), + # All four outputs replicated: theta (P,) via psum; y (N_total,) + # via final all_gather; residual and num_iter are scalars. out_specs=(PSpec(), PSpec(), PSpec(), PSpec()), - check_vma=False, # all_gather output is replicated but not statically inferrable + check_vma=False, # final all_gather output is replicated but not statically inferrable ) def _solve(X, F, epsilon, max_iter, tol, x0): - X_full = jax.lax.all_gather(X, "rank", axis=1, tiled=True) - F_full = jax.lax.all_gather(F, "rank", tiled=True) + # x0 arrives replicated as (N_total,); slice to this rank's + # contiguous chunk so the CG state lives in (N_local,) form. + n_local = F.shape[0] + rank_idx = jax.lax.axis_index("rank") + x0_local = jax.lax.dynamic_slice(x0, (rank_idx * n_local,), (n_local,)) - def apply_A(v): - return X_full.T @ (X_full @ v) + epsilon * v + def apply_A(v_local): + # v_local : (N_local,) sharded on "rank" + # X v -- accumulate per-rank contributions of the column-sharded matmul. + u = jax.lax.psum(X @ v_local, "rank") # (P,) replicated + # X^T u -- replicated u dotted with this rank's X columns gives the + # local slice of the result; no further collective needed. + return X.T @ u + epsilon * v_local # (N_local,) + + y_local, residual, num_iter = _cg_while_loop_sharded(F, apply_A, x0_local, max_iter, tol, X.dtype, axis_name="rank") + + # theta = X y -- assembled via psum of per-rank X_local @ y_local. + theta = jax.lax.psum(X @ y_local, "rank") # (P,) replicated - y, residual, num_iter = _cg_while_loop(F_full, apply_A, x0, max_iter, tol, X.dtype) - # Return y (sample-space CG solution) too so the caller can persist - # it as a warm-start for the next optimization step. - return X_full @ y, y, residual, num_iter + # Gather y to (N_total,) replicated so the host-side warm-start + # dual stored by the caller is the canonical full-sample vector. + # This is a (N_total,)-sized broadcast -- negligible vs the + # (P, N_total) we would otherwise pay to replicate X. + y_full = jax.lax.all_gather(y_local, "rank", tiled=True) + + return theta, y_full, residual, num_iter _get_sr_tall_cg_kernel._cached = _solve return _solve @@ -696,6 +784,7 @@ def run(self, num_mcmc_steps: int = 0, max_time=86400) -> None: timer_MPI_barrier = 0.0 # mcmc timer starts + mpi_comm.Barrier() mcmc_total_start = time.perf_counter() # toml(control) filename @@ -873,12 +962,14 @@ def run(self, num_mcmc_steps: int = 0, max_time=86400) -> None: ) self.__mcmc_kernels_warmed_up = True + mpi_comm.Barrier() mcmc_update_init_end = time.perf_counter() timer_mcmc_update_init += mcmc_update_init_end - mcmc_update_init_start logger.info("End compilation of the MCMC_update funciton.") logger.info(f"Elapsed Time = {mcmc_update_init_end - mcmc_update_init_start:.2f} sec.") logger.info("") else: + mpi_comm.Barrier() logger.info("Skipping compilation (JAX cache is warm from previous run).") logger.info("") @@ -2479,7 +2570,7 @@ def solve_linear_method( K_matrix: npt.NDArray, B_matrix: npt.NDArray, epsilon: float, - ) -> tuple[npt.NDArray, float]: + ) -> tuple[npt.NDArray, float, float]: r"""Solve the Linear Method generalized eigenvalue problem. Constructs extended matrices :math:`\bar H` and :math:`\bar S` of @@ -2487,7 +2578,9 @@ def solve_linear_method( and solves :math:`\bar H v = E \bar S v`. The eigenvector with the largest :math:`|v_0|^2` is selected, and the - parameter update is :math:`c_k = v_k / v_0`. + parameter update is :math:`c_k = v_k / v_0`. :math:`|v_0|^2` is also + returned so the caller can guard against updates that fall outside + the linear regime (small overlap with the current wavefunction). Args: H_0: Current energy :math:`E_\alpha`. @@ -2498,8 +2591,10 @@ def solve_linear_method( epsilon: Eigenvalue cutoff for S matrix. Returns: - tuple: ``(c_vec, E_lm)`` where ``c_vec`` has shape ``(p,)`` - (in the original parameter space) and ``E_lm`` is the selected eigenvalue. + tuple: ``(c_vec, E_lm, v0_sq_best)`` where ``c_vec`` has shape + ``(p,)`` (in the original parameter space), ``E_lm`` is the + selected eigenvalue, and ``v0_sq_best`` is the :math:`|v_0|^2` + of the selected eigenvector (in [0, 1]). """ p = len(f_vec) @@ -2526,7 +2621,7 @@ def solve_linear_method( if not np.any(alive): logger.warning(" LM dgelscut: all parameters removed in Step 1; returning zero update.") - return np.zeros(p, dtype=dtype_mcmc_np), H_0 + return np.zeros(p, dtype=dtype_mcmc_np), H_0, 0.0 # ---- Step 2: Build correlation matrix for alive parameters ---- alive_idx = np.where(alive)[0] @@ -2539,7 +2634,7 @@ def solve_linear_method( n_alive = len(idx) if n_alive == 0: logger.warning(" LM dgelscut: all parameters removed; returning zero update.") - return np.zeros(p, dtype=dtype_mcmc_np), H_0 + return np.zeros(p, dtype=dtype_mcmc_np), H_0, 0.0 # Build correlation matrix for current alive set D_sub = D_inv_sqrt[idx] # (n_alive,) @@ -2599,7 +2694,7 @@ def solve_linear_method( if p_prime == 0: logger.warning(" LM: no positive S eigenvalues after dgelscut; returning zero update.") - return np.zeros(p, dtype=dtype_mcmc_np), H_0 + return np.zeros(p, dtype=dtype_mcmc_np), H_0, 0.0 # P = U Lambda^{-1/2} (S-orthonormal basis) inv_sqrt_Lambda = 1.0 / np.sqrt(Lambda) @@ -2623,9 +2718,15 @@ def solve_linear_method( eigvals_lm, eigvecs_lm = np.linalg.eigh(H_bar) # ---- Select eigenvector with max |v_0|^2 ---- + # |v_0|^2 measures the overlap with the current wavefunction; large + # overlap means the LM step stays in the linear regime. The caller + # is expected to reject the update (e.g. fall back to plain SR) when + # ``v0_sq_best`` is below its safety threshold -- this routine only + # surfaces the value, it does not enforce a cutoff. v0_sq = eigvecs_lm[0, :] ** 2 best_idx = int(np.argmax(v0_sq)) E_lm = float(eigvals_lm[best_idx]) + v0_sq_best = float(v0_sq[best_idx]) # Diagnostic lowest_idx = 0 @@ -2635,13 +2736,10 @@ def solve_linear_method( eigvals_lm[lowest_idx], v0_sq[lowest_idx], E_lm, - v0_sq[best_idx], + v0_sq_best, ) else: - logger.debug(" LM: selected eigenvalue E_LM = %.6f (|v0|^2 = %.4f)", E_lm, v0_sq[best_idx]) - - if v0_sq[best_idx] < 0.01: - logger.warning(" LM: max |v0|^2 = %.4f is small; update may be unreliable.", v0_sq[best_idx]) + logger.debug(" LM: selected eigenvalue E_LM = %.6f (|v0|^2 = %.4f)", E_lm, v0_sq_best) w = eigvecs_lm[:, best_idx] w0 = w[0] @@ -2655,12 +2753,12 @@ def solve_linear_method( logger.info( " LM: E_LM = %.6f (|v0|^2 = %.4f), ||c|| = %.3e, max|c| = %.3e", E_lm, - v0_sq[best_idx], + v0_sq_best, np.linalg.norm(c_vec), np.max(np.abs(c_vec)), ) - return c_vec, E_lm + return c_vec, E_lm, v0_sq_best @staticmethod def _shard_X_F(X_local: npt.NDArray, F_local: npt.NDArray): @@ -2776,6 +2874,9 @@ def _sr_solve_tall_cg_device( ``x0`` and ``y`` live in the sample space, shape ``(N_total,)``; ``y`` is the CG solution (suitable as warm-start next iteration). ``theta = X y`` lives in parameter space, shape ``(P,)``. + + Uses ``psum`` collectives (not ``all_gather(X)``) so per-rank peak + memory does not scale with ``mpi_size``. """ solver = _get_sr_tall_cg_kernel() X_g, F_g = self._shard_X_F(X_local, F_local) @@ -3711,7 +3812,7 @@ def apply_S_primal_numpy(v): else: logger.info( "Using conjugate gradient for the inverse of S " - "(device-resident, shard_map + all_gather, push-through identity)." + "(device-resident, shard_map + psum, push-through identity)." ) logger.info(f" [CG] threshold {sr_cg_tol}.") logger.info(f" [CG] max iteration: {sr_cg_max_iter}.") @@ -4065,13 +4166,26 @@ def apply_dual_S_numpy(v): ) # Solve LM eigenvalue problem - c_vec, E_lm = self.solve_linear_method(H_0_lm, f_vec_lm, S_mat, K_mat, B_mat, epsilon=lm_cond) - - if E_lm > H_0_lm + 3.0 * E_std: - logger.warning( - f"LM: E_LM={E_lm:.6f} > E_0 + 3*sigma = {H_0_lm:.6f} + 3*{E_std:.6f} = {H_0_lm + 3.0 * E_std:.6f}; " - f"LM does not predict improvement. Falling back to plain SR." - ) + c_vec, E_lm, v0_sq_best = self.solve_linear_method(H_0_lm, f_vec_lm, S_mat, K_mat, B_mat, epsilon=lm_cond) + + # Safety thresholds: + # (i) E_LM exceeds current energy by > 3 sigma -- LM does not + # predict an improvement; the linear model is unreliable. + # (ii) |v_0|^2 of the selected eigenvector is small -- the LM + # update sits far outside the linear regime, where the + # first-order extrapolation can produce wild parameter + # moves and downstream NaN/Inf energies. + # In either case fall back to a conservative plain-SR step. + _V0_SQ_MIN = 0.9 + _e_lm_bad = E_lm > H_0_lm + 3.0 * E_std + _v0_bad = v0_sq_best < _V0_SQ_MIN + if _e_lm_bad or _v0_bad: + reasons = [] + if _e_lm_bad: + reasons.append(f"E_LM={E_lm:.6f} > E_0+3*sigma={H_0_lm:.6f}+3*{E_std:.6f}={H_0_lm + 3.0 * E_std:.6f}") + if _v0_bad: + reasons.append(f"|v0|^2={v0_sq_best:.4f} < {_V0_SQ_MIN} (LM update outside linear regime)") + logger.warning("LM: " + "; ".join(reasons) + ". Falling back to plain SR.") theta = 0.1 * g_sr else: # Back-transform: c_vec[0] = c_0 (SR direction), c_vec[1:] = c_k (individual params) @@ -6702,7 +6816,7 @@ def solve_linear_method( K_matrix: npt.NDArray, B_matrix: npt.NDArray, epsilon: float, - ) -> tuple[npt.NDArray, float]: + ) -> tuple[npt.NDArray, float, float]: r"""Debug implementation of the Linear Method with dgelscut preconditioning. This mirrors ``MCMC.solve_linear_method`` using the same dgelscut @@ -6718,7 +6832,8 @@ def solve_linear_method( epsilon: dgelscut threshold (correlation matrix min eigenvalue). Returns: - (c_vec, E_lm): parameter update in original space and selected eigenvalue. + (c_vec, E_lm, v0_sq_best): parameter update in original space, + selected eigenvalue, and ``|v_0|^2`` of the selected eigenvector. """ # Delegate to MCMC.solve_linear_method -- the production version uses # the same dgelscut + S-orthonormalization + standard eigenvalue problem. diff --git a/jqmc_workflow/_cli.py b/jqmc_workflow/_cli.py index 2bba4af3..02a71a9e 100644 --- a/jqmc_workflow/_cli.py +++ b/jqmc_workflow/_cli.py @@ -44,13 +44,13 @@ import os import shutil -from datetime import datetime from logging import Formatter, StreamHandler, getLogger import toml import typer from ._config import get_config_dir, template_dir +from ._state import WorkflowStatus, get_jobs, update_job, update_status logger = getLogger("jqmc-workflow").getChild(__name__) @@ -74,14 +74,25 @@ def __init__(self, root_dir: str): self.root_dir = root_dir self.job_counter = 0 self.entries = [] # list of dicts with path, state info + self._visited: set[str] = set() def discover(self): """Walk tree and collect entries from workflow_state.toml.""" self.entries = [] self.job_counter = 0 + self._visited.clear() self._walk(self.root_dir) def _walk(self, path): + # Guard against symlink cycles: dedupe by realpath. + try: + key = os.path.realpath(path) + except OSError: + return + if key in self._visited: + return + self._visited.add(key) + state_file = os.path.join(path, "workflow_state.toml") if os.path.isfile(state_file): try: @@ -112,9 +123,12 @@ def _walk(self, path): except Exception as e: logger.warning(f"Failed to read {state_file}: {e}") - # Recurse into subdirs + # Recurse into subdirs. Pilot subdirectories (``_pilot*``) hold + # internal bookkeeping state for the parent workflow and should + # not be listed as separate user-facing jobs. This matches the + # exclusion in :func:`_state.get_all_workflow_statuses`. try: - subdirs = sorted(d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))) + subdirs = sorted(d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d)) and not d.startswith("_pilot")) except PermissionError: return for d in subdirs: @@ -319,14 +333,16 @@ def delete_job(self, job_id: int, server_machine_name: str): except Exception as ex: logger.error(f" Failed to delete job on {server}: {ex}") - # Update workflow_state.toml - state_file = os.path.join(e["dir"], "workflow_state.toml") - if os.path.isfile(state_file): - data = toml.load(state_file) - data.setdefault("workflow", {})["status"] = "cancelled" - data["workflow"]["updated_at"] = datetime.now().isoformat() - with open(state_file, "w") as f: - toml.dump(data, f) + # Update workflow_state.toml via the canonical API so that + # [[jobs]] (latest record) and [workflow] stay consistent. + if os.path.isfile(os.path.join(e["dir"], "workflow_state.toml")): + jobs = get_jobs(e["dir"]) + if jobs: + last_job = jobs[-1] + input_file = last_job.get("input_file") + if input_file and last_job.get("status") in ("submitted", "completed"): + update_job(e["dir"], input_file, status="cancelled") + update_status(e["dir"], WorkflowStatus.CANCELLED) logger.info(f" Status set to 'cancelled' for JOB-ID {job_id}.") diff --git a/jqmc_workflow/_error_estimator.py b/jqmc_workflow/_error_estimator.py index 4b9cdbb9..b002d94c 100644 --- a/jqmc_workflow/_error_estimator.py +++ b/jqmc_workflow/_error_estimator.py @@ -164,6 +164,50 @@ def estimate_additional_steps( return additional +def read_accumulated_measurement_steps( + restart_chk_path: str, + warmup: int, + collect_steps: int = 0, +) -> int | None: + """Read the actual accumulated measurement steps from a jQMC checkpoint. + + Returns ``raw_mcmc_counter - collect_steps - warmup``, i.e. the number + of binnable measurement samples reflected in the checkpoint's + observable arrays (matching :pymeth:`MCMC.get_E` / + :pymeth:`GFMC_t.get_E` post-processing logic, where the public + ``mcmc_counter`` property already subtracts ``collect_steps``). + + This is the source of truth for the accumulated sample count when a + run was interrupted by ``max_time`` and only partially completed its + planned ``num_mcmc_steps`` -- in that case the planned step count + over-estimates the actual samples on disk. + + Args: + restart_chk_path: Path to ``restart.h5`` (merged checkpoint). + warmup: ``num_(gfmc_)mcmc_warmup_steps`` from the workflow. + collect_steps: ``num_gfmc_collect_steps`` for LRDMC, 0 for MCMC. + + Returns: + Effective accumulated measurement-step count (>= 0), or *None* + if the checkpoint cannot be read. + """ + try: + from jqmc._checkpoint import load_driver_config_from_checkpoint + except ImportError as exc: + logger.warning(f"jqmc not importable; cannot read mcmc_counter from {restart_chk_path}: {exc}") + return None + try: + cfg = load_driver_config_from_checkpoint(restart_chk_path, rank=0) + except Exception as exc: + logger.warning(f"Cannot read driver_config from {restart_chk_path}: {exc}") + return None + if "mcmc_counter" not in cfg: + logger.warning(f"driver_config in {restart_chk_path} has no 'mcmc_counter' key.") + return None + raw = int(cfg["mcmc_counter"]) + return max(raw - int(collect_steps) - int(warmup), 0) + + def suffixed_name(filename: str, index: int) -> str: """Insert an integer suffix before the file extension. diff --git a/jqmc_workflow/_input_generator.py b/jqmc_workflow/_input_generator.py index 5b1ee3f5..d202411c 100644 --- a/jqmc_workflow/_input_generator.py +++ b/jqmc_workflow/_input_generator.py @@ -37,6 +37,7 @@ # POSSIBILITY OF SUCH DAMAGE. import copy +import os from logging import getLogger import toml @@ -152,13 +153,21 @@ def generate_input_toml( f"Required parameter '{k}' in [{section}] was not set. Please provide it via the 'overrides' dict." ) + # Atomic write: tmpfile + fsync + os.replace. A partial input TOML + # would otherwise block the next workflow run with a parse error. + tmp = filename + ".tmp" if with_comments: text = _dump_with_comments(params, job_type) - with open(filename, "w") as f: + with open(tmp, "w") as f: f.write(text) + f.flush() + os.fsync(f.fileno()) else: - with open(filename, "w") as f: + with open(tmp, "w") as f: toml.dump(params, f) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, filename) logger.info(f"Generated {filename} (job_type={job_type})") return filename @@ -195,7 +204,11 @@ def _toml_value(v) -> str: if isinstance(v, bool): return "true" if v else "false" if isinstance(v, str): - return f'"{v}"' + # JSON-compatible escaping matches TOML basic-string rules + # (backslash, quotes, control characters). + import json + + return json.dumps(v, ensure_ascii=False) if isinstance(v, (int, float)): return str(v) if isinstance(v, dict): diff --git a/jqmc_workflow/_job.py b/jqmc_workflow/_job.py index 1acc1116..01f1dc11 100644 --- a/jqmc_workflow/_job.py +++ b/jqmc_workflow/_job.py @@ -102,7 +102,6 @@ def __init__( queue_label: str = "default", jobname: str = "jqmc-wf", run_id: str = "", - safe_mode: bool = False, ): self.data_transfer = Data_transfer( server_machine_name=server_machine_name, @@ -142,7 +141,6 @@ def __init__( self.run_id = run_id self.input_file = input_file self.output_file = output_file - self.safe_mode = safe_mode # -- Job state --------------------------------------------- self.max_job_submit = self.queue_data.get("max_job_submit", 1000) @@ -410,5 +408,8 @@ def job_acct(self) -> tuple[str, str, str] | None: # -- Helper ---------------------------------------------------- def _close_ssh(self): - self.server_machine.ssh_close() + # data_transfer owns the same Machine instance and recursively + # calls ssh_close() on it via Machines_handler; one call suffices. + # ssh_close is idempotent (no-op after the first call), so a + # second close here would just be wasted work. self.data_transfer.ssh_close() diff --git a/jqmc_workflow/_machine.py b/jqmc_workflow/_machine.py index a532bea6..266cdd66 100644 --- a/jqmc_workflow/_machine.py +++ b/jqmc_workflow/_machine.py @@ -36,10 +36,12 @@ # ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. +import concurrent.futures import os import pathlib import random import re +import shlex import shutil import stat import subprocess @@ -151,6 +153,13 @@ def ssh_open(self): logger.debug(f"SSH already open (id={id(self.ssh)})") return + # Note: this is a *synchronous* sleep, even when ``ssh_open`` is + # reached from an asyncio coroutine (e.g. ``_submit_and_wait``). + # It blocks the event loop for 3-6 s. Parallel workflows + # (Launcher / LRDMC_Ext) effectively serialise here. Replacing + # this whole layer with asyncssh + ``await asyncio.sleep`` would + # restore parallelism; until then, prefer fewer SSH opens per + # workflow rather than tighter polling intervals. rw = random.randint(3, 6) logger.info(f" Wait {rw}s before opening SSH to {self.name}") time.sleep(rw) @@ -215,11 +224,13 @@ def ssh_open(self): except paramiko.ssh_exception.SSHException: # Clean up the ProxyCommand from this failed attempt self._kill_proxy_process(proxy_cmd) - logger.warning(f"SSH connect failed (attempt {tt + 1}). Retrying in {self.ssh_retry_time}s.") - time.sleep(self.ssh_retry_time) if tt == self.ssh_retry_max_num - 1: + # Re-raise immediately on the final attempt; no point + # sleeping ssh_retry_time only to give up afterwards. logger.error("SSH connect failed after all retries.") raise + logger.warning(f"SSH connect failed (attempt {tt + 1}). Retrying in {self.ssh_retry_time}s.") + time.sleep(self.ssh_retry_time) self.sftp = self.ssh.open_sftp() self.ssh_status = True @@ -344,7 +355,7 @@ def run_command(self, command: str, execute_dir: str = None): return self._run_local(command_r) return self._run_remote(command_r) - def _run_local(self, command_r: str, max_retries: int = 10): + def _run_local(self, command_r: str, max_retries: int = 3): for attempt in range(max_retries): for sub_attempt in range(3): try: @@ -368,32 +379,64 @@ def _run_local(self, command_r: str, max_retries: int = 10): logger.warning(f"Local command timeout (sub-attempt {sub_attempt})") time.sleep(60) - logger.warning(f"Local command failed (attempt {attempt}). Retrying in {self.ssh_retry_time}s.") - time.sleep(self.ssh_retry_time) + if attempt < max_retries - 1: + logger.warning( + f"Local command failed (attempt {attempt + 1}/{max_retries}). Retrying in {self.ssh_retry_time}s." + ) + time.sleep(self.ssh_retry_time) raise RuntimeError(f"Local command failed after {max_retries} retries: {command_r}") + remote_command_timeout_sec = 1200 # match _run_local default + def _run_remote(self, command_r: str): + """Execute *command_r* on the remote host with a hard wall-time guard. + + ``recv_exit_status`` waits on a paramiko ``status_event`` that is + *not* connected to socket-level keepalive, so a dead SSH session + can hang the call indefinitely. We run the entire exec/read + sequence in a worker thread and enforce a timeout via + :class:`concurrent.futures.Future`; on timeout the SSH session is + torn down so the next call reconnects cleanly. + """ self.ssh_open() + + def _do_exec(): + try: + pstdin, pstdout, pstderr = self.ssh.exec_command(command=command_r) + except (paramiko.SSHException, OSError, EOFError): + # Connection may have died (e.g. keepalive timeout during + # a long asyncio.sleep between polls). Reconnect once. + logger.warning("SSH connection lost during exec_command; reconnecting...") + self.ssh_close() + self.ssh_open() + pstdin, pstdout, pstderr = self.ssh.exec_command(command=command_r) + try: + exit_status = pstdout.channel.recv_exit_status() + stdout = pstdout.read().decode("utf-8").strip() + stderr = pstderr.read().decode("utf-8").strip() + finally: + for ch in (pstdin, pstdout, pstderr): + try: + ch.close() + except Exception: + pass + return exit_status, stdout, stderr + + executor = ThreadPoolExecutor(max_workers=1) + future = executor.submit(_do_exec) try: - pstdin, pstdout, pstderr = self.ssh.exec_command(command=command_r) - except (paramiko.SSHException, OSError, EOFError): - # Connection may have died (e.g. keepalive timeout during - # a long asyncio.sleep between polls). Reconnect once. - logger.warning("SSH connection lost during exec_command; reconnecting...") - self.ssh_close() - self.ssh_open() - pstdin, pstdout, pstderr = self.ssh.exec_command(command=command_r) - try: - exit_status = pstdout.channel.recv_exit_status() - stdout = pstdout.read().decode("utf-8").strip() - stderr = pstderr.read().decode("utf-8").strip() + exit_status, stdout, stderr = future.result(timeout=self.remote_command_timeout_sec) + except concurrent.futures.TimeoutError as exc: + logger.error(f"Remote command timed out after {self.remote_command_timeout_sec}s: {command_r}") + try: + self.ssh_close() + except Exception: + pass + raise RuntimeError(f"Remote command timed out after {self.remote_command_timeout_sec}s: {command_r}") from exc finally: - for ch in (pstdin, pstdout, pstderr): - try: - ch.close() - except Exception: - pass + executor.shutdown(wait=False) + if exit_status != 0: logger.error(f"Remote command failed: {command_r}") logger.error(f"stdout={stdout}") @@ -423,13 +466,19 @@ def _sftp_lstat_with_retry(self, path: str, max_retries=3, timeout_sec=5.0): def is_file(self, file_name: str) -> bool: if self.machine_type == "local": return os.path.isfile(file_name) - fileattr = self._sftp_lstat_with_retry(file_name) + try: + fileattr = self._sftp_lstat_with_retry(file_name) + except (RuntimeError, OSError): + return False return stat.S_ISREG(fileattr.st_mode) def is_dir(self, dir_name: str) -> bool: if self.machine_type == "local": return os.path.isdir(dir_name) - fileattr = self._sftp_lstat_with_retry(dir_name) + try: + fileattr = self._sftp_lstat_with_retry(dir_name) + except (RuntimeError, OSError): + return False return stat.S_ISDIR(fileattr.st_mode) def exist(self, object_name: str) -> bool: @@ -451,7 +500,7 @@ def get_job_list_as_text(self): return stdout.split("\n") def delete_job(self, jobid): - stdout, _ = self.run_command(f"{self.jobdel} {jobid}") + stdout, _ = self.run_command(f"{self.jobdel} {shlex.quote(str(jobid))}") return stdout.split("\n") @@ -493,7 +542,7 @@ def _get_sftp_file(self, source, target, exclude_patterns): def _put_sftp_file(self, source, target, exclude_patterns): if exclude_patterns and any(re.match(p, os.path.basename(source)) for p in exclude_patterns): return - self.server_machine.run_command(f"mkdir -p {os.path.dirname(target)}") + self.server_machine.run_command(f"mkdir -p {shlex.quote(os.path.dirname(target))}") self.server_machine.ssh_open() self.server_machine.sftp.put(source, target) @@ -513,7 +562,7 @@ def _get_sftp_dir(self, source, target, exclude_patterns): self._get_sftp_dir(remote_path, local_path, exclude_patterns) def _put_sftp_dir(self, source, target, exclude_patterns): - self.server_machine.run_command(f"mkdir -p {target}") + self.server_machine.run_command(f"mkdir -p {shlex.quote(target)}") self.server_machine.ssh_open() sftp = self.server_machine.sftp for item in os.listdir(source): @@ -539,7 +588,7 @@ def _transfer(self, from_path, to_path, exclude_patterns, dir_transfer, directio # Ensure target directory exists to_dir = os.path.dirname(to_path) if not dir_transfer else to_path if direction == "put": - self.server_machine.run_command(f"mkdir -p {to_dir}") + self.server_machine.run_command(f"mkdir -p {shlex.quote(to_dir)}") else: os.makedirs(to_dir, exist_ok=True) diff --git a/jqmc_workflow/_output_parser.py b/jqmc_workflow/_output_parser.py index 48394b32..735e4aec 100644 --- a/jqmc_workflow/_output_parser.py +++ b/jqmc_workflow/_output_parser.py @@ -211,11 +211,12 @@ def repair_forces_from_output(work_dir: str) -> bool: if forces is None: return False - # Update the TOML - state = toml.load(state_path) + # Update the TOML atomically via the canonical state API. + from ._state import _write, read_state + + state = read_state(work_dir) state.setdefault("result", {})["forces"] = forces - with open(state_path, "w") as f: - toml.dump(state, f) + _write(work_dir, state) logger.info(f" Repaired forces in {work_dir} from {os.path.basename(last_out)}") return True @@ -228,24 +229,27 @@ def repair_forces_from_output(work_dir: str) -> bool: # "Optimization step = 1/10" or "Optimization step = 1/10." _RE_OPT_STEP = re.compile(r"Optimization\s+step\s*=\s*(\d+)\s*/\s*(\d+)") +# Numeric pattern that also matches ``nan`` / ``inf`` (any case). +# Required so a diverged QMC run surfaces as a non-finite float rather +# than being silently dropped from per-step parsing. +_NUM = r"[+-]?(?:\d+\.?\d*(?:[eE][+-]?\d+)?|nan|inf)" + # "E = -76.438901 +- 0.000123 Ha" (energy line) _RE_ENERGY = re.compile( - r"E\s*=\s*([+-]?\d+\.?\d*(?:[eE][+-]?\d+)?)" - r"\s*\+\-\s*" - r"(\d+\.?\d*(?:[eE][+-]?\d+)?)" + rf"E\s*=\s*({_NUM})\s*\+\-\s*({_NUM})", + re.IGNORECASE, ) # "Max f = 17.984 +- 0.330 Ha/a.u." or "Max f = 17.984 +- 0.330" _RE_MAX_FORCE = re.compile( - r"Max\s+f\s*=\s*([+-]?\d+\.?\d*(?:[eE][+-]?\d+)?)" - r"\s*\+\-\s*" - r"(\d+\.?\d*(?:[eE][+-]?\d+)?)" + rf"Max\s+f\s*=\s*({_NUM})\s*\+\-\s*({_NUM})", + re.IGNORECASE, ) # "Max of signal-to-noise of f = max(|f|/|std f|) = 126.871." _RE_SNR = re.compile( - r"Max of signal-to-noise of f\s*=\s*max\(\|f\|/\|std f\|\)\s*=\s*" - r"([-+]?\d+(?:\.\d+)?)" + rf"Max of signal-to-noise of f\s*=\s*max\(\|f\|/\|std f\|\)\s*=\s*({_NUM})", + re.IGNORECASE, ) # "Average of walker weights is 0.799. Ideal is ~ 0.800. Adjust epsilon_AS." @@ -382,6 +386,11 @@ def _find_input_files(work_dir: str) -> list: Reads ``workflow_state.toml`` ``[[jobs]]`` records and returns the ``input_file`` paths that exist on disk, ordered by ``step``. + + Only jobs in ``"fetched"`` or ``"completed"`` status are considered: + cancelled / failed / still-submitted jobs may have stale inputs + that would mislead downstream parsing (e.g. the wrong hamiltonian + file reference). """ state_path = os.path.join(work_dir, "workflow_state.toml") if not os.path.isfile(state_path): @@ -392,6 +401,8 @@ def _find_input_files(work_dir: str) -> list: return [] files = [] for job in state.get("jobs", []): + if job.get("status") not in ("fetched", "completed"): + continue name = job.get("input_file", "") if name: path = os.path.join(work_dir, name) @@ -496,6 +507,11 @@ def _find_output_files(work_dir: str) -> list: Reads ``workflow_state.toml`` ``[[jobs]]`` records and returns the ``output_file`` paths that exist on disk, ordered by ``step``. + + Only jobs in ``"fetched"`` or ``"completed"`` status are considered: + partial output left behind by cancelled / failed / still-submitted + jobs would otherwise pollute parser aggregations (timing breakdowns, + SNR series, force tables) with garbage data. """ state_path = os.path.join(work_dir, "workflow_state.toml") if not os.path.isfile(state_path): @@ -506,6 +522,8 @@ def _find_output_files(work_dir: str) -> list: return [] files = [] for job in state.get("jobs", []): + if job.get("status") not in ("fetched", "completed"): + continue name = job.get("output_file", "") if name: path = os.path.join(work_dir, name) diff --git a/jqmc_workflow/_phase.py b/jqmc_workflow/_phase.py index 312d20a9..55d35e61 100644 --- a/jqmc_workflow/_phase.py +++ b/jqmc_workflow/_phase.py @@ -221,12 +221,12 @@ def allowed_actions( ) -> list[str]: """Return the list of actions allowed for the given *phase* / *status*. - When *status* is ``FAILED`` only ``recover_*`` and ``rollback_phase`` - actions are kept. When *status* is ``RUNNING`` configuration actions - are excluded. + When *status* is ``FAILED`` or ``CANCELLED`` only ``recover_*`` and + ``rollback_phase`` actions are kept. When *status* is ``RUNNING`` + configuration actions are excluded. """ phase_actions = list(PHASE_ALLOWED_ACTIONS.get(phase, [])) - if status == WorkflowStatus.FAILED: + if status in (WorkflowStatus.FAILED, WorkflowStatus.CANCELLED): phase_actions = [a for a in phase_actions if a.startswith("recover_")] phase_actions.append("rollback_phase") if status == WorkflowStatus.RUNNING: diff --git a/jqmc_workflow/_state.py b/jqmc_workflow/_state.py index aa4442a5..8ba84993 100644 --- a/jqmc_workflow/_state.py +++ b/jqmc_workflow/_state.py @@ -17,6 +17,7 @@ completed -- scheduler reports job finished fetched -- results transferred back to local machine failed -- job failed + cancelled -- user cancelled via CLI before completion """ # Copyright (C) 2024- Kosuke Nakano @@ -80,6 +81,7 @@ class JobStatus(str, Enum): COMPLETED = "completed" FETCHED = "fetched" FAILED = "failed" + CANCELLED = "cancelled" class CompletionStatus(str, Enum): @@ -107,7 +109,9 @@ class CompletionStatus(str, Enum): def _now_iso() -> str: - return datetime.now().isoformat(timespec="seconds") + # Local time *with* tz suffix: stays unambiguous when state files are + # shared between machines in different zones or compared across DST. + return datetime.now().astimezone().isoformat(timespec="seconds") def create_state( @@ -118,9 +122,9 @@ def create_state( ) -> dict: """Create (or reset) workflow_state.toml in *directory*. - If the file already exists, the ``[estimation]`` and ``[[jobs]]`` - sections are preserved so that pilot-run results and job history - survive a restart. + If the file already exists, ``[estimation]``, ``[[jobs]]``, and + ``[input_fingerprints]`` are preserved so that pilot-run results, + job history, and the staleness baseline survive a restart. """ if status not in VALID_STATUSES: raise ValueError(f"Invalid status '{status}'. Must be one of {VALID_STATUSES}") @@ -129,6 +133,7 @@ def create_state( existing = read_state(directory) preserved_estimation = existing.get("estimation", {}) preserved_jobs = existing.get("jobs", []) + preserved_fingerprints = existing.get("input_fingerprints", {}) state = { "workflow": { @@ -144,6 +149,8 @@ def create_state( if preserved_estimation: state["estimation"] = preserved_estimation + if preserved_fingerprints: + state["input_fingerprints"] = preserved_fingerprints _write(directory, state) return state @@ -155,49 +162,163 @@ def read_state(directory: str) -> dict: if not os.path.isfile(path): return {} state = toml.load(path) - # Migrate legacy single [job] -> [[jobs]] list + # Migrate legacy single [job] -> [[jobs]] list. Persist the + # migration so subsequent reads don't repeat the conversion. if "job" in state and "jobs" not in state: old_job = state.pop("job") - if old_job: - state["jobs"] = [old_job] - else: - state["jobs"] = [] + state["jobs"] = [old_job] if old_job else [] + try: + _write(directory, state) + except OSError as exc: + logger.warning(f"Failed to persist legacy [job] migration in {path}: {exc}") return state -def _check_normal_termination(directory: str, jobs: list) -> list[str]: - """Check fetched output files for the ``Program ends`` marker. +def _has_program_ends(filepath: str) -> bool | None: + """Return ``True`` if *filepath*'s tail contains ``Program ends``. + + ``None`` if the file is absent or unreadable (caller decides how to + treat that -- :func:`_check_normal_termination` ignores absent files + as "nothing to assert", while :func:`reconcile_fetched_jobs` treats + them as "not yet finished"). Only the last 8 KiB is read since the + marker is always the final log line. + """ + if not os.path.isfile(filepath): + return None + try: + with open(filepath, errors="replace") as f: + f.seek(0, 2) + size = f.tell() + f.seek(max(0, size - 8192)) + tail = f.read() + except OSError: + return None + return "Program ends" in tail - Returns a list of output-file names that exist on disk but do **not** - contain the ``Program ends`` line -- a strong signal that the - computation was killed (e.g. wall-time expiration) before normal - termination. - Files that are absent, unreadable, or binary are silently skipped. +def _check_normal_termination(directory: str, jobs: list) -> list[str]: + """Return output-file names whose contents lack ``Program ends``. + + Only jobs that were *meant* to complete normally are inspected -- + namely status ``"fetched"`` or ``"completed"``. Jobs in + ``"submitted"``, ``"failed"``, or ``"cancelled"`` status are skipped + because their output is expected to be incomplete (still running, + already known-failed, or intentionally aborted), and a partial file + left over from such a job would otherwise produce a false-positive + "abnormal termination" verdict that flips the whole workflow to + FAILED. + + Files that are absent on disk are silently skipped -- they say + nothing about whether the remote computation ended normally. Files + present without the marker are reported as abnormal terminations + (e.g. wall-time kill, process crash). """ abnormal: list[str] = [] for job in jobs: + if job.get("status") not in ("fetched", "completed"): + continue output_file = job.get("output_file", "") if not output_file: continue - filepath = os.path.join(directory, output_file) - if not os.path.isfile(filepath): - continue # not fetched yet -- nothing to check - try: - with open(filepath, errors="replace") as f: - # Read only the tail (last 8 KiB) for efficiency; - # "Program ends ..." is always the last log line. - f.seek(0, 2) - size = f.tell() - f.seek(max(0, size - 8192)) - tail = f.read() - if "Program ends" not in tail: - abnormal.append(output_file) - except OSError: - continue # unreadable -- skip + result = _has_program_ends(os.path.join(directory, output_file)) + if result is False: + abnormal.append(output_file) return abnormal +def reconcile_fetched_jobs_recursive(directory: str) -> int: + """Run :func:`reconcile_fetched_jobs` on *directory* and every nested + sub-directory that contains its own ``workflow_state.toml``. + + LRDMC / MCMC / VMC pilots live in subdirectories (``_pilot_b/``, + ``_pilot_a/_pilot1/``, ``_pilot/``, ...) and each carries its own + state file. Calling :func:`reconcile_fetched_jobs` only on the + top-level workflow directory misses those -- so a pilot job that + finished on the cluster but whose state record is stuck on + ``"submitted"`` (e.g. orchestrator died between job completion and + fetch-finalize) would not be picked up before the production phase + tries to resume it via SSH. + + A malformed state.toml in any nested directory is logged and + skipped; the walk continues so a single corrupted pilot record + does not block production-level reconciliation. + + Returns the total number of jobs reconciled across all directories. + """ + total = 0 + for dirpath, dirnames, filenames in os.walk(directory): + if STATE_FILENAME not in filenames: + continue + try: + total += reconcile_fetched_jobs(dirpath) + except Exception as exc: + logger.warning( + f"reconcile_fetched_jobs_recursive: skipping {dirpath} due to error ({exc.__class__.__name__}: {exc})" + ) + return total + + +def reconcile_fetched_jobs(directory: str) -> int: + """Promote orphaned ``[[jobs]]`` records to ``"fetched"``. + + A job whose status is ``"submitted"`` or ``"completed"`` is promoted + to ``"fetched"`` when **both**: + + * its ``output_file`` is present locally with a ``Program ends`` + marker (the run finished normally on remote), AND + * at least one ``.h5`` checkpoint is present in the directory + (the workflow has a restart point to continue from). + + The ``.h5`` precondition prevents the next phase from crashing with + ``"no restart checkpoint found"`` when only the ``.out`` got + rsync'd locally but ``restart.h5`` is still on the remote. In + that case we leave the job ``"submitted"`` so the normal + ``_submit_and_wait`` resume path re-fetches via SSH. + + Handles the case where the workflow process was killed between + job completion and the fetch-finalize state update, while the + actual output and restart files have since landed locally + (e.g. via rsync or a separate fetch). + + Returns the number of jobs reconciled. + """ + state = read_state(directory) + jobs = state.get("jobs", []) + if not jobs or not os.path.isdir(directory): + return 0 + # Cheap one-shot listdir: avoids per-job globbing. Promotion is + # only safe if *some* checkpoint is on disk; we don't require an + # exact name match because workflows accept several conventions + # (``restart.h5``, ``lrdmc.h5``, ``hamiltonian_data_opt_step_*.h5``). + has_h5 = any(fn.endswith(".h5") for fn in os.listdir(directory)) + reconciled = 0 + for job in jobs: + status = job.get("status") + if status not in ("submitted", "completed"): + continue + output_file = job.get("output_file", "") + if not output_file: + continue + if _has_program_ends(os.path.join(directory, output_file)) is not True: + continue + if not has_h5: + logger.warning( + f"reconcile_fetched_jobs: not promoting {output_file} -- " + f"no .h5 checkpoint in {directory}; will let normal resume " + f"path try to fetch the missing checkpoint via SSH." + ) + continue + now = _now_iso() + job["status"] = "fetched" + job.setdefault("completed_at", now) + job["fetched_at"] = now + reconciled += 1 + if reconciled: + state.setdefault("workflow", {})["updated_at"] = _now_iso() + _write(directory, state) + return reconciled + + def validate_completion( directory: str, output_values: dict | None = None, @@ -320,9 +441,14 @@ def update_status( logger.warning(f"No workflow_state.toml in {directory}; creating minimal one.") state = {"workflow": {}, "jobs": [], "result": {}} - state.setdefault("workflow", {}) - state["workflow"]["status"] = status_str - state["workflow"]["updated_at"] = _now_iso() + wf = state.setdefault("workflow", {}) + # Populate required identity fields if missing (e.g. when called + # without a prior Container.create_state). + wf.setdefault("label", os.path.basename(os.path.abspath(directory)) or "workflow") + wf.setdefault("type", "Workflow") + wf.setdefault("created_at", _now_iso()) + wf["status"] = status_str + wf["updated_at"] = _now_iso() if phase is not None: state["workflow"]["phase"] = phase.value if hasattr(phase, "value") else phase @@ -664,8 +790,19 @@ def get_input_fingerprints(directory: str) -> dict[str, dict]: def _write(directory: str, state: dict): - """Write state dict to workflow_state.toml.""" + """Write state dict to workflow_state.toml atomically. + + Writes to a sibling ``.tmp`` file, ``fsync``s it, then ``os.replace``s + over the destination. This guarantees that ``workflow_state.toml`` + either reflects the previous successful write or the new state in + full -- never a truncated mix -- even if the process is killed + mid-write (SIGKILL, power loss, ...). + """ path = os.path.join(directory, STATE_FILENAME) os.makedirs(directory, exist_ok=True) - with open(path, "w") as f: + tmp = path + ".tmp" + with open(tmp, "w") as f: toml.dump(state, f) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) diff --git a/jqmc_workflow/_transfer.py b/jqmc_workflow/_transfer.py index 5f718d27..5d88c057 100644 --- a/jqmc_workflow/_transfer.py +++ b/jqmc_workflow/_transfer.py @@ -38,6 +38,7 @@ import fnmatch import glob import os +import shlex from logging import getLogger from ._machine import Machine, Machines_handler @@ -239,7 +240,13 @@ def get_objects(self, from_objects=None, exclude_patterns=None, *, work_dir=None # -- remove (local + remote) ---------------------------------- - def remove_objects(self, patterns: list[str], *, work_dir: str | None = None) -> None: + def remove_objects( + self, + patterns: list[str], + *, + work_dir: str | None = None, + protected_basenames: "frozenset[str] | set[str] | None" = None, + ) -> None: """Delete files matching *patterns* from local and (if remote) server. Matching is **recursive** -- each pattern is applied to *work_dir* @@ -252,16 +259,26 @@ def remove_objects(self, patterns: list[str], *, work_dir: str | None = None) -> the top-level directory **and** all subdirectories. work_dir (str, optional): Local directory. When *None*, falls back to ``os.getcwd()``. + protected_basenames (frozenset[str], optional): + Basenames that must never be deleted, even when matched by + a pattern. Used by :meth:`Workflow._cleanup_files` to + preserve ``workflow_state.toml`` against over-broad + patterns like ``"*.toml"``. """ local_cwd = os.path.abspath(work_dir) if work_dir else os.path.abspath(os.getcwd()) + protected = frozenset(protected_basenames or ()) # -- Local deletion (always, recursive) ------------------- for pattern in patterns: for fpath in sorted(glob.glob(os.path.join(local_cwd, "**", pattern), recursive=True)): - if os.path.isfile(fpath): - os.remove(fpath) - relpath = os.path.relpath(fpath, local_cwd) - logger.info(f" Cleanup: removed local file {relpath}") + if not os.path.isfile(fpath): + continue + if os.path.basename(fpath) in protected: + logger.warning(f" Cleanup: refusing to delete protected file {os.path.relpath(fpath, local_cwd)}") + continue + os.remove(fpath) + relpath = os.path.relpath(fpath, local_cwd) + logger.info(f" Cleanup: removed local file {relpath}") # -- Remote deletion (only for non-local machines) -------- if self.server_machine.machine_type == "local": @@ -278,9 +295,15 @@ def remove_objects(self, patterns: list[str], *, work_dir: str | None = None) -> return server_dir = local_cwd.replace(local_root, server_root) + # Build a `find ... -name X ! -name P1 ! -name P2 ... -delete` clause + # so protected files are spared remotely as well. + protect_clause = " ".join(f"! -name {shlex.quote(p)}" for p in sorted(protected)) for pattern in patterns: try: - self.server_machine.run_command(f"find {server_dir} -name '{pattern}' -type f -delete") + # Quote both the directory and the user-supplied glob to + # prevent shell-meta-character injection. + cmd = (f"find {shlex.quote(server_dir)} -name {shlex.quote(pattern)} {protect_clause} -type f -delete").rstrip() + self.server_machine.run_command(cmd) logger.info(f" Cleanup: removed remote files matching {pattern} (recursive)") except Exception as exc: logger.warning(f" Cleanup: failed to remove remote '{pattern}': {exc}") diff --git a/jqmc_workflow/launcher.py b/jqmc_workflow/launcher.py index 7fd94b5c..bea8c1ba 100644 --- a/jqmc_workflow/launcher.py +++ b/jqmc_workflow/launcher.py @@ -296,7 +296,7 @@ def get_session_state(self) -> dict: from ._state import get_workflow_summary workflows = {} - completed = failed = 0 + completed = failed = cancelled = 0 running_labels: list[str] = [] pending_labels: list[str] = [] @@ -311,6 +311,8 @@ def get_session_state(self) -> dict: completed += 1 elif s == "failed": failed += 1 + elif s == "cancelled": + cancelled += 1 elif s in ("running", "submitted"): running_labels.append(cw.label) else: @@ -322,6 +324,7 @@ def get_session_state(self) -> dict: "progress": { "completed": completed, "failed": failed, + "cancelled": cancelled, "running": running_labels, "pending": pending_labels, "total": len(self.workflows), @@ -417,7 +420,16 @@ async def async_launch(self): """ completed = set() failed = set() + skipped = set() # failed transitively because an upstream dep failed pending = set(self.workflows_by_label.keys()) + # label -> {"reason": str, "where": str, "kind": str} + failure_info: dict[str, dict] = {} + + def _where(label: str) -> str: + cw = self.workflows_by_label.get(label) + if cw is None: + return "(unknown)" + return getattr(cw, "project_dir", None) or getattr(cw, "dirname", None) or "(unknown)" logger.info("") logger.info("=" * 50) @@ -433,10 +445,17 @@ async def async_launch(self): for label in list(pending): deps = self.dependency_dict[label] # If any dep failed, this workflow cannot run - if any(d in failed for d in deps): - logger.error(f"[{label}] Skipping -- dependency failed: {[d for d in deps if d in failed]}") + failed_deps = [d for d in deps if d in failed] + if failed_deps: + logger.error(f"[{label}] Skipping -- dependency failed: {failed_deps}") pending.discard(label) failed.add(label) + skipped.add(label) + failure_info[label] = { + "kind": "skipped", + "reason": f"upstream dependency failed: {failed_deps}", + "where": _where(label), + } continue # All deps done? if all(d in completed for d in deps): @@ -450,12 +469,20 @@ async def async_launch(self): logger.info("-" * 50) logger.info(f" [{label}] Launching...") logger.info("-" * 50) - task = asyncio.create_task(self._run_workflow(label, cw)) + task = asyncio.create_task(self._run_workflow(label, cw), name=label) running[label] = task if not running: if pending: logger.error(f"Deadlock! Remaining: {pending}") + for label in pending: + failure_info[label] = { + "kind": "deadlock", + "reason": "deadlock: dependencies could not be resolved", + "where": _where(label), + } + failed.update(pending) + pending.clear() break break @@ -463,26 +490,50 @@ async def async_launch(self): done_tasks, _ = await asyncio.wait(running.values(), return_when=asyncio.FIRST_COMPLETED) for task in done_tasks: - # Find which label this task corresponds to - label = None - for lbl, t in list(running.items()): - if t is task: - label = lbl - break - if label is None: + label = task.get_name() + if label not in running: continue - del running[label] + # Task.exception() raises CancelledError on a cancelled + # task -- check cancellation BEFORE inspecting exception + # so the Launcher's main loop doesn't crash if a single + # workflow task is cancelled. + if task.cancelled(): + logger.warning(f"[{label}] CANCELLED") + failed.add(label) + failure_info[label] = { + "kind": "cancelled", + "reason": "task cancelled", + "where": _where(label), + } + continue + exc = task.exception() if exc: logger.error(f"[{label}] FAILED: {exc}") failed.add(label) + failure_info[label] = { + "kind": "exception", + "reason": f"{type(exc).__name__}: {exc}", + "where": _where(label), + } else: cw = self.workflows_by_label[label] if getattr(cw, "status", None) == "failed": - logger.error(f"[{label}] FAILED (status=failed)") + err = "" + try: + err = (cw.output_values or {}).get("error", "") + except Exception: + err = "" + msg = err or "workflow returned status=failed (no detail)" + logger.error(f"[{label}] FAILED (status=failed): {msg}") failed.add(label) + failure_info[label] = { + "kind": "status_failed", + "reason": msg, + "where": _where(label), + } else: logger.info(f"[{label}] Completed.") completed.add(label) @@ -493,8 +544,25 @@ async def async_launch(self): logger.info(" DAG execution summary") logger.info("-" * 50) logger.info(f" Completed : {len(completed)}") - logger.info(f" Failed : {len(failed)}") - logger.info(f" Skipped : {len(pending)}") + logger.info(f" Failed : {len(failed)} (of which skipped due to upstream: {len(skipped)})") + logger.info(f" Pending : {len(pending)}") + if failure_info: + logger.info("-" * 50) + logger.info(" Failure details") + logger.info("-" * 50) + # Show direct failures first, then transitively-skipped ones, + # so the root cause is easy to spot at the top. + order = {"exception": 0, "status_failed": 1, "cancelled": 2, "deadlock": 3, "skipped": 4} + for label in sorted(failure_info, key=lambda lb: (order.get(failure_info[lb]["kind"], 99), lb)): + info = failure_info[label] + logger.info(f" - [{label}] ({info['kind']})") + logger.info(f" where : {info['where']}") + # Truncate very long reason strings so the summary stays readable; + # the full message is already in the body of the log above. + reason = info["reason"] + if len(reason) > 400: + reason = reason[:400] + " ...[truncated]" + logger.info(f" reason: {reason}") logger.info("=" * 50) from ._header_footer import _print_footer diff --git a/jqmc_workflow/lrdmc_ext_workflow.py b/jqmc_workflow/lrdmc_ext_workflow.py index e1444646..6777ceaf 100644 --- a/jqmc_workflow/lrdmc_ext_workflow.py +++ b/jqmc_workflow/lrdmc_ext_workflow.py @@ -52,7 +52,7 @@ GFMC_MIN_COLLECT_STEPS, GFMC_MIN_WARMUP_STEPS, ) -from ._state import WorkflowStatus +from ._state import WorkflowStatus, read_state from .lrdmc_workflow import LRDMC_Workflow from .workflow import Container, Workflow @@ -118,10 +118,19 @@ class LRDMC_Ext_Workflow(Workflow): find its own optimal ``num_projection_per_measurement``. Set to *None* to disable auto-calibration (requires explicit *num_projection_per_measurement*). Activates GFMC_n mode. - num_projection_per_measurement (int, optional): + num_projection_per_measurement (int | dict[float, int] | list[dict], optional): GFMC projections per measurement. When given explicitly, automatic calibration is disabled and this value is used - for every ``alat``. Activates GFMC_n mode. + for every ``alat``. Activates GFMC_n mode. Accepted forms: + + * ``int`` -- the same value for every alat. + * ``dict[float, int]`` -- per-alat values; keys must cover + every alat in ``alat_list`` exactly. + * ``list[dict]`` -- per-alat values as records + ``{"alat": float, "nmpm": int}``. This form is TOML-safe + (no float dict keys) and is the recommended shape when + wired through ``ValueFrom`` from an upstream workflow. + Normalized to the dict form internally. non_local_move (str, optional): Non-local move treatment. Default from ``jqmc_miscs``. E_scf (float, optional): @@ -205,6 +214,12 @@ class LRDMC_Ext_Workflow(Workflow): Statistical error on ``extrapolated_energy`` (Ha). per_alat_results (dict): Per-alat energy/error results keyed by ``alat``. + nmpm_per_alat (list[dict]): + Averaged GFMC projections per alat as records + ``{"alat": float, "nmpm": int}``. Suitable for piping into + a downstream GFMC_n ``LRDMC_Ext_Workflow`` via ``ValueFrom`` + as ``num_projection_per_measurement``. Only present when + sub-run outputs could be parsed. errors (list[str]): Error messages for alat runs that failed. error (str): @@ -287,6 +302,19 @@ def __init__( # None -- GFMC_t mode (uses time_projection_tau) # int -- same value for every alat # dict -- per-alat values; keys must cover every alat in alat_list + # list of {"alat": float, "nmpm": int} -- TOML-safe wire form + # (used by ValueFrom upstream). Normalized to dict here. + if isinstance(num_projection_per_measurement, list): + try: + num_projection_per_measurement = { + float(entry["alat"]): int(entry["nmpm"]) for entry in num_projection_per_measurement + } + except (KeyError, TypeError) as exc: + raise ValueError( + f"num_projection_per_measurement list entries must be " + f"dicts with keys 'alat' and 'nmpm'; got " + f"{num_projection_per_measurement!r}" + ) from exc if isinstance(num_projection_per_measurement, dict): missing = [a for a in self.alat_list if a not in num_projection_per_measurement] if missing: @@ -359,7 +387,12 @@ def _make_lrdmc_workflow(self, alat): pilot_steps=self.pilot_steps, num_gfmc_projections=self.num_gfmc_projections, max_continuation=self.max_continuation, - cleanup_patterns=self.cleanup_patterns, + # Children do NOT clean up: LRDMC_Ext.run() needs every alat's + # restart.h5 *after* the children complete, for the + # extrapolation step. The parent Container handles cleanup + # recursively (via ``**/`` glob) after extrapolation + # has consumed the files. + cleanup_patterns=None, precision_mode=self.precision_mode, ) enc = Container( @@ -383,6 +416,51 @@ def configure(self) -> dict: "max_continuation": self.max_continuation, } + def can_resume_after_completed(self, proj_dir: str) -> bool: + """Return True if *any* child ``LRDMC_Workflow`` at any alat could + still benefit from more runs. + + ``Container`` consults this before: + + * short-circuiting on a previously completed workflow, and + * running ``_cleanup_files`` (which uses a recursive + ``**/`` glob that would otherwise delete the children's + ``restart.h5`` files that an unconverged child wants to keep). + + Returns True when any of the following hold for the *current* + ``alat_list``: + + * an alat in the list has no recorded ``energy_error`` yet + (e.g. the user extended ``alat_list`` after the prior run, so + this alat has never been executed); or + * an alat's recorded ``energy_error`` exceeds + ``target_error * 1.20``. + + Returns False in fixed-step mode (no target_error) or when + every alat already has an acceptable recorded ``energy_error``. + """ + if self.target_error is None or self.num_gfmc_projections is not None: + return False + for alat in self.alat_list: + alat_dir = os.path.join(proj_dir, f"lrdmc_alat_{alat:.3f}") + try: + result = read_state(alat_dir).get("result", {}) + except Exception as exc: + logger.warning( + f"can_resume_after_completed: cannot read state for " + f"alat={alat} ({exc.__class__.__name__}: {exc}); " + f"requesting resume to recover." + ) + return True + err = result.get("energy_error") + if err is None: + # No result for this alat yet -- definitely need to run it + # (e.g. user just added this value to alat_list). + return True + if err > self.target_error * 1.20: + return True + return False + async def run(self) -> tuple: """Run LRDMC at each alat, then extrapolate to a^2->0. @@ -397,6 +475,23 @@ async def run(self) -> tuple: """ self._ensure_project_dir() _wd = self.project_dir + + # When num_projection_per_measurement comes through ValueFrom, + # the launcher resolves it via setattr after __init__, so the + # __init__-time list-of-records normalization is bypassed. + # Re-apply it here so _make_lrdmc_workflow sees a dict. + if isinstance(self.num_projection_per_measurement, list): + try: + self.num_projection_per_measurement = { + float(entry["alat"]): int(entry["nmpm"]) for entry in self.num_projection_per_measurement + } + except (KeyError, TypeError) as exc: + raise ValueError( + f"num_projection_per_measurement list entries must be " + f"dicts with keys 'alat' and 'nmpm'; got " + f"{self.num_projection_per_measurement!r}" + ) from exc + sorted_alats = sorted(self.alat_list, reverse=True) # -- helper: run a single alat, return a uniform result tuple ------ @@ -431,7 +526,7 @@ async def _run_one(enc): logger.error(f"[{enc.label}] failed: {error}") errors.append(str(error)) continue - if status not in ("success", "completed", WorkflowStatus.COMPLETED): + if status != WorkflowStatus.COMPLETED: logger.error(f"[{enc.label}] returned status={status}") errors.append(f"{enc.label}: status={status}") continue @@ -473,6 +568,34 @@ async def _run_one(enc): return self.status, [], {"error": msg} self.output_values["per_alat_results"] = per_alat_results + + # Publish GFMC projections per alat as a TOML-safe list of + # ``{"alat": float, "nmpm": int}`` records. A downstream + # GFMC_n LRDMC_Ext_Workflow can consume this via ValueFrom and + # pass it back as ``num_projection_per_measurement`` (the + # ``__init__`` accepts this list form and normalizes to + # ``dict[float, int]``). Each child LRDMC_Workflow publishes + # ``num_projection_per_measurement`` in both GFMC_n (user/calib + # input) and GFMC_t (averaged measurement) modes. + out_values_by_alat: dict[float, dict] = { + float(_out_values["alat"]): _out_values + for _enc, _status, _out_files, _out_values, _error in all_results + if _error is None and _status == WorkflowStatus.COMPLETED and _out_values.get("alat") is not None + } + + nmpm_per_alat: list[dict] = [] + for alat in self.alat_list: + ov = out_values_by_alat.get(float(alat)) + nmpm_raw = ov.get("num_projection_per_measurement") if ov is not None else None + if nmpm_raw is None: + msg = f"Missing output_values['num_projection_per_measurement'] for alat={alat:.3f} in sub-workflow result." + logger.error(msg) + self.status = WorkflowStatus.FAILED + self.output_values["error"] = msg + return self.status, [], {"error": msg} + nmpm_per_alat.append({"alat": float(alat), "nmpm": max(int(nmpm_raw), 1)}) + self.output_values["nmpm_per_alat"] = nmpm_per_alat + self.output_files = restart_chks self.status = WorkflowStatus.COMPLETED return self.status, self.output_files, self.output_values @@ -484,25 +607,35 @@ def _extrapolate_energy(self, restart_chks: list[str]): tuple: ``(energy, error)`` or ``(None, None)``. """ - chk_args = " ".join(restart_chks) - cmd = ( - f"jqmc-tool lrdmc extrapolate-energy {chk_args} " - f"-p {self.polynomial_order} " - f"-b {self.num_gfmc_bin_blocks} " - f"-w {self.num_gfmc_warmup_steps} " - f"-c {self.num_gfmc_collect_steps}" - ) - logger.info(f" Running: {cmd}") + cmd = [ + "jqmc-tool", + "lrdmc", + "extrapolate-energy", + *restart_chks, + "-p", + str(self.polynomial_order), + "-b", + str(self.num_gfmc_bin_blocks), + "-w", + str(self.num_gfmc_warmup_steps), + "-c", + str(self.num_gfmc_collect_steps), + ] + logger.info(f" Running: {' '.join(cmd)}") try: result = subprocess.run( cmd, - shell=True, + shell=False, capture_output=True, text=True, + errors="replace", check=True, cwd=self.project_dir, ) return self._parse_extrapolation_output(result.stdout) + except FileNotFoundError as e: + logger.error(f"extrapolate-energy: '{cmd[0]}' not found on PATH ({e})") + return None, None except subprocess.CalledProcessError as e: logger.error(f"extrapolate-energy failed: {e.stderr}") return None, None diff --git a/jqmc_workflow/lrdmc_workflow.py b/jqmc_workflow/lrdmc_workflow.py index c029d4e0..e28ee742 100644 --- a/jqmc_workflow/lrdmc_workflow.py +++ b/jqmc_workflow/lrdmc_workflow.py @@ -60,6 +60,7 @@ estimate_additional_steps, estimate_required_steps, parse_net_time, + read_accumulated_measurement_steps, ) from ._input_generator import generate_input_toml, resolve_with_defaults from ._job import get_num_mpi, load_queue_data @@ -68,7 +69,7 @@ get_num_electrons, parse_survived_walkers_ratio, ) -from ._output_parser import parse_force_table +from ._output_parser import parse_force_table, parse_lrdmc_output from ._setting import ( GFMC_MIN_BIN_BLOCKS, GFMC_MIN_COLLECT_STEPS, @@ -79,6 +80,8 @@ WorkflowStatus, get_estimation, get_job_by_step, + read_state, + reconcile_fetched_jobs_recursive, set_estimation, validate_completion, ) @@ -480,6 +483,28 @@ def configure(self) -> dict: "max_continuation": self.max_continuation, } + def can_resume_after_completed(self, proj_dir: str) -> bool: + """Return True when a re-launch could still reduce ``energy_error`` toward target. + + A prior ``"completed"`` state may have been written by + :meth:`_launch_auto` after exhausting ``max_continuation`` without + meeting ``target_error``. When the user raises ``max_continuation`` + (or tightens ``target_error``) and relaunches, the recorded + ``[result].energy_error`` will still exceed ``target_error*1.20`` + and this method returns True so ``Container`` bypasses its + short-circuit and re-enters the production loop. + + Fixed-step mode (``num_gfmc_projections`` set) has no convergence + criterion and is never resumed automatically. + """ + if self.target_error is None or self.num_gfmc_projections is not None: + return False + result = read_state(proj_dir).get("result", {}) + err = result.get("energy_error") + if err is None: + return False + return err > self.target_error * 1.20 + async def run(self) -> tuple: """Run the LRDMC workflow. @@ -501,6 +526,17 @@ async def run(self) -> tuple: self._ensure_project_dir() _wd = self.project_dir + # Reconcile any orphaned "submitted"/"completed" job records whose + # output file already landed locally with a "Program ends" marker + # (e.g. workflow killed between job completion and fetch-finalize). + # Walks pilot subdirectories (``_pilot_a/``, ``_pilot_b/``) too, + # since each carries its own state file. Without this, Phase A + # would break out at the orphan record and the safety-net energy + # computation would read a stale earlier step. + n_reconciled = reconcile_fetched_jobs_recursive(_wd) + if n_reconciled: + logger.info(f" Reconciled {n_reconciled} job record(s) to 'fetched' from existing output.") + # -- Fixed-step mode --------------------------------------- if self.num_gfmc_projections is not None: return await self._launch_fixed_steps(_wd) @@ -588,7 +624,7 @@ async def _launch_fixed_steps(self, _wd): # Post-process energy (informational only, no convergence check) restart_chk = self._find_restart_chk(_wd) if restart_chk: - energy, error = self._compute_energy(restart_chk, work_dir=_wd, output_file=output_i) + energy, error = self._compute_energy(restart_chk, work_dir=_wd) if energy is not None: self.output_values["energy"] = energy self.output_values["energy_error"] = error @@ -622,7 +658,7 @@ async def _launch_fixed_steps(self, _wd): last_output = step_files[last_run][1] if last_run in step_files else None restart_chk = self._find_restart_chk(_wd) if restart_chk: - energy, error = self._compute_energy(restart_chk, work_dir=_wd, output_file=last_output) + energy, error = self._compute_energy(restart_chk, work_dir=_wd) if energy is not None: self.output_values["energy"] = energy self.output_values["energy_error"] = error @@ -650,6 +686,9 @@ async def _launch_fixed_steps(self, _wd): self.output_values["num_projection_per_measurement"] = self.num_projection_per_measurement else: self.output_values["time_projection_tau"] = self.time_projection_tau + avg_nmpm = self._resolve_avg_nmpm(_wd) + if avg_nmpm is not None: + self.output_values["num_projection_per_measurement"] = avg_nmpm if self.status != WorkflowStatus.FAILED: self.status = WorkflowStatus.COMPLETED @@ -814,7 +853,7 @@ async def _launch_auto(self, _wd): if not restart_chk: raise RuntimeError("No checkpoint found after pilot run. Cannot estimate required steps.") - _, pilot_error = self._compute_energy(restart_chk, work_dir=pilot_b_dir, output_file=output_pb) + _, pilot_error = self._compute_energy(restart_chk, work_dir=pilot_b_dir) if pilot_error is None: raise RuntimeError("Could not parse energy error from pilot run.") @@ -923,13 +962,22 @@ async def _launch_auto(self, _wd): logger.info( f" Target already achieved (cached): {cached_error:.6g} <= {self.target_error * 1.20:.6g} Ha (target*1.20)" ) + # Mode-specific key: avoid writing None (which TOML + # silently drops, breaking downstream readers). + if self._use_gfmc_n: + mode_extras = {"num_projection_per_measurement": self.num_projection_per_measurement} + else: + mode_extras = {"time_projection_tau": self.time_projection_tau} + avg_nmpm = self._resolve_avg_nmpm(_wd) + if avg_nmpm is not None: + mode_extras["num_projection_per_measurement"] = avg_nmpm self.output_values.update( energy=cached_energy, energy_error=cached_error, alat=self.alat, restart_chk=restart_chk or "", estimated_steps=estimated_steps, - num_projection_per_measurement=self.num_projection_per_measurement, + **mode_extras, ) if self.atomic_force and restart_chk: forces = self._compute_force(restart_chk, work_dir=_wd) @@ -970,46 +1018,56 @@ async def _launch_auto(self, _wd): # -- Phase B: re-estimate from accumulated data -- accumulated_measurement = 0 # measurement steps only (excl. warmup) if first_new_run > 1: - cached_accum = estimation.get("accumulated_measurement_steps") - if cached_accum is not None: - accumulated_measurement = int(cached_accum) - else: - accumulated_measurement = (first_new_run - 1) * max(estimated_steps - warmup, 0) - _re_chk = self._find_restart_chk(_wd) - if _re_chk: - _re_energy, _re_error = self._compute_energy(_re_chk, work_dir=_wd) - if _re_energy is not None and _re_error is not None: - if _re_error <= self.target_error * 1.20: - logger.info( - f" Target already met after prior runs: {_re_error:.6g} <= {self.target_error * 1.20:.6g} Ha" - ) - self.output_values.update( - energy=_re_energy, - energy_error=_re_error, - alat=self.alat, - restart_chk=_re_chk, - ) - if self.atomic_force: - forces = self._compute_force(_re_chk, work_dir=_wd) - if forces is not None: - self.output_values["forces"] = forces - first_new_run = self.max_continuation + 1 # skip loop - else: - _additional = estimate_additional_steps( - accumulated_measurement, - _re_error, - self.target_error, - ) - estimated_steps = _additional + warmup - logger.info( - f" Resuming after {first_new_run - 1} prior run(s): " - f"error={_re_error:.6g} Ha > target " - f"{self.target_error:.6g} Ha -> " - f"{estimated_steps} steps " - f"(measurement: {_additional}, warmup: {warmup}, " - f"accumulated measurement: {accumulated_measurement})" - ) + if _re_chk is None: + raise RuntimeError( + f"Phase B: {first_new_run - 1} prior run(s) marked fetched but no restart checkpoint found in {_wd}." + ) + # mcmc_counter in restart.h5 is the only trustworthy source + # for accumulated samples (planned step counts over-count + # when prior runs were cut short by max_time). + actual = read_accumulated_measurement_steps( + os.path.join(_wd, _re_chk), + warmup=warmup, + collect_steps=self.num_gfmc_collect_steps, + ) + if actual is None: + raise RuntimeError(f"Phase B: cannot read mcmc_counter from {_re_chk} in {_wd}.") + accumulated_measurement = actual + + _re_energy, _re_error = self._compute_energy_cached(_re_chk, work_dir=_wd, accumulated=actual) + if _re_energy is None or _re_error is None: + raise RuntimeError( + f"Phase B: compute-energy failed for {_re_chk} in {_wd}. Cannot decide whether to resume or stop." + ) + if _re_error <= self.target_error * 1.20: + logger.info(f" Target already met after prior runs: {_re_error:.6g} <= {self.target_error * 1.20:.6g} Ha") + self.output_values.update( + energy=_re_energy, + energy_error=_re_error, + alat=self.alat, + restart_chk=_re_chk, + ) + if self.atomic_force: + forces = self._compute_force(_re_chk, work_dir=_wd) + if forces is not None: + self.output_values["forces"] = forces + first_new_run = self.max_continuation + 1 # skip loop + else: + _additional = estimate_additional_steps( + accumulated_measurement, + _re_error, + self.target_error, + ) + estimated_steps = _additional + warmup + logger.info( + f" Resuming after {first_new_run - 1} prior run(s): " + f"error={_re_error:.6g} Ha > target " + f"{self.target_error:.6g} Ha -> " + f"{estimated_steps} steps " + f"(measurement: {_additional}, warmup: {warmup}, " + f"accumulated measurement: {accumulated_measurement})" + ) # -- Phase C: production loop -- _prev_run_steps = None @@ -1079,35 +1137,45 @@ async def _launch_auto(self, _wd): step=i, run_id=run_id_i, ) - accumulated_measurement += estimated_steps - warmup _prev_run_steps = estimated_steps last_run = i # -- Side-effects: compute energy from checkpoint (if any) -- restart_chk = self._find_restart_chk(_wd) - energy = error = None - if restart_chk: - energy, error = self._compute_energy(restart_chk, work_dir=_wd, output_file=output_i) - if energy is not None: - self.output_values["energy"] = energy - self.output_values["energy_error"] = error - self.output_values["alat"] = self.alat - self.output_values["restart_chk"] = restart_chk - logger.info(f" LRDMC energy (a={self.alat}): {energy} +- {error} Ha") - if self.atomic_force: - forces = self._compute_force(restart_chk, work_dir=_wd, output_file=output_i) - if forces is not None: - self.output_values["forces"] = forces + if restart_chk is None: + raise RuntimeError(f"Phase C: run {i} completed but no restart checkpoint found in {_wd}.") + # mcmc_counter in restart.h5 is the only trustworthy source for + # accumulated samples (planned step counts over-count when the + # run was cut short by max_time). + actual = read_accumulated_measurement_steps( + os.path.join(_wd, restart_chk), + warmup=warmup, + collect_steps=self.num_gfmc_collect_steps, + ) + if actual is None: + raise RuntimeError(f"Phase C: cannot read mcmc_counter from {restart_chk} in {_wd}.") + accumulated_measurement = actual + energy, error = self._compute_energy_cached(restart_chk, work_dir=_wd, accumulated=actual) + if energy is not None: + self.output_values["energy"] = energy + self.output_values["energy_error"] = error + self.output_values["alat"] = self.alat + self.output_values["restart_chk"] = restart_chk + logger.info(f" LRDMC energy (a={self.alat}): {energy} +- {error} Ha") + if self.atomic_force: + forces = self._compute_force(restart_chk, work_dir=_wd, output_file=output_i) + if forces is not None: + self.output_values["forces"] = forces - set_estimation( - _wd, - last_energy=energy, - last_energy_error=error, - accumulated_measurement_steps=accumulated_measurement, - last_num_gfmc_bin_blocks=self.num_gfmc_bin_blocks, - last_num_gfmc_warmup_steps=self.num_gfmc_warmup_steps, - last_num_gfmc_collect_steps=self.num_gfmc_collect_steps, - ) + set_estimation( + _wd, + last_energy=energy, + last_energy_error=error, + accumulated_measurement_steps=accumulated_measurement, + last_num_gfmc_bin_blocks=self.num_gfmc_bin_blocks, + last_num_gfmc_warmup_steps=self.num_gfmc_warmup_steps, + last_num_gfmc_collect_steps=self.num_gfmc_collect_steps, + ) # -- Termination decision -- single source of truth -- vstatus, vmsg = validate_completion( @@ -1150,7 +1218,7 @@ async def _launch_auto(self, _wd): last_output = step_files[last_run][1] if last_run in step_files else None restart_chk = self._find_restart_chk(_wd) if restart_chk: - energy, error = self._compute_energy(restart_chk, work_dir=_wd, output_file=last_output) + energy, error = self._compute_energy_cached(restart_chk, work_dir=_wd) if energy is not None: self.output_values["energy"] = energy self.output_values["energy_error"] = error @@ -1179,6 +1247,9 @@ async def _launch_auto(self, _wd): self.output_values["num_projection_per_measurement"] = self.num_projection_per_measurement else: self.output_values["time_projection_tau"] = self.time_projection_tau + avg_nmpm = self._resolve_avg_nmpm(_wd) + if avg_nmpm is not None: + self.output_values["num_projection_per_measurement"] = avg_nmpm if self.status != WorkflowStatus.FAILED: self.status = WorkflowStatus.COMPLETED @@ -1186,6 +1257,26 @@ async def _launch_auto(self, _wd): # -- Utility methods ------------------------------------------- + def _resolve_avg_nmpm(self, work_dir: str) -> int | None: + """Parse the GFMC_t output log for the averaged number of projections. + + Returns the rounded ``avg_num_projections`` as an int (>=1), or + None if the diagnostic is unavailable. Used to expose the + averaged nmpm via ``output_values["num_projection_per_measurement"]`` + so that downstream GFMC_n runs can consume it via ``ValueFrom``. + """ + try: + diag = parse_lrdmc_output(work_dir) + except Exception: + return None + avg = getattr(diag, "avg_num_projections", None) if diag is not None else None + if avg is None: + return None + try: + return max(int(round(float(avg))), 1) + except (TypeError, ValueError): + return None + def _find_restart_chk(self, work_dir: str) -> str | None: """Locate the LRDMC restart checkpoint file in *work_dir*.""" for pattern in ["restart.h5", "lrdmc.h5", "*.h5"]: @@ -1194,62 +1285,112 @@ def _find_restart_chk(self, work_dir: str) -> str | None: return os.path.basename(matches[-1]) return None - def _compute_energy(self, restart_chk: str, work_dir: str, output_file: str | None = None): - """Parse energy from *output_file* or run ``jqmc-tool lrdmc compute-energy``. + def _compute_energy_cached(self, restart_chk: str, work_dir: str, accumulated: int | None = None): + """Return (energy, error) using ``[estimation]`` cache when fresh. - When *output_file* is given the energy is read directly from - the ``jqmc`` stdout (``Total Energy: E = ... +- ... Ha.``). - This avoids the overhead of re-running ``jqmc-tool`` when - the post-processing parameters (-b, -w, -c) are the same as - in the input TOML -- which is always the case for a fresh run. + The cache (``last_energy`` / ``last_energy_error`` in + ``workflow_state.toml``) is considered fresh when the recorded + ``accumulated_measurement_steps`` matches the current + ``restart.h5`` ``mcmc_counter`` *and* the post-processing + parameters (``-b``, ``-w``, ``-c``) match the workflow's + current settings. On a hit, no subprocess is launched. - Falls back to ``jqmc-tool`` when *output_file* is *None* or - when stdout parsing fails. + On a miss, :meth:`_compute_energy` is invoked and the cache is + refreshed via :func:`set_estimation` so that subsequent + invocations within the same or later workflow runs short-circuit. + + Args: + restart_chk (str): + Checkpoint filename (basename). + work_dir (str): + Directory in which to run the command. + accumulated (int, optional): + Pre-read ``mcmc_counter`` from ``restart.h5``. Pass when + the caller has already computed it (Phase B / Phase C) to + avoid a redundant HDF5 read. + """ + if accumulated is None: + accumulated = read_accumulated_measurement_steps( + os.path.join(work_dir, restart_chk), + warmup=self.num_gfmc_warmup_steps, + collect_steps=self.num_gfmc_collect_steps, + ) + est = get_estimation(work_dir) + if ( + accumulated is not None + and est.get("last_energy") is not None + and est.get("last_energy_error") is not None + and est.get("accumulated_measurement_steps") == accumulated + and est.get("last_num_gfmc_bin_blocks") == self.num_gfmc_bin_blocks + and est.get("last_num_gfmc_warmup_steps") == self.num_gfmc_warmup_steps + and est.get("last_num_gfmc_collect_steps") == self.num_gfmc_collect_steps + ): + e, err = est["last_energy"], est["last_energy_error"] + logger.info( + f" Energy cached: E = {e} +- {err} Ha " + f"(acc={accumulated}, b={self.num_gfmc_bin_blocks}, " + f"w={self.num_gfmc_warmup_steps}, c={self.num_gfmc_collect_steps})" + ) + return e, err + energy, error = self._compute_energy(restart_chk, work_dir=work_dir) + if energy is not None and accumulated is not None: + set_estimation( + work_dir, + last_energy=energy, + last_energy_error=error, + accumulated_measurement_steps=accumulated, + last_num_gfmc_bin_blocks=self.num_gfmc_bin_blocks, + last_num_gfmc_warmup_steps=self.num_gfmc_warmup_steps, + last_num_gfmc_collect_steps=self.num_gfmc_collect_steps, + ) + return energy, error + + def _compute_energy(self, restart_chk: str, work_dir: str): + """Run ``jqmc-tool lrdmc compute-energy`` against *restart_chk*. + + Always invokes ``jqmc-tool`` so that the returned (energy, error) + carry full numerical precision. Parsing jqmc's printed + ``Total Energy: E = ... +- ... Ha.`` line is lossy (``%.5f`` + formatting), which is unsuitable for values persisted to + ``workflow_state.toml`` or compared against ``target_error``. Args: restart_chk (str): Checkpoint filename (basename). work_dir (str): Directory in which to run the command. - output_file (str, optional): - Stdout filename (basename) of the ``jqmc`` run. Returns: tuple: - ``(energy, error)`` or ``(None, None)``. + ``(energy, error)`` or ``(None, None)`` on failure. """ - # Fast path: parse from jqmc stdout - if output_file is not None: - out_path = os.path.join(work_dir, output_file) - if os.path.isfile(out_path): - try: - with open(out_path) as fh: - text = fh.read() - energy, error = self._parse_energy_output(text) - if energy is not None: - logger.info(f" Energy from {output_file} (jqmc-tool skipped): E = {energy} +- {error} Ha") - return energy, error - except OSError: - pass - - # Fallback: jqmc-tool - cmd = ( - f"jqmc-tool lrdmc compute-energy {restart_chk} " - f"-b {self.num_gfmc_bin_blocks} " - f"-w {self.num_gfmc_warmup_steps} " - f"-c {self.num_gfmc_collect_steps}" - ) - logger.info(f" Running: {cmd}") + cmd = [ + "jqmc-tool", + "lrdmc", + "compute-energy", + restart_chk, + "-b", + str(self.num_gfmc_bin_blocks), + "-w", + str(self.num_gfmc_warmup_steps), + "-c", + str(self.num_gfmc_collect_steps), + ] + logger.info(f" Running: {' '.join(cmd)}") try: result = subprocess.run( cmd, - shell=True, + shell=False, capture_output=True, text=True, + errors="replace", check=True, cwd=work_dir, ) return self._parse_energy_output(result.stdout) + except FileNotFoundError as e: + logger.error(f"compute-energy: '{cmd[0]}' not found on PATH ({e})") + return None, None except subprocess.CalledProcessError as e: logger.error(f"compute-energy failed: {e.stderr}") return None, None @@ -1307,19 +1448,26 @@ def _compute_force(self, restart_chk: str, work_dir: str, output_file: str | Non pass # Fallback: jqmc-tool - cmd = ( - f"jqmc-tool lrdmc compute-force {restart_chk} " - f"-b {self.num_gfmc_bin_blocks} " - f"-w {self.num_gfmc_warmup_steps} " - f"-c {self.num_gfmc_collect_steps}" - ) - logger.info(f" Running: {cmd}") + cmd = [ + "jqmc-tool", + "lrdmc", + "compute-force", + restart_chk, + "-b", + str(self.num_gfmc_bin_blocks), + "-w", + str(self.num_gfmc_warmup_steps), + "-c", + str(self.num_gfmc_collect_steps), + ] + logger.info(f" Running: {' '.join(cmd)}") try: result = subprocess.run( cmd, - shell=True, + shell=False, capture_output=True, text=True, + errors="replace", check=True, cwd=work_dir, ) @@ -1334,6 +1482,9 @@ def _compute_force(self, restart_chk: str, work_dir: str, output_file: str | Non f" Ha/bohr" ) return forces + except FileNotFoundError as e: + logger.error(f"compute-force: '{cmd[0]}' not found on PATH ({e})") + return None except subprocess.CalledProcessError as e: logger.error(f"compute-force failed: {e.stderr}") return None diff --git a/jqmc_workflow/mcmc_workflow.py b/jqmc_workflow/mcmc_workflow.py index 598cfa63..147f65b5 100644 --- a/jqmc_workflow/mcmc_workflow.py +++ b/jqmc_workflow/mcmc_workflow.py @@ -49,6 +49,7 @@ estimate_additional_steps, estimate_required_steps, parse_net_time, + read_accumulated_measurement_steps, ) from ._input_generator import generate_input_toml, resolve_with_defaults from ._job import get_num_mpi, load_queue_data @@ -58,6 +59,8 @@ WorkflowStatus, get_estimation, get_job_by_step, + read_state, + reconcile_fetched_jobs_recursive, set_estimation, validate_completion, ) @@ -340,6 +343,28 @@ def configure(self) -> dict: "max_continuation": self.max_continuation, } + def can_resume_after_completed(self, proj_dir: str) -> bool: + """Return True when a re-launch could still reduce ``energy_error`` toward target. + + A prior ``"completed"`` state may have been written by + :meth:`_launch_auto` after exhausting ``max_continuation`` without + meeting ``target_error``. When the user raises ``max_continuation`` + (or tightens ``target_error``) and relaunches, the recorded + ``[result].energy_error`` will still exceed ``target_error*1.05`` + and this method returns True so ``Container`` bypasses its + short-circuit and re-enters the production loop. + + Fixed-step mode (``num_mcmc_steps`` set) has no convergence + criterion and is never resumed automatically. + """ + if self.target_error is None or self.num_mcmc_steps is not None: + return False + result = read_state(proj_dir).get("result", {}) + err = result.get("energy_error") + if err is None: + return False + return err > self.target_error * 1.05 + async def run(self) -> tuple: """Run the MCMC workflow. @@ -360,6 +385,17 @@ async def run(self) -> tuple: self._ensure_project_dir() _wd = self.project_dir + # Reconcile any orphaned "submitted"/"completed" job records whose + # output file already landed locally with a "Program ends" marker + # (e.g. workflow killed between job completion and fetch-finalize). + # Walks the pilot subdirectory (``_pilot/``) too, since it carries + # its own state file. Without this, Phase A would break out at + # the orphan record and the safety-net energy computation would + # read a stale earlier step. + n_reconciled = reconcile_fetched_jobs_recursive(_wd) + if n_reconciled: + logger.info(f" Reconciled {n_reconciled} job record(s) to 'fetched' from existing output.") + # -- Fixed-step mode --------------------------------------- if self.num_mcmc_steps is not None: return await self._launch_fixed_steps(_wd) @@ -428,7 +464,7 @@ async def _launch_fixed_steps(self, _wd): # Post-process energy (informational only, no convergence check) restart_chk = self._find_restart_chk(_wd) if restart_chk: - energy, error = self._compute_energy(restart_chk, work_dir=_wd, output_file=output_i) + energy, error = self._compute_energy(restart_chk, work_dir=_wd) if energy is not None: self.output_values["energy"] = energy self.output_values["energy_error"] = error @@ -460,7 +496,7 @@ async def _launch_fixed_steps(self, _wd): last_output = step_files[last_run][1] if last_run in step_files else None restart_chk = self._find_restart_chk(_wd) if restart_chk: - energy, error = self._compute_energy(restart_chk, work_dir=_wd, output_file=last_output) + energy, error = self._compute_energy(restart_chk, work_dir=_wd) if energy is not None: self.output_values["energy"] = energy self.output_values["energy_error"] = error @@ -535,7 +571,7 @@ async def _launch_auto(self, _wd): if not restart_chk: raise RuntimeError("No checkpoint found after pilot run. Cannot estimate required steps.") - _, pilot_error = self._compute_energy(restart_chk, work_dir=pilot_dir, output_file=output_0) + _, pilot_error = self._compute_energy(restart_chk, work_dir=pilot_dir) if pilot_error is None: raise RuntimeError("Could not parse energy error from pilot run.") @@ -677,45 +713,54 @@ async def _launch_auto(self, _wd): # -- Phase B: re-estimate from accumulated data -- accumulated_measurement = 0 # measurement steps only (excl. warmup) if first_new_run > 1: - cached_accum = estimation.get("accumulated_measurement_steps") - if cached_accum is not None: - accumulated_measurement = int(cached_accum) - else: - accumulated_measurement = (first_new_run - 1) * max(estimated_steps - warmup, 0) - _re_chk = self._find_restart_chk(_wd) - if _re_chk: - _re_energy, _re_error = self._compute_energy(_re_chk, work_dir=_wd) - if _re_energy is not None and _re_error is not None: - if _re_error <= self.target_error * 1.05: - logger.info( - f" Target already met after prior runs: {_re_error:.6g} <= {self.target_error * 1.05:.6g} Ha" - ) - self.output_values.update( - energy=_re_energy, - energy_error=_re_error, - restart_chk=_re_chk, - ) - if self.atomic_force: - forces = self._compute_force(_re_chk, work_dir=_wd) - if forces is not None: - self.output_values["forces"] = forces - first_new_run = self.max_continuation + 1 # skip loop - else: - _additional = estimate_additional_steps( - accumulated_measurement, - _re_error, - self.target_error, - ) - estimated_steps = _additional + warmup - logger.info( - f" Resuming after {first_new_run - 1} prior run(s): " - f"error={_re_error:.6g} Ha > target " - f"{self.target_error:.6g} Ha -> " - f"{estimated_steps} steps " - f"(measurement: {_additional}, warmup: {warmup}, " - f"accumulated measurement: {accumulated_measurement})" - ) + if _re_chk is None: + raise RuntimeError( + f"Phase B: {first_new_run - 1} prior run(s) marked fetched but no restart checkpoint found in {_wd}." + ) + # mcmc_counter in restart.h5 is the only trustworthy source + # for accumulated samples (planned step counts over-count + # when prior runs were cut short by max_time). + actual = read_accumulated_measurement_steps( + os.path.join(_wd, _re_chk), + warmup=warmup, + ) + if actual is None: + raise RuntimeError(f"Phase B: cannot read mcmc_counter from {_re_chk} in {_wd}.") + accumulated_measurement = actual + + _re_energy, _re_error = self._compute_energy_cached(_re_chk, work_dir=_wd, accumulated=actual) + if _re_energy is None or _re_error is None: + raise RuntimeError( + f"Phase B: compute-energy failed for {_re_chk} in {_wd}. Cannot decide whether to resume or stop." + ) + if _re_error <= self.target_error * 1.05: + logger.info(f" Target already met after prior runs: {_re_error:.6g} <= {self.target_error * 1.05:.6g} Ha") + self.output_values.update( + energy=_re_energy, + energy_error=_re_error, + restart_chk=_re_chk, + ) + if self.atomic_force: + forces = self._compute_force(_re_chk, work_dir=_wd) + if forces is not None: + self.output_values["forces"] = forces + first_new_run = self.max_continuation + 1 # skip loop + else: + _additional = estimate_additional_steps( + accumulated_measurement, + _re_error, + self.target_error, + ) + estimated_steps = _additional + warmup + logger.info( + f" Resuming after {first_new_run - 1} prior run(s): " + f"error={_re_error:.6g} Ha > target " + f"{self.target_error:.6g} Ha -> " + f"{estimated_steps} steps " + f"(measurement: {_additional}, warmup: {warmup}, " + f"accumulated measurement: {accumulated_measurement})" + ) # -- Phase C: production loop -- _prev_run_steps = None @@ -785,33 +830,42 @@ async def _launch_auto(self, _wd): run_id=run_id_i, ) step_files[i] = (input_i, output_i, run_id_i) - accumulated_measurement += estimated_steps - warmup _prev_run_steps = estimated_steps last_run = i # -- Side-effects: compute energy from checkpoint (if any) -- restart_chk = self._find_restart_chk(_wd) - energy = error = None - if restart_chk: - energy, error = self._compute_energy(restart_chk, work_dir=_wd, output_file=output_i) - if energy is not None: - self.output_values["energy"] = energy - self.output_values["energy_error"] = error - self.output_values["restart_chk"] = restart_chk - logger.info(f" MCMC energy: {energy} +- {error} Ha") - if self.atomic_force: - forces = self._compute_force(restart_chk, work_dir=_wd, output_file=output_i) - if forces is not None: - self.output_values["forces"] = forces + if restart_chk is None: + raise RuntimeError(f"Phase C: run {i} completed but no restart checkpoint found in {_wd}.") + # mcmc_counter in restart.h5 is the only trustworthy source for + # accumulated samples (planned step counts over-count when the + # run was cut short by max_time). + actual = read_accumulated_measurement_steps( + os.path.join(_wd, restart_chk), + warmup=warmup, + ) + if actual is None: + raise RuntimeError(f"Phase C: cannot read mcmc_counter from {restart_chk} in {_wd}.") + accumulated_measurement = actual + energy, error = self._compute_energy_cached(restart_chk, work_dir=_wd, accumulated=actual) + if energy is not None: + self.output_values["energy"] = energy + self.output_values["energy_error"] = error + self.output_values["restart_chk"] = restart_chk + logger.info(f" MCMC energy: {energy} +- {error} Ha") + if self.atomic_force: + forces = self._compute_force(restart_chk, work_dir=_wd, output_file=output_i) + if forces is not None: + self.output_values["forces"] = forces - set_estimation( - _wd, - last_energy=energy, - last_energy_error=error, - accumulated_measurement_steps=accumulated_measurement, - last_num_mcmc_bin_blocks=self.num_mcmc_bin_blocks, - last_num_mcmc_warmup_steps=self.num_mcmc_warmup_steps, - ) + set_estimation( + _wd, + last_energy=energy, + last_energy_error=error, + accumulated_measurement_steps=accumulated_measurement, + last_num_mcmc_bin_blocks=self.num_mcmc_bin_blocks, + last_num_mcmc_warmup_steps=self.num_mcmc_warmup_steps, + ) # -- Termination decision -- single source of truth -- vstatus, vmsg = validate_completion( @@ -854,7 +908,7 @@ async def _launch_auto(self, _wd): last_output = step_files[last_run][1] if last_run in step_files else None restart_chk = self._find_restart_chk(_wd) if restart_chk: - energy, error = self._compute_energy(restart_chk, work_dir=_wd, output_file=last_output) + energy, error = self._compute_energy_cached(restart_chk, work_dir=_wd) if energy is not None: self.output_values["energy"] = energy self.output_values["energy_error"] = error @@ -892,53 +946,107 @@ def _find_restart_chk(self, work_dir: str) -> str | None: return os.path.basename(matches[-1]) return None - def _compute_energy(self, restart_chk: str, work_dir: str, output_file: str | None = None): - """Parse energy from *output_file* or run ``jqmc-tool mcmc compute-energy``. + def _compute_energy_cached(self, restart_chk: str, work_dir: str, accumulated: int | None = None): + """Return (energy, error) using ``[estimation]`` cache when fresh. + + The cache (``last_energy`` / ``last_energy_error`` in + ``workflow_state.toml``) is considered fresh when the recorded + ``accumulated_measurement_steps`` matches the current + ``restart.h5`` ``mcmc_counter`` *and* the post-processing + parameters (``-b``, ``-w``) match the workflow's current + settings. On a hit, no subprocess is launched. - When *output_file* is given the energy is read directly from - the ``jqmc`` stdout (``Total Energy: E = ... +- ... Ha.``). - Falls back to ``jqmc-tool`` when *output_file* is *None* or - when stdout parsing fails. + On a miss, :meth:`_compute_energy` is invoked and the cache is + refreshed via :func:`set_estimation` so that subsequent + invocations within the same or later workflow runs short-circuit. + + Args: + restart_chk (str): + Checkpoint filename (basename). + work_dir (str): + Directory in which to run the command. + accumulated (int, optional): + Pre-read ``mcmc_counter`` from ``restart.h5``. Pass when + the caller has already computed it (Phase B / Phase C) to + avoid a redundant HDF5 read. + """ + if accumulated is None: + accumulated = read_accumulated_measurement_steps( + os.path.join(work_dir, restart_chk), + warmup=self.num_mcmc_warmup_steps, + ) + est = get_estimation(work_dir) + if ( + accumulated is not None + and est.get("last_energy") is not None + and est.get("last_energy_error") is not None + and est.get("accumulated_measurement_steps") == accumulated + and est.get("last_num_mcmc_bin_blocks") == self.num_mcmc_bin_blocks + and est.get("last_num_mcmc_warmup_steps") == self.num_mcmc_warmup_steps + ): + e, err = est["last_energy"], est["last_energy_error"] + logger.info( + f" Energy cached: E = {e} +- {err} Ha " + f"(acc={accumulated}, b={self.num_mcmc_bin_blocks}, " + f"w={self.num_mcmc_warmup_steps})" + ) + return e, err + energy, error = self._compute_energy(restart_chk, work_dir=work_dir) + if energy is not None and accumulated is not None: + set_estimation( + work_dir, + last_energy=energy, + last_energy_error=error, + accumulated_measurement_steps=accumulated, + last_num_mcmc_bin_blocks=self.num_mcmc_bin_blocks, + last_num_mcmc_warmup_steps=self.num_mcmc_warmup_steps, + ) + return energy, error + + def _compute_energy(self, restart_chk: str, work_dir: str): + """Run ``jqmc-tool mcmc compute-energy`` against *restart_chk*. + + Always invokes ``jqmc-tool`` so that the returned (energy, error) + carry full numerical precision. Parsing jqmc's printed + ``Total Energy: E = ... +- ... Ha.`` line is lossy (``%.5f`` + formatting), which is unsuitable for values persisted to + ``workflow_state.toml`` or compared against ``target_error``. Args: restart_chk (str): Checkpoint filename (basename). work_dir (str): Directory in which to run the command. - output_file (str, optional): - Stdout filename (basename) of the ``jqmc`` run. Returns: tuple: - ``(energy, error)`` or ``(None, None)``. + ``(energy, error)`` or ``(None, None)`` on failure. """ - # Fast path: parse from jqmc stdout - if output_file is not None: - out_path = os.path.join(work_dir, output_file) - if os.path.isfile(out_path): - try: - with open(out_path) as fh: - text = fh.read() - energy, error = self._parse_energy_output(text) - if energy is not None: - logger.info(f" Energy from {output_file} (jqmc-tool skipped): E = {energy} +- {error} Ha") - return energy, error - except OSError: - pass - - # Fallback: jqmc-tool - cmd = f"jqmc-tool mcmc compute-energy {restart_chk} -b {self.num_mcmc_bin_blocks} -w {self.num_mcmc_warmup_steps}" - logger.info(f" Running: {cmd}") + cmd = [ + "jqmc-tool", + "mcmc", + "compute-energy", + restart_chk, + "-b", + str(self.num_mcmc_bin_blocks), + "-w", + str(self.num_mcmc_warmup_steps), + ] + logger.info(f" Running: {' '.join(cmd)}") try: result = subprocess.run( cmd, - shell=True, + shell=False, capture_output=True, text=True, + errors="replace", check=True, cwd=work_dir, ) return self._parse_energy_output(result.stdout) + except FileNotFoundError as e: + logger.error(f"compute-energy: '{cmd[0]}' not found on PATH ({e})") + return None, None except subprocess.CalledProcessError as e: logger.error(f"compute-energy failed: {e.stderr}") return None, None @@ -996,14 +1104,24 @@ def _compute_force(self, restart_chk: str, work_dir: str, output_file: str | Non pass # Fallback: jqmc-tool - cmd = f"jqmc-tool mcmc compute-force {restart_chk} -b {self.num_mcmc_bin_blocks} -w {self.num_mcmc_warmup_steps}" - logger.info(f" Running: {cmd}") + cmd = [ + "jqmc-tool", + "mcmc", + "compute-force", + restart_chk, + "-b", + str(self.num_mcmc_bin_blocks), + "-w", + str(self.num_mcmc_warmup_steps), + ] + logger.info(f" Running: {' '.join(cmd)}") try: result = subprocess.run( cmd, - shell=True, + shell=False, capture_output=True, text=True, + errors="replace", check=True, cwd=work_dir, ) @@ -1018,6 +1136,9 @@ def _compute_force(self, restart_chk: str, work_dir: str, output_file: str | Non f" Ha/bohr" ) return forces + except FileNotFoundError as e: + logger.error(f"compute-force: '{cmd[0]}' not found on PATH ({e})") + return None except subprocess.CalledProcessError as e: logger.error(f"compute-force failed: {e.stderr}") return None diff --git a/jqmc_workflow/template/machine_data.yaml b/jqmc_workflow/template/machine_data.yaml index 6e583040..339e91ed 100644 --- a/jqmc_workflow/template/machine_data.yaml +++ b/jqmc_workflow/template/machine_data.yaml @@ -1,8 +1,27 @@ # Machine definitions for jqmc-workflow # Edit this file to match your environment. -# Each machine needs at least: machine_type (local or remote), queuing (true or false), workspace_root (path). -# Remote machines also require: ssh_host (Host alias in ~/.ssh/config). -# The top-level key (e.g., "my-cluster") is a nickname; it does NOT have to match the SSH host. +# +# Required fields for every machine: +# machine_type : "local" or "remote" +# queuing : true or false +# workspace_root : absolute path where job working dirs live +# jobsubmit : command used to invoke the submit script (required +# even when queuing=false -- for a plain local run use +# "bash" or "sh"; for a scheduler use "qsub" / "sbatch") +# +# Required additionally when queuing: true: +# jobcheck : command that lists jobs in the queue (e.g. "qstat") +# jobdel : command that cancels a job by id (e.g. "qdel") +# jobnum_index : 0-based token index of the job-id field in +# jobsubmit's stdout (qsub prints "1234.host" -> 0; +# sbatch prints "Submitted batch job 1234" -> 3) +# +# Required additionally when machine_type: remote: +# ssh_host : Host alias in ~/.ssh/config (the top-level YAML key +# above is a nickname and does NOT need to match) +# +# Optional: +# jobacct : scheduler accounting command (e.g. "qstat -fx") # Local execution without a batch scheduler (synchronous). # queuing: false -> bash runs submit.sh synchronously; no PID tracking needed. @@ -10,7 +29,7 @@ localhost: machine_type: local queuing: false workspace_root: /home/username/jqmc_work - jobsubmit: "bash" + jobsubmit: "bash" # required: command that runs the submit script # Local execution without a batch scheduler (asynchronous / background). # queuing: true -> bash launches submit.sh in the background and prints PID; diff --git a/jqmc_workflow/vmc_workflow.py b/jqmc_workflow/vmc_workflow.py index 40bdb7c4..63a2b15c 100644 --- a/jqmc_workflow/vmc_workflow.py +++ b/jqmc_workflow/vmc_workflow.py @@ -55,6 +55,7 @@ WorkflowStatus, get_estimation, get_job_by_step, + reconcile_fetched_jobs_recursive, set_estimation, validate_completion, ) @@ -63,6 +64,39 @@ logger = getLogger("jqmc-workflow").getChild(__name__) +def _last_opt_energy_from_log(output_file: str) -> tuple[float | None, float | None]: + """Return ``(energy, energy_error)`` of the last VMC opt step in *output_file*. + + Single source of truth shared by :meth:`VMC_Workflow._parse_output` + and :meth:`VMC_Workflow._parse_last_opt_energy`. Both used to keep + their own copy of the ``E = ... +- ...`` regex, which drifted + independently and silently dropped ``nan``/``inf`` lines. This + helper instead delegates to the canonical + :func:`_output_parser._parse_vmc_log_text` parser so any future + format change lives in exactly one place. + + A ``nan``/``inf`` energy is returned as a (non-finite) float -- not + ``None`` -- so the caller's ``math.isfinite`` check in + :func:`validate_completion` can flag diverged runs. + """ + try: + with open(output_file, errors="replace") as f: + text = f.read() + except OSError: + return None, None + from ._output_parser import _parse_vmc_log_text + + steps = _parse_vmc_log_text(text) + # Prefer the last opt-step block that actually contains an ``E =`` + # line; the very last block may be a header-only stub from a + # partial write. Note: ``nan`` is not ``None`` so a diverged final + # step is still selected, which is the whole point. + for s in reversed(steps): + if s.energy is not None and s.energy_error is not None: + return s.energy, s.energy_error + return None, None + + class VMC_Workflow(Workflow): r"""VMC (Variational Monte Carlo) Jastrow / orbital optimisation workflow. @@ -469,6 +503,17 @@ async def run(self) -> tuple: self._ensure_project_dir() _wd = self.project_dir + # Reconcile any orphaned "submitted"/"completed" job records whose + # output file already landed locally with a "Program ends" marker + # (e.g. workflow killed between job completion and fetch-finalize). + # Walks the pilot subdirectory (``_pilot/``) too, since it carries + # its own state file. SNR / slope convergence checks parse those + # output files, so a stale "submitted" record would otherwise hide + # the latest data. + n_reconciled = reconcile_fetched_jobs_recursive(_wd) + if n_reconciled: + logger.info(f" Reconciled {n_reconciled} job record(s) to 'fetched' from existing output.") + # -- Fixed-step mode --------------------------------------- if self.num_mcmc_steps is not None: return await self._launch_fixed_steps(_wd) @@ -545,6 +590,12 @@ async def _launch_fixed_steps(self, _wd): logger.info(f" VMC production run {i}/{self.max_continuation} completed.") + # Refresh output_values["energy"] from THIS iteration's log + # before validating -- otherwise validate_completion sees a + # stale or unset energy and the non-finite check at + # _state.py:387 silently passes for diverged runs. + self._parse_output(os.path.join(_wd, output_i)) + # -- Abnormal-termination guard (single source of truth) -- # target_error=None -> only Program-ends / non-finite-energy # checks are active. VMC's SNR/slope convergence is decided @@ -805,6 +856,12 @@ async def _launch_auto(self, _wd): logger.info(f" VMC production run {i}/{self.max_continuation} completed.") + # Refresh output_values["energy"] from THIS iteration's log + # before validating -- otherwise validate_completion sees a + # stale or unset energy and the non-finite check at + # _state.py:387 silently passes for diverged runs. + self._parse_output(os.path.join(_wd, output_i)) + # -- Abnormal-termination guard (single source of truth) -- # target_error=None -> only Program-ends / non-finite-energy # checks; SNR/slope convergence is evaluated separately below. @@ -960,27 +1017,19 @@ def _find_restart_chk(self, work_dir: str) -> str | None: # -- Output parsing -------------------------------------------- def _parse_output(self, output_file=None): - """Extract the last optimization energy from *output_file*.""" - if output_file is None: - return - if not os.path.isfile(output_file): - return + """Extract the last optimization step's energy from *output_file*. - energy_pattern = re.compile(r"E\s*=\s*([+-]?\d+\.\d+(?:[eE][+-]?\d+)?)\s*\+\-\s*(\d+\.\d+(?:[eE][+-]?\d+)?)") - last_match = None - try: - with open(output_file) as f: - for line in f: - m = energy_pattern.search(line) - if m: - last_match = m - except Exception: + Delegates to :func:`_output_parser._parse_vmc_log_text` so there + is a single source of truth for VMC log parsing -- any future + format change (or fix like nan/inf support) lives in one place. + """ + if output_file is None or not os.path.isfile(output_file): return - - if last_match: - self.output_values["energy"] = float(last_match.group(1)) - self.output_values["energy_error"] = float(last_match.group(2)) - logger.info(f" VMC energy: {self.output_values['energy']} +- {self.output_values['energy_error']} Ha") + e, err = _last_opt_energy_from_log(output_file) + if e is not None and err is not None: + self.output_values["energy"] = e + self.output_values["energy_error"] = err + logger.info(f" VMC energy: {e} +- {err} Ha") @staticmethod def _parse_all_snr(output_file): @@ -1022,13 +1071,18 @@ def _parse_all_energies(output_file: str) -> list[tuple[float, float]]: if not os.path.isfile(output_file): return [] try: - with open(output_file) as f: + with open(output_file, errors="replace") as f: text = f.read() from ._output_parser import _parse_vmc_log_text steps = _parse_vmc_log_text(text) return [(s.energy, s.energy_error) for s in steps if s.energy is not None and s.energy_error is not None] + except OSError as exc: + logger.warning(f"_parse_all_energies: cannot read {output_file}: {exc}") + return [] except Exception: + # Log unexpected parser failures rather than swallowing silently. + logger.exception(f"_parse_all_energies: unexpected error parsing {output_file}") return [] @staticmethod @@ -1056,6 +1110,11 @@ def _fit_energy_slope( E = np.asarray(energies, dtype=float) sigma = np.asarray(energy_errors, dtype=float) + # Replace non-positive sigmas with the median positive sigma so + # they get a finite (non-inf) weight rather than dividing by zero. + positive = sigma[sigma > 0] + floor = float(np.median(positive)) if positive.size else 1.0 + sigma = np.where(sigma > 0, sigma, floor) w = 1.0 / sigma**2 k = np.arange(len(E), dtype=float) @@ -1072,29 +1131,14 @@ def _fit_energy_slope( @staticmethod def _parse_last_opt_energy(output_file): - """Parse the last ``E = +- `` from a VMC output file. - - Extracts the energy from the *last* optimization step, which - reflects the optimized wavefunction quality. + """Parse the last optimization step's energy from a VMC output file. Returns: tuple: - ``(energy, error)`` or ``(None, None)``. + ``(energy, error)`` or ``(None, None)``. ``nan``/``inf`` + are returned as :class:`float` (not ``None``), so callers + can apply ``math.isfinite`` to detect diverged runs. """ if not os.path.isfile(output_file): return None, None - - energy_pattern = re.compile(r"E\s*=\s*([+-]?\d+\.?\d*(?:[eE][+-]?\d+)?)\s*\+\-\s*(\d+\.?\d*(?:[eE][+-]?\d+)?)") - last_match = None - try: - with open(output_file) as f: - for line in f: - m = energy_pattern.search(line) - if m: - last_match = m - except Exception: - return None, None - - if last_match: - return float(last_match.group(1)), float(last_match.group(2)) - return None, None + return _last_opt_energy_from_log(output_file) diff --git a/jqmc_workflow/wf_workflow.py b/jqmc_workflow/wf_workflow.py index 4590e21d..0552aab6 100644 --- a/jqmc_workflow/wf_workflow.py +++ b/jqmc_workflow/wf_workflow.py @@ -40,7 +40,6 @@ # POSSIBILITY OF SUCH DAMAGE. import os -import shlex import subprocess from logging import getLogger @@ -132,8 +131,8 @@ def __init__( raise ValueError(f"ao_conv_to must be None, 'cart', or 'sphe', got {ao_conv_to!r}") self.ao_conv_to = ao_conv_to - def _build_command(self) -> str: - """Build the ``jqmc-tool trexio convert-to`` CLI command.""" + def _build_command(self) -> list[str]: + """Build the ``jqmc-tool trexio convert-to`` CLI command (argv list).""" cmd = ["jqmc-tool", "trexio", "convert-to", self.trexio_file] cmd += ["-o", self.hamiltonian_file] @@ -154,7 +153,7 @@ def _build_command(self) -> str: if self.ao_conv_to is not None: cmd += ["--ao-conv-to", str(self.ao_conv_to)] - return shlex.join(cmd) + return cmd def configure(self) -> dict: """Validate parameters and return configuration summary.""" @@ -179,20 +178,25 @@ async def run(self) -> tuple: _wd = self.project_dir command = self._build_command() - logger.info(f" Running: {command}") + logger.info(f" Running: {' '.join(command)}") try: result = subprocess.run( command, - shell=True, + shell=False, capture_output=True, text=True, + errors="replace", check=True, cwd=_wd, ) logger.info(result.stdout) if result.stderr: logger.warning(f"stderr: {result.stderr}") + except FileNotFoundError as e: + logger.error(f"Command failed: '{command[0]}' not found on PATH ({e})") + self.status = WorkflowStatus.FAILED + return self.status, [], {} except subprocess.CalledProcessError as e: logger.error(f"Command failed (rc={e.returncode}): {e.stderr}") self.status = WorkflowStatus.FAILED diff --git a/jqmc_workflow/workflow.py b/jqmc_workflow/workflow.py index 4cbedcc4..68f0574e 100644 --- a/jqmc_workflow/workflow.py +++ b/jqmc_workflow/workflow.py @@ -271,12 +271,27 @@ def _ensure_project_dir(self): if self.project_dir is None: self.project_dir = os.path.abspath(os.getcwd()) + # Protected file basenames that ``_cleanup_files`` must never delete, + # regardless of what the user puts in ``cleanup_patterns``. Losing + # any of these breaks workflow state, job history, or resume. + _PROTECTED_CLEANUP_BASENAMES = frozenset( + { + "workflow_state.toml", + "workflow_state.toml.tmp", + } + ) + def _cleanup_files(self): """Delete files matching *cleanup_patterns* from local and remote. Local files are always removed. Remote files are removed only when the workflow targets a remote machine (``server_machine_name`` is set and not ``"localhost"``). + + Protected files (``workflow_state.toml`` and its atomic-write + ``.tmp`` sibling) are *never* deleted, even when matched by an + over-broad pattern like ``"*.toml"`` -- losing the state file + would break job history and resume. """ if not self.cleanup_patterns: return @@ -292,16 +307,24 @@ def _cleanup_files(self): for pattern in self.cleanup_patterns: for fpath in sorted(_glob.glob(os.path.join(work_dir, "**", pattern), recursive=True)): - if os.path.isfile(fpath): - os.remove(fpath) - logger.info(f" Cleanup: removed local file {os.path.relpath(fpath, work_dir)}") + if not os.path.isfile(fpath): + continue + if os.path.basename(fpath) in self._PROTECTED_CLEANUP_BASENAMES: + logger.warning(f" Cleanup: refusing to delete protected file {os.path.relpath(fpath, work_dir)}") + continue + os.remove(fpath) + logger.info(f" Cleanup: removed local file {os.path.relpath(fpath, work_dir)}") return from ._transfer import Data_transfer dt = Data_transfer(server_machine_name) try: - dt.remove_objects(patterns=self.cleanup_patterns, work_dir=work_dir) + dt.remove_objects( + patterns=self.cleanup_patterns, + work_dir=work_dir, + protected_basenames=self._PROTECTED_CLEANUP_BASENAMES, + ) except Exception: dt.ssh_close() raise @@ -310,9 +333,12 @@ def _cleanup_files(self): # -- configure / run (new primary interface) --------------------- def configure(self) -> dict: - """Validate parameters and generate inputs (no execution). + """Return a summary dict of the workflow's parameters (no side effects). - Override in subclass. Returns a summary dict. + Concrete workflows override this to expose their key parameters + for logging / inspection. Parameter *validation* happens in + ``__init__`` (so that invalid configs fail before any I/O), and + input-file generation happens lazily inside ``run()``. """ return {} @@ -328,6 +354,20 @@ async def run(self) -> tuple: self._ensure_project_dir() return self.status, self.output_files, self.output_values + def can_resume_after_completed(self, proj_dir: str) -> bool: + """Return True if a re-launch from a ``"completed"`` state could improve the result. + + ``Container`` consults this before short-circuiting on a + previously completed workflow. Subclasses with a target-error + convergence criterion (LRDMC, MCMC) override this to allow a + bumped ``max_continuation`` or tightened ``target_error`` to + actually re-trigger production runs, instead of silently + accepting the prior under-converged result. + + Default: False (the workflow is genuinely done). + """ + return False + # -- Full lifecycle (backward-compatible) ---------------------- async def async_launch(self): @@ -337,6 +377,11 @@ async def async_launch(self): return await self.run() def launch(self): + """Synchronous entry point: ``asyncio.run(self.async_launch())``. + + Not callable from an already-running event loop (e.g. Jupyter); + use ``await self.async_launch()`` there, or install ``nest_asyncio``. + """ return asyncio.run(self.async_launch()) # -- Phased execution (MCP interactive mode) ------------------- @@ -374,6 +419,9 @@ async def async_submit(self, action: str = "run") -> dict: require_action(action, self.phase, self.status) self._ensure_project_dir() self.configure() + # Flip self.status so subsequent require_action / async_poll + # callers see the workflow as RUNNING rather than PENDING. + self.status = WorkflowStatus.RUNNING self._bg_task = asyncio.create_task(self.run()) return {"status": "submitted", "project_dir": self.project_dir} @@ -390,6 +438,10 @@ async def async_poll(self) -> dict: if not self._bg_task.done(): summary = get_workflow_summary(self.project_dir) if self.project_dir else {} return {"status": "running", **summary} + # Task.exception() raises CancelledError on a cancelled task, so + # check cancellation BEFORE inspecting the exception. + if self._bg_task.cancelled(): + return {"status": "cancelled"} if self._bg_task.exception() is not None: return {"status": "failed", "error": str(self._bg_task.exception())} return {"status": "completed"} @@ -403,7 +455,8 @@ async def async_collect(self) -> dict: Raises: RuntimeError: - If the workflow was not submitted or is still running. + If the workflow was not submitted, was cancelled, or is + still running. Exception: Re-raises the original exception if the workflow failed. """ @@ -411,6 +464,8 @@ async def async_collect(self) -> dict: raise RuntimeError("No workflow has been submitted. Call async_submit() first.") if not self._bg_task.done(): raise RuntimeError("Workflow is still running. Call async_poll() to check status.") + if self._bg_task.cancelled(): + raise RuntimeError("Workflow was cancelled before completion.") exc = self._bg_task.exception() if exc is not None: raise exc @@ -529,6 +584,12 @@ async def _submit_and_wait( if recorded.get("status") == "submitted": stored_job_id = recorded.get("job_id") + if not stored_job_id: + raise RuntimeError( + f"State has step in 'submitted' status but no job_id for {input_file}. " + f"Edit workflow_state.toml (e.g. remove the malformed record or set status='cancelled') " + f"to recover." + ) logger.info(f" Resuming previously submitted job {stored_job_id}") job = self._make_job(input_file, output_file, queue_label=queue_label, run_id=run_id) try: @@ -699,12 +760,24 @@ def __init__( # -- Preparation ----------------------------------------------- def _prepare(self): - """Create project dir, copy input files, write initial state.""" + """Create project dir, copy input files, write initial state. + + Re-entry behaviour: when ``existing_status`` is ``completed`` or + ``running``, input files in *project_dir* are left intact and no + new state is created. ``async_launch`` will subsequently decide + whether to short-circuit (default for ``completed``) or resume + (when the inner workflow opts in via + :meth:`Workflow.can_resume_after_completed`). + """ state = read_state(self.project_dir) existing_status = state.get("workflow", {}).get("status", "") if existing_status in ("completed", "running"): - logger.info(f"[{self.label}] Already {existing_status}. Delete project dir to restart from scratch.") + logger.info( + f"[{self.label}] Existing project dir with status='{existing_status}' " + f"found; will short-circuit or resume depending on workflow policy. " + f"Delete project dir to force a clean restart." + ) return if os.path.isdir(self.project_dir): @@ -816,19 +889,59 @@ def _validate_input_files(self, proj: str): ) def _compute_input_fingerprints(self) -> dict[str, dict]: - """Return ``{basename: {sha256: hex_digest}}`` for each resolved input file.""" + """Return ``{basename: {sha256: hex_digest | "missing"}}`` per input file. + + For each entry in ``self.input_files``: + + * ``FileFrom`` / ``ValueFrom`` placeholders are skipped (they have + no on-disk path yet -- Launcher resolves them before launch; + direct ``.launch()`` may still hold placeholders). + * Regular files are hashed in 1 MiB chunks. + * Directories are hashed by walking their contents in sorted + order -- each file's relative path and content are folded into + the hash so the digest changes when any nested file changes. + * Missing source paths are recorded as ``{"sha256": "missing"}`` + so that :meth:`_check_input_staleness` can distinguish + "input never existed" from "input was deleted since last run". + """ fingerprints: dict[str, dict] = {} for i, src in enumerate(self.input_files): + if _is_dependency(src): + continue src = str(src) if not os.path.isabs(src): src = os.path.join(self.root_dir, src) key = self._dst_basename(src, self.rename_input_files, i) - if os.path.exists(src): + if os.path.isfile(src): h = hashlib.sha256() with open(src, "rb") as f: for chunk in iter(lambda: f.read(1 << 20), b""): h.update(chunk) fingerprints[key] = {"sha256": h.hexdigest()} + elif os.path.isdir(src): + # Walk deterministically (sorted dirs and files) so the + # digest is reproducible across runs. Folding the + # relative path into the hash makes additions, removals, + # and renames all visible. + h = hashlib.sha256() + for root, dirs, files in os.walk(src): + dirs.sort() + for name in sorted(files): + p = os.path.join(root, name) + rel = os.path.relpath(p, src).encode("utf-8", errors="replace") + h.update(rel) + h.update(b"\0") + try: + with open(p, "rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + except OSError: + # Unreadable entry -- fold a marker in so the + # digest still changes when readability flips. + h.update(b"") + fingerprints[key] = {"sha256": h.hexdigest()} + else: + fingerprints[key] = {"sha256": "missing"} return fingerprints def _check_input_staleness(self, proj: str) -> bool: @@ -872,20 +985,30 @@ async def async_launch(self): f"Delete '{self.dirname}/' to re-run with the updated inputs." ) - # Record input-file fingerprints after staleness check but - # before any execution, so that even interrupted runs have a - # baseline for the next invocation. - set_input_fingerprints(proj, self._compute_input_fingerprints()) - if prev_status == "completed": - logger.info(f"[{self.label}] Already completed, no re-run.") - self.status = WorkflowStatus.COMPLETED - self._collect_outputs() - return self.status, self.output_files, self.output_values - - # Validate required files before running. + if self.workflow.can_resume_after_completed(proj): + logger.info( + f"[{self.label}] Previously 'completed' but workflow indicates " + f"the result can still be improved (target not yet met); resuming." + ) + else: + logger.info(f"[{self.label}] Already completed, no re-run.") + self.status = WorkflowStatus.COMPLETED + self._collect_outputs() + # Record fingerprints even on short-circuit so the next + # launch's staleness check has an up-to-date baseline. + set_input_fingerprints(proj, self._compute_input_fingerprints()) + return self.status, self.output_files, self.output_values + + # Validate required files before running. This may raise; do it + # before updating the fingerprint baseline so that a failed + # validation leaves the previous baseline intact. self._validate_input_files(proj) + # Record input-file fingerprints after staleness check + validation + # but before execution, so even interrupted runs have a baseline. + set_input_fingerprints(proj, self._compute_input_fingerprints()) + # Run the workflow -- pass project_dir explicitly instead of # relying on os.chdir(). update_status(proj, WorkflowStatus.RUNNING) @@ -909,10 +1032,19 @@ async def async_launch(self): result_fields[f"result_{k}"] = v update_status(proj, WorkflowStatus.COMPLETED, **result_fields) # -- Post-completion cleanup -- - try: - self.workflow._cleanup_files() - except Exception as e: - logger.warning(f"[{self.label}] Cleanup failed (non-fatal): {e}") + # Only run cleanup when the workflow is *truly* done, i.e. + # ``can_resume_after_completed`` says no further runs would + # help. Otherwise restart.h5 / opt-step checkpoints would + # be deleted while still being needed for the next resume + # (e.g. when max_continuation was raised after exhausting + # the original budget without meeting target_error). + if not self.workflow.can_resume_after_completed(proj): + try: + self.workflow._cleanup_files() + except Exception as e: + logger.warning(f"[{self.label}] Cleanup failed (non-fatal): {e}") + else: + logger.info(f"[{self.label}] Skipping cleanup: workflow may still be resumed (target not yet met).") else: logger.error(error_msg) self.status = WorkflowStatus.FAILED @@ -926,19 +1058,32 @@ async def async_launch(self): return self.status, self.output_files, self.output_values + # Files emitted by the workflow plumbing itself (state, generated + # input TOMLs, submit scripts, scheduler stdout/stderr, accounting). + # They are recorded in [[jobs]] / [workflow] and are not meaningful + # downstream artefacts, so we exclude them from output_files. + _INTERNAL_OUTPUT_PREFIXES = ("input_", "output_", "submit_", "job_", "job_accounting_") + _INTERNAL_OUTPUT_FILES = ("workflow_state.toml", "workflow_state.toml.tmp") + def _collect_outputs(self): """Re-collect output info from state file (for already-completed runs).""" state = read_state(self.project_dir) self.output_values = state.get("result", {}) - # Gather all files in project dir as potential outputs if os.path.isdir(self.project_dir): self.output_files = [ f - for f in os.listdir(self.project_dir) - if os.path.isfile(os.path.join(self.project_dir, f)) and f != "workflow_state.toml" + for f in sorted(os.listdir(self.project_dir)) + if os.path.isfile(os.path.join(self.project_dir, f)) + and f not in self._INTERNAL_OUTPUT_FILES + and not f.startswith(self._INTERNAL_OUTPUT_PREFIXES) ] def launch(self): + """Synchronous entry point: ``asyncio.run(self.async_launch())``. + + Not callable from an already-running event loop (e.g. Jupyter); + use ``await self.async_launch()`` there, or install ``nest_asyncio``. + """ return asyncio.run(self.async_launch()) # -- Phased execution (delegates to inner Workflow) ------------ @@ -980,6 +1125,8 @@ async def async_poll(self) -> dict: if not self._bg_task.done(): summary = get_workflow_summary(self.project_dir) if self.project_dir else {} return {"status": "running", **summary} + if self._bg_task.cancelled(): + return {"status": "cancelled"} if self._bg_task.exception() is not None: return {"status": "failed", "error": str(self._bg_task.exception())} return {"status": "completed"} @@ -994,7 +1141,7 @@ async def async_collect(self) -> dict: Raises: RuntimeError: - If not submitted or still running. + If not submitted, cancelled, or still running. Exception: Re-raises the original exception if the workflow failed. """ @@ -1002,6 +1149,8 @@ async def async_collect(self) -> dict: raise RuntimeError(f"[{self.label}] Not submitted. Call async_submit() first.") if not self._bg_task.done(): raise RuntimeError(f"[{self.label}] Still running. Call async_poll() to check.") + if self._bg_task.cancelled(): + raise RuntimeError(f"[{self.label}] Workflow was cancelled before completion.") exc = self._bg_task.exception() if exc is not None: raise exc diff --git a/setup.cfg b/setup.cfg index 72b5763e..0dfc15a0 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,9 +4,9 @@ author = Kosuke Nakano author_email = kousuke_1123@icloud.com long_description = file: README.md long_description_content_type = text/markdown -url = https://github.com/kousuke-nakano/jQMC +url = https://github.com/jqmc-project/jQMC project_urls = - Bug tracker = https://github.com/kousuke-nakano/jQMC/issues + Bug tracker = https://github.com/jqmc-project/jQMC/issues Documentations = https://jQMC.readthedocs.io/en/latest/ classifiers = Intended Audience :: Science/Research @@ -37,6 +37,7 @@ install_requires = pyyaml >= 6.0.0 toml >= 0.10.2 typer >= 0.15.1 + click >= 8.0.0 tomlkit >= 0.13.2 uncertainties >= 3.2.2 matplotlib >= 3.10.1 diff --git a/tests/conftest.py b/tests/conftest.py index d390b0d5..d723850b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,6 +70,20 @@ def configure_precision(request): configure(mode) +@pytest.fixture(autouse=True) +def _isolated_cwd(tmp_path, monkeypatch): + """Run every test in its own pytest-managed tmp directory. + + Why: production code (e.g. ``MCMC.run_optimize``) writes artifacts such + as ``hamiltonian_data_opt_step_.h5`` to the current working directory + with fixed filenames. Under pytest-xdist, all workers share one cwd, so + concurrent tests collide on h5py's exclusive file lock (EWOULDBLOCK). + Per-test cwd isolation removes the collision and stops tests/ from + accumulating stale artifacts across runs. + """ + monkeypatch.chdir(tmp_path) + + def pytest_itemcollected(item): """Show reason for obsolete tests.""" obsolete_marker = item.get_closest_marker("obsolete") diff --git a/tests/test_checkpoint_gfmc.py b/tests/test_checkpoint_gfmc.py index aefdfe34..daa1a799 100644 --- a/tests/test_checkpoint_gfmc.py +++ b/tests/test_checkpoint_gfmc.py @@ -85,7 +85,9 @@ def _build_hamiltonian(trexio_file, jastrow_combo): ) if jastrow_combo == "1b+2b+3b+nn": - nn_jastrow_data = Jastrow_NN_data.init_from_structure(structure_data=structure_data) + nn_jastrow_data = Jastrow_NN_data.init_from_structure( + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0 + ) jastrow_data = Jastrow_data( jastrow_one_body_data=jastrow_one_body_data, diff --git a/tests/test_checkpoint_mcmc.py b/tests/test_checkpoint_mcmc.py index da868377..527d1211 100644 --- a/tests/test_checkpoint_mcmc.py +++ b/tests/test_checkpoint_mcmc.py @@ -93,7 +93,9 @@ def _build_hamiltonian(trexio_file, jastrow_combo): ) if jastrow_combo == "1b+2b+3b+nn": - nn_jastrow_data = Jastrow_NN_data.init_from_structure(structure_data=structure_data) + nn_jastrow_data = Jastrow_NN_data.init_from_structure( + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0 + ) jastrow_data = Jastrow_data( jastrow_one_body_data=jastrow_one_body_data, diff --git a/tests/test_hamiltonian.py b/tests/test_hamiltonian.py index 1fbc92d3..640721f1 100644 --- a/tests/test_hamiltonian.py +++ b/tests/test_hamiltonian.py @@ -162,7 +162,9 @@ def test_hamiltonian_hdf5(trexio_file, use_1b, use_2b, use_3b, use_nn, geminal_t nn_jastrow_data = None if use_nn: - nn_jastrow_data = Jastrow_NN_data.init_from_structure(structure_data=structure_data) + nn_jastrow_data = Jastrow_NN_data.init_from_structure( + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0 + ) jastrow_data = Jastrow_data( jastrow_one_body_data=jastrow_one_body_data, diff --git a/tests/test_jastrow.py b/tests/test_jastrow.py index 0d6b3245..41dd1837 100755 --- a/tests/test_jastrow.py +++ b/tests/test_jastrow.py @@ -1323,9 +1323,9 @@ def _build_jastrow_data_for_part_tests(j1b_type: str = "exp", j2b_type: str = "p if include_nn: jastrow_nn_data = Jastrow_NN_data.init_from_structure( structure_data=structure_data, - hidden_dim=16, - num_layers=2, - num_rbf=8, + hidden_dim=4, + num_layers=1, + num_rbf=4, cutoff=5.0, key=jax.random.PRNGKey(0), ) diff --git a/tests/test_jqmc_gfmc_bra.py b/tests/test_jqmc_gfmc_bra.py index b6710711..66d97cce 100755 --- a/tests/test_jqmc_gfmc_bra.py +++ b/tests/test_jqmc_gfmc_bra.py @@ -120,7 +120,7 @@ def test_jqmc_gfmc_n(trexio_file, with_1b_jastrow, with_2b_jastrow, with_3b_jast jastrow_nn_data = None if with_nn_jastrow: jastrow_nn_data = Jastrow_NN_data.init_from_structure( - structure_data=structure_data, hidden_dim=2, num_layers=1, cutoff=5.0, key=jax.random.PRNGKey(0) + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0, key=jax.random.PRNGKey(0) ) jastrow_data = Jastrow_data( diff --git a/tests/test_jqmc_gfmc_tau.py b/tests/test_jqmc_gfmc_tau.py index 77a7adc1..f3705900 100755 --- a/tests/test_jqmc_gfmc_tau.py +++ b/tests/test_jqmc_gfmc_tau.py @@ -121,7 +121,7 @@ def test_jqmc_gfmc_t(trexio_file, with_1b_jastrow, with_2b_jastrow, with_3b_jast jastrow_nn_data = None if with_nn_jastrow: jastrow_nn_data = Jastrow_NN_data.init_from_structure( - structure_data=structure_data, hidden_dim=2, num_layers=1, cutoff=5.0, key=jax.random.PRNGKey(0) + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0, key=jax.random.PRNGKey(0) ) jastrow_data = Jastrow_data( @@ -208,12 +208,19 @@ def test_jqmc_gfmc_t(trexio_file, with_1b_jastrow, with_2b_jastrow, with_3b_jast np.testing.assert_allclose(e_L2_debug, e_L2_jax, atol=atol, rtol=rtol) # average_projection_counter - # Both GFMC_t and _GFMC_t_debug now store local averages per rank. - apc_debug = gfmc_debug.average_projection_counter - apc_jax = gfmc_jax.average_projection_counter - assert not np.any(np.isnan(np.asarray(apc_debug))), "NaN detected in first argument" - assert not np.any(np.isnan(np.asarray(apc_jax))), "NaN detected in second argument" - np.testing.assert_allclose(apc_debug, apc_jax, atol=atol, rtol=rtol) + # Both GFMC_t and _GFMC_t_debug store local averages per rank. Production + # builds the branching cumprob via MPI allreduce + Exscan offset; debug + # uses a centralized rank-0 cumsum. The two paths are mathematically + # equivalent but not bit-identical, so boundary cases of searchsorted can + # permute walkers across ranks. Per-rank local apc is sensitive to that + # shuffling; the global mean across ranks is not. + apc_debug = np.asarray(gfmc_debug.average_projection_counter) + apc_jax = np.asarray(gfmc_jax.average_projection_counter) + assert not np.any(np.isnan(apc_debug)), "NaN detected in first argument" + assert not np.any(np.isnan(apc_jax)), "NaN detected in second argument" + apc_debug_global = np.mean(np.stack(mpi_comm.allgather(apc_debug), axis=0), axis=0) + apc_jax_global = np.mean(np.stack(mpi_comm.allgather(apc_jax), axis=0), axis=0) + np.testing.assert_allclose(apc_debug_global, apc_jax_global, atol=atol, rtol=rtol) # E E_debug, E_err_debug, Var_debug, Var_err_debug = gfmc_debug.get_E( diff --git a/tests/test_jqmc_mcmc.py b/tests/test_jqmc_mcmc.py index 3bf5f890..e4bc6685 100755 --- a/tests/test_jqmc_mcmc.py +++ b/tests/test_jqmc_mcmc.py @@ -127,7 +127,7 @@ def test_jqmc_mcmc(trexio_file, with_1b_jastrow, with_2b_jastrow, with_3b_jastro jastrow_nn_data = None if with_nn_jastrow: jastrow_nn_data = Jastrow_NN_data.init_from_structure( - structure_data=structure_data, hidden_dim=2, num_layers=1, cutoff=5.0 + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0 ) jastrow_data = Jastrow_data( @@ -278,7 +278,9 @@ def test_jqmc_vmc(trexio_file, monkeypatch): ) jastrow_twobody_data = Jastrow_two_body_data.init_jastrow_two_body_data(jastrow_2b_param=0.5, jastrow_2b_type="pade") jastrow_threebody_data = Jastrow_three_body_data.init_jastrow_three_body_data(orb_data=aos_data) - jastrow_nn_data = Jastrow_NN_data.init_from_structure(structure_data=structure_data, hidden_dim=5, num_layers=2, cutoff=5.0) + jastrow_nn_data = Jastrow_NN_data.init_from_structure( + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0 + ) jastrow_data = Jastrow_data( jastrow_one_body_data=jastrow_onebody_data, @@ -2354,9 +2356,10 @@ def test_trivial_2x2(self): S_matrix = np.array([[1.0]]) K_matrix = np.array([[-0.5]]) B_matrix = np.array([[-0.1]]) - c_vec, E_lm = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-10) + c_vec, E_lm, v0_sq = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-10) assert c_vec.shape == (1,) assert E_lm <= H_0 + 1e-10, f"E_lm={E_lm} should be <= H_0={H_0}" + assert 0.0 <= v0_sq <= 1.0 + 1e-10 def test_diagonal_known_solution(self): """Diagonal H, S: verify c_vec has correct shape and E_lm is valid.""" @@ -2366,10 +2369,11 @@ def test_diagonal_known_solution(self): S_matrix = np.diag(np.linspace(0.1, 1.0, p)) K_matrix = np.diag(np.linspace(-1.0, -0.1, p)) B_matrix = np.diag(np.linspace(-0.5, -0.05, p)) - c_vec, E_lm = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-10) + c_vec, E_lm, v0_sq = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-10) assert c_vec.shape == (p,) assert np.all(np.isfinite(c_vec)) assert np.isfinite(E_lm) + assert 0.0 <= v0_sq <= 1.0 + 1e-10 def test_epsilon_cutoff(self): """S eigenvalues below epsilon are cut; p' < p.""" @@ -2379,9 +2383,10 @@ def test_epsilon_cutoff(self): S_matrix = np.diag([1.0, 0.5, 1e-8, 1e-10]) K_matrix = np.eye(p) * (-0.5) B_matrix = np.eye(p) * (-0.1) - c_vec, E_lm = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-6) + c_vec, E_lm, v0_sq = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-6) assert c_vec.shape == (p,) assert np.isfinite(E_lm) + assert 0.0 <= v0_sq <= 1.0 + 1e-10 def test_all_diag_S_zero(self): """All diag(S) = 0 -> dgelscut removes all parameters -> zero update, E_lm == H_0.""" @@ -2391,9 +2396,10 @@ def test_all_diag_S_zero(self): S_matrix = np.zeros((p, p)) K_matrix = np.eye(p) * (-0.5) B_matrix = np.eye(p) * (-0.1) - c_vec, E_lm = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-6) + c_vec, E_lm, v0_sq = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-6) np.testing.assert_array_equal(c_vec, np.zeros(p)) assert E_lm == H_0 + assert v0_sq == 0.0 def test_v0_max_selection(self): """The eigenvector with max |v_0|^2 is selected.""" @@ -2404,9 +2410,10 @@ def test_v0_max_selection(self): S_matrix = np.eye(p) K_matrix = np.diag([-10.0, -0.1]) B_matrix = np.zeros((p, p)) - c_vec, E_lm = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-10) + c_vec, E_lm, v0_sq = MCMC.solve_linear_method(H_0, f_vec, S_matrix, K_matrix, B_matrix, epsilon=1e-10) assert c_vec.shape == (p,) assert np.isfinite(E_lm) + assert 0.0 <= v0_sq <= 1.0 + 1e-10 # --------------------------------------------------------------------------- @@ -2618,10 +2625,11 @@ def test_get_aH_and_solve_lm_debug_vs_production(): # Use the production matrices for both to verify the two implementations # produce identical results when given the exact same input. epsilon_lm = 1e-6 - c_debug, E_debug = _MCMC_debug.solve_linear_method(H_0_p, f_p, S_p, K_p, B_p, epsilon_lm) - c_prod, E_prod = MCMC.solve_linear_method(H_0_p, f_p, S_p, K_p, B_p, epsilon_lm) + c_debug, E_debug, v0_debug = _MCMC_debug.solve_linear_method(H_0_p, f_p, S_p, K_p, B_p, epsilon_lm) + c_prod, E_prod, v0_prod = MCMC.solve_linear_method(H_0_p, f_p, S_p, K_p, B_p, epsilon_lm) np.testing.assert_allclose(c_debug, c_prod, atol=atol, rtol=rtol) np.testing.assert_allclose(E_debug, E_prod, atol=atol, rtol=rtol) + np.testing.assert_allclose(v0_debug, v0_prod, atol=atol, rtol=rtol) jax.clear_caches() diff --git a/tests/test_lrdmc_force.py b/tests/test_lrdmc_force.py index ac413cdc..daa4766c 100755 --- a/tests/test_lrdmc_force.py +++ b/tests/test_lrdmc_force.py @@ -152,7 +152,7 @@ def test_lrdmc_force_with_SWCT_n(trexio_file: str, jastrow_parameters: dict, loc jastrow_nn_param = jastrow_parameters.get("jastrow_nn_param", False) if jastrow_nn_param: jastrow_nn_data = Jastrow_NN_data.init_from_structure( - structure_data=structure_data, hidden_dim=5, num_layers=2, cutoff=5.0 + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0 ) else: jastrow_nn_data = None @@ -269,7 +269,7 @@ def test_lrdmc_force_with_SWCT_t(trexio_file: str, jastrow_parameters: dict, loc jastrow_nn_param = jastrow_parameters.get("jastrow_nn_param", False) if jastrow_nn_param: jastrow_nn_data = Jastrow_NN_data.init_from_structure( - structure_data=structure_data, hidden_dim=5, num_layers=2, cutoff=5.0 + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0 ) else: jastrow_nn_data = None diff --git a/tests/test_mcmc_force.py b/tests/test_mcmc_force.py index 23cbd4d9..33f72c20 100755 --- a/tests/test_mcmc_force.py +++ b/tests/test_mcmc_force.py @@ -151,7 +151,7 @@ def test_mcmc_force_with_SWCT(trexio_file: str, jastrow_parameters: dict): jastrow_nn_param = jastrow_parameters.get("jastrow_nn_param", False) if jastrow_nn_param: jastrow_nn_data = Jastrow_NN_data.init_from_structure( - structure_data=structure_data, hidden_dim=5, num_layers=2, cutoff=5.0 + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0 ) else: jastrow_nn_data = None @@ -310,7 +310,7 @@ def test_mcmc_force_open_shell_finite(with_nn: bool): orb_data=aos_data, random_init=True, random_scale=1.0e-3 ) jastrow_nn_data = ( - Jastrow_NN_data.init_from_structure(structure_data=structure_data, hidden_dim=2, num_layers=1, cutoff=5.0) + Jastrow_NN_data.init_from_structure(structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0) if with_nn else None ) diff --git a/tests/test_wave_function.py b/tests/test_wave_function.py index 59caa4e2..3cf469ac 100755 --- a/tests/test_wave_function.py +++ b/tests/test_wave_function.py @@ -432,7 +432,9 @@ def test_nodal_distance_analytic_vs_debug(trexio_file: str): jastrow_threebody_data = Jastrow_three_body_data.init_jastrow_three_body_data( orb_data=aos_data, random_init=True, random_scale=1.0e-3 ) - jastrow_nn_data = Jastrow_NN_data.init_from_structure(structure_data=structure_data, hidden_dim=5, num_layers=2, cutoff=5.0) + jastrow_nn_data = Jastrow_NN_data.init_from_structure( + structure_data=structure_data, hidden_dim=2, num_layers=1, num_rbf=2, cutoff=5.0 + ) jastrow_data = Jastrow_data( jastrow_one_body_data=jastrow_onebody_data,