Skip to content

v0.2.2#81

Merged
kousuke-nakano merged 118 commits into
rcfrom
main
Jun 2, 2026
Merged

v0.2.2#81
kousuke-nakano merged 118 commits into
rcfrom
main

Conversation

@kousuke-nakano

Copy link
Copy Markdown
Collaborator

v0.2.2

kousuke-nakano and others added 30 commits April 24, 2026 17:17
Introduce a zone-based mixed precision system that allows each computational component (AO/MO evaluation, Jastrow, determinant, Coulomb, kinetic energy, MCMC, GFMC, optimization, I/O) to run independently in float32 or float64. The default mode (`"full"`) keeps all zones at float64 for backward compatibility, while `"mixed"` sets low-risk zones to float32 and keeps numerically sensitive zones at float64.

Core changes:
- New module `jqmc/_precision.py` (`get_dtype`, `configure`, `get_tolerance`)
- All 15 source modules parameterized with `get_dtype("zone")`
- EPS constants made dtype-aware via `get_eps()` in `_setting.py`
- TOML `[precision]` section parsed in `jqmc_cli.py`
- `jqmc_workflow` (VMC/MCMC/LRDMC) accepts a `precision_mode` parameter

Test infrastructure:
- Dtype-aware tolerances via `get_tolerance(zone, level)`
- `@numerical_diff` marker: skip FD tests in mixed mode
- `@external_reference` marker: skip TurboRVB tests in mixed mode
- Zone-corrected tolerances for cross-zone gradient comparisons

Documentation and examples:
- Module docstrings with Precision Zone annotations
- User guide (`doc/notes/mixed_precision.md`)
- Sphinx autodoc entries for `_precision.py` and `_setting.py`
- TOML `[precision]` examples in `jqmc-example01`
- Workflow example (`run_pes_pipeline.py`) with `PRECISION_MODE` config
jqmc/jqmc_gfmc.py to remove implicit reliance on JAX dtype propagation.
Behavior is unchanged (gfmc zone defaults to float64 in both full/mixed
modes); this improves precision-zone clarity and resilience against
future JAX promotion-rule changes.
Implement zone-aware fp32/fp64 dtype propagation across the
orb_eval / jastrow / geminal / coulomb (fp32 in mixed mode) and
determinant / kinetic / wavefunction / IO (fp64) precision zones,
so that mode=mixed actually delivers the intended speedup over
mode=full instead of being collapsed back to fp64 by JAX type
promotion.

Implementation
- atomic_orbital.py: cast AO data fields and r_carts at zone entry
  (incl. _atomic_center_carts_prim_jnp and friends; covers both
  cart and sphe paths).
- molecular_orbital.py: cast mo_coefficients at zone entry.
- jastrow_factor.py: cast J1/J2/J3 r_carts and variational params
  to the jastrow zone dtype; J3-NN params handled via tree_map.
- determinant.py: cast geminal inputs / lambda blocks to the
  geminal zone dtype, including the MCMC fast-update row/column
  kernels.
- coulomb_potential.py:
  - cast inputs at zone entry across the bare/local/non-local ECP
    paths (~15 np.array sites).
  - cast wf_ratio_up / wf_ratio_dn back to the coulomb zone in
    compute_ecp_non_local_part_all_pairs_jax_weights_grid_points,
    sealing fp64 leak from determinant/jastrow into V_ecp pytree.
- structure.py: mark dtype as static for _get_min_dist_rel_R_cart_jnp
  (@partial(jit, static_argnames=("dtype",)))-required because
  Python scalar types like jnp.float32 cannot be traced as abstract
  arrays.
- jqmc_mcmc.py (_update_electron_positions): cast lax.cond branches in
  v / u construction to geminal_inv.dtype. The geminal-diff branch lives
  in the geminal zone (fp64) while jax.nn.one_hot defaults to fp32, so
  in mixed precision the cond branches disagreed on dtype.

- jqmc_mcmc.py (_update_electron_positions_only_up_electron): cast the
  Sherman-Morrison rank-1 inverse update back to geminal_inv.dtype so
  _accepted_fun and _rejected_fun branches of lax.cond agree.

- jqmc_gfmc.py (debug-vs-production e_L check at the tmove branch):
  replace hard-coded `assert_almost_equal(..., decimal=6)` with
  `assert_allclose(..., atol, rtol)` driven by `get_tolerance_min` over
  the e_L path zones. The previous hard-coded decimal silently ignored
  the surrounding `rtol_debug_vs_production` setting and broke under
  mixed precision.

Test tolerance fixes:

- tests/test_jqmc_mcmc.py, tests/test_jqmc_gfmc_tau.py,
  tests/test_jqmc_gfmc_bra.py: switch e_L / e_L2 / w_L / ln|Psi| /
  H_0/f/S/K/B comparisons from get_tolerance("mcmc"/"gfmc", "strict")
  to get_tolerance_min(<path zones>, "strict"). The mcmc/gfmc zones
  are fp64-pinned even in mixed mode, but the e_L computation path
  crosses orb_eval/jastrow/geminal/coulomb/kinetic which are fp32 in
  mixed mode, so tolerances must be bounded by the weakest zone.
…oid catastrophic cancellation

When mixed-precision zones (orb_eval, jastrow, coulomb) are set to float32,
relative coordinate differences r - R (and r_i - r_j) suffer catastrophic
cancellation for systems with large absolute coordinates (e.g. R ~ 50 Bohr
loses ~6 digits of precision in float32). This propagated through
exp(-Z*r^2), (x+eps)^nx, and 1/|r-R|, producing energy errors of several
hartree (e.g. -134 vs -137 Ha) far beyond expected float32 round-off.
…tion of log|det|

Diagnostic scripts in bug/fp32/ showed that with the previous mixed-precision
defaults (orb_eval=fp32, geminal=fp32), local energies on a 32-up/32-dn ECP
system were biased vs full fp64. The bias came from the kinetic term: AO matrix
entries had only fp32 round-off, but the determinant amplified that into a log|det|
error, which fed back into both T and ECP non-local matrix elements.

Two minimal changes restore safety while keeping the heavy AO kernels in fp32:

1. molecular_orbital.compute_MOs now upcasts the (small) MO matmul to the
   determinant-zone dtype (fp64 in mixed mode). The expensive AO evaluation
   still runs in orb_eval (fp32), but the MO matrix returned to the
   determinant / kinetic / energy paths is fp64, breaking the chain
   "fp32 AO -> fp32 geminal -> noisy log|det|".

2. _precision._DEFAULTS_MIXED["geminal"] is changed from float32 to float64.
   The geminal matrix is the input to LU/det; even ~1e-7 entry noise blows
   up to O(1) errors in log|det|. The diag_04 AGP sweep additionally
   shows that AO-direct (AGP) form is ~4x more sensitive to geminal=fp32
   than the MO (JSD) form so this default is required for both ansatz types.
Users now choose only `"full"` or `"mixed"`; per-zone overrides have been removed from TOML, workflow parameters, and `configure()`. Zone assignments now live in `_FULL_PRECISION` and `_MIXED_PRECISION` dictionaries, renamed from `_DEFAULTS_*`. Developers who need per-zone control can edit those dictionaries directly or use `_set_zone()`.
… modules with 3 design principles

Brings all selectable-precision modules into compliance with the three design principles documented in `jqmc/_precision.py`

  P1: zone <-> owning module is 1:1; a module consults only its own zones.
  P2: a module may own multiple zones, named for purpose not dtype.
  P3a (frozen args): parameter names must not be rebound. Forwarding is dtype-neutral; cast at the arithmetic use site.
  P3b (local cast at point of arithmetic): cast operands to the function's own zone immediately before consumption. For catastrophic cancellation `(r - R)`: reconstruct in caller-supplied precision then `.astype(zone)`, never hardcode `jnp.float64` in the reconstruction.

Main changes:
    - Remove parameter rebinds (`arg = jnp.asarray(arg, dtype=...)` at function entry) across `swct`, `atomic_orbital`, `hamiltonians`, `coulomb_potential`, `jastrow_factor`, and `determinant` (including JIT and debug helpers). Arguments are now forwarded dtype-neutral; casting happens at the arithmetic use site via `arg.astype(dtype_jnp)` or via new locals such as `A_old_inv_z`, `G_inv`, and `r_up_z` to avoid repeated casts without rebinding the parameter.
    - Replace hardcoded `jnp.float64` in `r - R` reconstructions with caller-supplied precision: `(r - R).astype(zone_dtype)` in `coulomb_potential`, `jastrow_factor`, and the `compute_AOs` / analytic grad-lap kernels in `atomic_orbital`.
    - Add explicit `.astype(get_dtype_jnp("mo_grad_lap"))` on AO-zone outputs before `jnp.dot(mo_coefficients, ...)` in `compute_MOs_laplacian`, `compute_MOs_grad`, and `_compute_MOs_grad_autodiff`. Runtime behavior is unchanged, since JAX promotion already produced the correct dtype; this change is for zone-leak auditability.
    - Extend the "no hardcoded dtypes" exemption in `_precision.py` to cover basis-data storage accessors (`_*_jnp` properties on dataclasses whose underlying field is `npt.NDArray[np.float64]`), with a one-line justification comment at each accessor body. These are lift-only NumPy-to-`jax.Array` adapters; storage is fp64 by construction.
    - Cosmetic: fix `molecular_orbital` module docstring zone names (`orb_eval` / `kinetic` -> `mo_eval` / `mo_grad_lap`) and relax `swct` debug-helper return type hints from `npt.NDArray[np.float64]` to plain `np.ndarray` for selectable-precision zones. Correct the `evaluate_swct_domega` JIT-function return type to `jax.Array`.
Precision zone changes (`jqmc/_precision.py`):

  - `ao_grad_lap` -> `ao_grad` (fp32 in mixed) + `ao_lap` (fp64 in mixed)
  - `mo_grad_lap` -> `mo_grad` (fp64) + `mo_lap` (fp64)
  - `jastrow_grad_lap`: kept as a single zone (fp32 in mixed)
  - `det_grad_lap`: kept as a single zone (fp64 in mixed)

`ao` is split with an actual dtype difference because the analytic Laplacian kernel contains catastrophic-cancellation terms that fp32 cannot resolve, while the gradient kernel is safe in fp32 (see `bug/fp32/diag_07_ao_grad_vs_lap_split.py` for the diagnostic). The shared helper `_compute_S_l_m_and_grad_lap` is pinned to `ao_lap` (fp64), and `_compute_AOs_grad_analytic_sphe` downcasts the gradient output at the use site per Principle 3b.
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.11 → v0.15.12](astral-sh/ruff-pre-commit@v0.15.11...v0.15.12)
Add mixed-precision (fp32/fp64) support with selectable per-zone dtype
Replace the per-walker Python loop in `_generate_init_electron_configurations` with a vectorized NumPy implementation: the deterministic atom assignment is replayed once, and all spherical offsets are drawn in a single batched call. The original implementation is kept as `_generate_init_electron_configurations_debug` for tests.

Replace `[fold_in(key, nw) for nw in range(num_walkers)]` with batched `jax.random.split(key, num_walkers)` in `MCMC`, `_MCMC_debug`, `GFMC_t`, `_GFMC_t_debug`, `GFMC_n`, and `_GFMC_n_debug`. Also remove the per-walker `bincount` + `logger.debug` loop from the same six `__init__` blocks.

Add tests covering position uniqueness, owner/per-atom-count agreement with the reference, dimer singlet anti-alignment, and per-atom charge neutrality across all reachable `S`.
The hand-rolled jackknife standard deviation in the MCMC/GFMC estimators used

    Var = <x^2> - <x>^2

which suffers from float64 catastrophic cancellation when `M_total = num_mcmc_bin_blocks * nw * num_ranks` becomes large.

Switch to the two-pass centered formulation in all production paths:
- `MCMC.get_E`, `get_aF`, `get_gF`
- `GFMC_t.get_E`, `get_aF` (small-bin and MPI-scatter)
- `GFMC_n.get_E`, `get_aF` (small-bin and MPI-scatter)

`_MCMC_debug`, `_GFMC_t_debug`, and `_GFMC_n_debug` are unchanged; they already use `np.std()` (internally two-pass). The existing debug <-> production agreement tests cover this fix as a regression test.
Vectorize walker init: electron configs and PRNG keys
Use two-pass centered sum-of-squares in jackknife std
Forward `Kinetic_streaming_state.j3_state` through `_compute_ratio_Jastrow_part_{rank1_update,split_spin}` so the discretized kinetic mesh and ECP non-local kernels can skip per-step AO / W / U / `cross_vec` recomputation, saving `O(n_ao^2 * N_e)` per projection step. Legacy and observation paths keep `j3_state=None`.
- Add `compute_AOs_value_grad_lap` and `compute_MOs_value_grad_lap`, which share the heavy block (`exp` / polynomial chain / `S_l_m`) across value, gradient, and Laplacian evaluation instead of recomputing it three times.
- Migrate the streaming-advance hot path (Det/J3 single-electron updates), the fast/init paths, and the J3 forward grad/lap path to the fused API via `Geminal_data.compute_orb_value_grad_lap_api` and `_three_body_orb_apis`.
pre-commit-ci Bot and others added 29 commits May 21, 2026 23:11
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.15.12 → v0.15.13](astral-sh/ruff-pre-commit@v0.15.12...v0.15.13)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: kousuke-nakano <37653569+kousuke-nakano@users.noreply.github.com>
* Update docs.

* Optimize J2 ratio: O(N^2) baseline -> O(N * N_grid) per-grid sums

* Introduce J3 state carry in MCMC update.

* Implement the slim J3 carry for MCMC WF update.

* Refactor atomic_orbital: coalesce primitive gather and lift r-R to unique-atom rank

Address the L1 LSU wavefront-pipe bottleneck (~90% of peak) identified
in `_compute_AOs_cart` `_compute_AOs_sphe` and their Laplacians.

Coalesced bucket reduction:
  - Introduce `PrimBucketLayout` (NamedTuple) and rewrite
    `_build_prim_buckets_by_K` to return a permanent bucket-K-major
    permutation of the primitive axis.
  - Apply the permutation inside every `_*_prim_*` accessor on
    `AOs_cart_data` / `AOs_sphe_data`, so all primitive arrays are
    stored bucket-K-major.
  - `_reduce_primitives_to_aos` now reduces each K bucket via
    `dynamic_slice + reshape + reduce_sum` instead of a fancy
    primitive-index gather, collapsing L1 sectors-per-request from
    ~10 toward the ideal of 4.

r-R at unique-atom rank:
  - Compute `r-R` and `r^2` at unique-atom rank
    (`n_atoms_unique << num_ao_prim`) and gather to primitive rank
    via `_nucleus_index_prim_jnp`.
  - Eliminates the `(num_ao_prim, n_elec, 3)` broadcast intermediate
    that drove most of the register-spill local-memory traffic.
  - Applied to `_compute_AOs_cart`, `_compute_AOs_sphe`,
    `_compute_AOs_laplacian_analytic_cart`, and
    `_compute_AOs_laplacian_analytic_sphe` (the spherical Laplacian
    additionally computes `grad_S . r_R` at unique rank).

* Fixed a trivial bug atomic_orbital: permute primitive arrays along the last axis

`_exponents_jnp` / `_coefficients_jnp` applied the bucket-K-major
permutation via `arr[self._prim_perm_np]`, which targets axis 0. This
is fine for the AO eval kernels (1D primitive arrays), but
`collect_param_grads` reads these accessors on a vmap-batched
`AOs_*_data` whose `exponents` / `coefficients` leaves are 2D
`(num_walkers, num_ao_prim)`, the take then mangled the walker axis
into the primitive axis and produced shape mismatches downstream.

Switch to `jnp.take(arr, self._prim_perm_np, axis=-1)` so the
permutation is always applied to the primitive axis regardless of
prepended batch dims. AO-eval results are bit-exact identical for the
1D inputs used by `_compute_AOs_*`.

* Fixed a bug in AOs. return raw-order exponents/coefficients from public accessors

The public ao_exponents/ao_coefficients on Geminal_data and
Jastrow_three_body_data were leaking the bucket-K-major permuted
form (_exponents_jnp/_coefficients_jnp), which broke round-trips
through the basis-optimization API (ShellPrimMap.symmetrize,
with_updated_ao_*, apply_block_update) and produced double-mangled
values.

Public accessors now return basis-natural order via
jnp.asarray(self.exponents/coefficients). The bucket-K permutation
stays strictly inside the AO eval kernel boundary.

Also revert the jnp.take(..., axis=-1) workaround introduced in
561ca40 in atomic_orbital._{exponents,coefficients}_jnp.

* Revert "Fixed a bug in AOs. return raw-order exponents/coefficients from public accessors"

This reverts commit f38dd21.

* Revert "Fixed a trivial bug atomic_orbital: permute primitive arrays along the last axis"

This reverts commit 561ca40.

* Revert "Refactor atomic_orbital: coalesce primitive gather and lift r-R to unique-atom rank"

This reverts commit 43bfa00.

* Improve GFMC_n on-the-fly `E_scf` update, leaving bad-regime samples in permanent contamination

The on-the-fly `E_scf` update in `GFMC_n` / `_GFMC_n_debug` was gated by `(i_mcmc_step + 1) % mcmc_interval == 0`, which delayed the first update until iteration `eq_steps + mcmc_interval - 1`. Since `mcmc_interval = num_mcmc_steps / 100`, the first-update step grew linearly with `N`. As a result, `stored_w_L` values projected with the possibly far-off initial `E_scf` were pinned as permanent K-window-product outliers in `__G_L`, and the on-the-fly jackknife never recovered as `N` grew: sigma stayed flat instead of falling as `1/sqrt(N)`. The same outlier-laden tail also leaked into the final `get_E()` mixed estimator whenever the user-set `num_gfmc_warmup_steps` was smaller than the run-dependent bad-regime length. This was introduced in `c287272` (2025-02-05), when the `E_scf` update logic was first added inside the display-throttle block.

Fix:
- Remove the `mcmc_interval` throttle from the update logic so that `E_scf` is updated every step during the rapid phase (`i_mcmc_step < mcmc_interval`) and every `mcmc_interval` steps thereafter. The MPI cost increase is bounded by about `N / 100` extra broadcasts, which is negligible relative to projection time.
- Derive the jackknife warmup cap from the actual bad-regime boundary: `GFMC_ON_THE_FLY_COLLECT_STEPS + GFMC_ON_THE_FLY_BIN_BLOCKS` (`= 20` with the current constants), instead of the magic `eq_steps = 20`
- Remove `GFMC_ON_THE_FLY_WARMUP_STEPS` from `_setting.py` because it is no longer used; it is replaced by `mcmc_interval` as the rapid/throttle boundary.

* Fix continuation step overestimation in LRDMC/MCMC workflows

`accumulated_measurement` was being incremented by the planned `estimated_steps - warmup` after every `_submit_and_wait`, regardless of whether the run actually completed all planned branching steps. When a GFMC run was cut short by `max_time` ("Break the branching loop"), the workflow overcounted the accumulated samples, inflating `estimate_additional_steps` and causing the target step count to grow on each continuation.

* Improve jqmc-workflow behaviors.

  - Reconcile orphan `submitted`/`completed` jobs to `fetched` on re-entry
  - Always use `jqmc-tool` for energy (drop stdout `%.5f` fast-path)
  - Cache (energy, error) in `[estimation]` keyed on accumulated steps + post-proc params
  - Allow re-entry into `completed` LRDMC/MCMC when target_error is not yet met

* Clean up jqmc-workflow. Solved many trivial bugs.

* Improve jqmc-workflow error reporting and fix silent failures

- `launcher`: include the per-job reason and path in the DAG execution summary so failed runs are easier to diagnose
- `vmc_workflow`: detect non-finite energies at each optimization step by parsing before `validate_completion`, and consolidate VMC log parsing through `_output_parser._parse_vmc_log_text`
- `_output_parser`: extend the energy / force / SNR regexes to match `nan` / `inf` so diverged runs surface as non-finite floats instead of being silently dropped from convergence checks
- `lrdmc_ext_workflow`: read `num_projection_per_measurement` from each sub-workflow's `output_values` (`GFMC_n` / `lrdmc-bra` path) before falling back to the `GFMC_t` output-file diagnostic, which does not exist in bra mode

* Add `|v_0|^2 < 0.9` fallback to plain SR in LM optimization

`solve_linear_method` now returns `v0_sq_best` alongside `(c_vec, E_lm)`. The LM caller falls back to `theta = 0.1 * g_sr` when the selected eigenvector's overlap with the current wavefunction is small, preventing NaN energies from updates that lie outside the linear regime. This also removes the now-redundant `|v_0|^2 < 0.01` warning inside the solver. Tests and `Debug_MCMC.solve_linear_method` have been updated for the new return value.

* Make tall CG SR solve distributed: `psum` instead of `all_gather(X)`

Per-rank memory no longer scales with `mpi_size`, fixing weak-scaling OOM. Keep the legacy kernel as dead code pending verification on real hardware.

* Remove legacy tall-CG kernel; promote distributed version to canonical name

The distributed `psum`-based tall-CG kernel has now been verified: it eliminates the OOM issue and preserves results. Remove the legacy `all_gather`-based kernel, its driver method, and the equivalence test, and rename the distributed implementation to the canonical names.

Net behavior is unchanged from the prior commit.

* LRDMC_Workflow: publish averaged nmpm in GFMC_t mode

Expose parsed `avg_num_projections` as `output_values["num_projection_per_measurement"]` in tau mode, matching the `GFMC_n` key. This simplifies `
…imer_projection_init`. It was previously leaked into Net GFMC time when `comput_position_deriv=True`.
…den_dim=2, num_layers=1, num_rbf=2 (test_jastrow: 4/1/4 for numerical-diff).
…ranks (allgather) instead of per-rank local mean because the production and debug functions yield different cums at the digit level.
timeout-minutes 720 -> 1440
…on (allreduce+Exscan) cumprob is not bit-identical, and at fp32+np>2 the gap crosses searchsorted boundaries and permutes walkers across ranks, breaking the rank-local w_L/e_L/e_L2 compare (physics is still validated via get_E/get_aF MPI-aggregated checks).
@kousuke-nakano kousuke-nakano merged commit 262c64c into rc Jun 2, 2026
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant