From 46524378075c7fbe261319ee95dbdfd269fc31e1 Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Sat, 11 Apr 2026 04:43:27 +0000 Subject: [PATCH 1/4] Add OpenVM2 SWIRL soundness support --- pyproject.toml | 5 + soundcalc/custom/__init__.py | 1 + soundcalc/custom/swirl/__init__.py | 18 ++ soundcalc/custom/swirl/calculator.py | 433 +++++++++++++++++++++++++++ soundcalc/custom/swirl/circuit.py | 115 +++++++ soundcalc/main.py | 3 +- soundcalc/report_md.py | 24 +- soundcalc/zkvms/openvm2/__init__.py | 80 +++++ soundcalc/zkvms/openvm2/openvm2.toml | 77 +++++ 9 files changed, 748 insertions(+), 8 deletions(-) create mode 100644 soundcalc/custom/__init__.py create mode 100644 soundcalc/custom/swirl/__init__.py create mode 100644 soundcalc/custom/swirl/calculator.py create mode 100644 soundcalc/custom/swirl/circuit.py create mode 100644 soundcalc/zkvms/openvm2/__init__.py create mode 100644 soundcalc/zkvms/openvm2/openvm2.toml diff --git a/pyproject.toml b/pyproject.toml index 22b15a5..9181290 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,5 +14,10 @@ dependencies = [] [tool.setuptools.packages.find] where = ["."] +[dependency-groups] +dev = [ + "pytest>=8.4.2", +] + diff --git a/soundcalc/custom/__init__.py b/soundcalc/custom/__init__.py new file mode 100644 index 0000000..3ed9c69 --- /dev/null +++ b/soundcalc/custom/__init__.py @@ -0,0 +1 @@ +"""Custom proof-system integrations.""" diff --git a/soundcalc/custom/swirl/__init__.py b/soundcalc/custom/swirl/__init__.py new file mode 100644 index 0000000..db2ade4 --- /dev/null +++ b/soundcalc/custom/swirl/__init__.py @@ -0,0 +1,18 @@ +from soundcalc.custom.swirl.calculator import ( + SWIRLLogUpSecurityParameters, + SWIRLSystemParams, + SWIRLWhirProximityMode, + build_swirl_system_params, + calculate_swirl_soundness, +) +from soundcalc.custom.swirl.circuit import SWIRLCircuit, SWIRLCircuitConfig + +__all__ = [ + "SWIRLCircuit", + "SWIRLCircuitConfig", + "SWIRLLogUpSecurityParameters", + "SWIRLSystemParams", + "SWIRLWhirProximityMode", + "build_swirl_system_params", + "calculate_swirl_soundness", +] diff --git a/soundcalc/custom/swirl/calculator.py b/soundcalc/custom/swirl/calculator.py new file mode 100644 index 0000000..9ce6615 --- /dev/null +++ b/soundcalc/custom/swirl/calculator.py @@ -0,0 +1,433 @@ +from __future__ import annotations + +import math +from dataclasses import dataclass + +from soundcalc.common.fields import FieldParams +from soundcalc.common.utils import apply_grinding +from soundcalc.pcs.whir import WHIR +from soundcalc.proxgaps.johnson_bound import JohnsonBoundRegime +from soundcalc.proxgaps.proxgaps_regime import ProximityGapsRegime + + +SWIRL_SECURITY_BITS_TARGET = 100 +SWIRL_WHIR_K = 4 +SWIRL_WHIR_MAX_LOG_FINAL_POLY_LEN = 10 +SWIRL_QUERY_PHASE_POW_BITS = 20 +SWIRL_MAX_CONSTRAINT_DEGREE = 4 + + +@dataclass(frozen=True) +class SWIRLWhirRoundConfig: + num_queries: int + + +@dataclass(frozen=True) +class SWIRLWhirProximityMode: + kind: str + m: int | None = None + + def build_regime(self, field: FieldParams) -> ProximityGapsRegime: + if self.kind == "unique": + return SWIRLUniqueDecodingRegime(field) + if self.kind == "list": + if self.m is None: + raise ValueError("list-decoding mode requires multiplicity m") + return SWIRLListDecodingRegime(field, self.m) + raise ValueError(f"Unknown SWIRL proximity mode: {self.kind}") + + def whir_query_security_bits(self, num_queries: int, log_inv_rate: int) -> float: + rho = 2.0 ** (-log_inv_rate) + if self.kind == "unique": + max_agreement = (1.0 + rho) / 2.0 + elif self.kind == "list": + if self.m is None: + raise ValueError("list-decoding mode requires multiplicity m") + max_agreement = math.sqrt(rho * (1.0 + 1.0 / self.m)) + 1e-6 + else: + raise ValueError(f"Unknown SWIRL proximity mode: {self.kind}") + + max_agreement = max(max_agreement, math.ldexp(1.0, -1022)) + return -(num_queries * math.log2(max_agreement)) + + +@dataclass(frozen=True) +class SWIRLWhirConfig: + k: int + rounds: list[SWIRLWhirRoundConfig] + mu_pow_bits: int + query_phase_pow_bits: int + folding_pow_bits: int + proximity: SWIRLWhirProximityMode + + +@dataclass(frozen=True) +class SWIRLLogUpSecurityParameters: + """ + SWIRL's interaction LogUp bound expressed through the shared error-to-bits path. + """ + + max_interaction_count: int + log_max_message_length: int + pow_bits: int + + def max_message_length(self) -> int: + return 1 << self.log_max_message_length + + def get_soundness_error(self, challenge_field_size: int, list_size: float = 1.0) -> float: + return ( + 2.0 + * self.max_interaction_count + * self.max_message_length() + / (challenge_field_size * list_size) + ) + + def get_soundness_bits(self, challenge_field_size: int, list_size: float = 1.0) -> float: + grounded_error = apply_grinding( + self.get_soundness_error(challenge_field_size, list_size), + self.pow_bits, + ) + return -math.log2(grounded_error) + + +@dataclass(frozen=True) +class SWIRLSystemParams: + l_skip: int + n_stack: int + w_stack: int + log_blowup: int + whir: SWIRLWhirConfig + logup: SWIRLLogUpSecurityParameters + max_constraint_degree: int + + def log_stacked_height(self) -> int: + return self.l_skip + self.n_stack + + +@dataclass(frozen=True) +class SWIRLWhirDetails: + mu_batching_bits: float + fold_rbr_bits: float + proximity_gaps_bits: float + sumcheck_bits: float + ood_rbr_bits: float + shift_rbr_bits: float + query_bits: float + gamma_batching_bits: float + + +@dataclass(frozen=True) +class SWIRLSoundnessResult: + logup_bits: float + gkr_sumcheck_bits: float + gkr_batching_bits: float + zerocheck_sumcheck_bits: float + constraint_batching_bits: float + stacked_reduction_bits: float + whir_bits: float + whir_details: SWIRLWhirDetails + total_bits: float + + +class SWIRLUniqueDecodingRegime(ProximityGapsRegime): + """ + SWIRL's unique-decoding WHIR bound uses `n / |F|` for the proximity-gap term. + + This differs from soundcalc's generic UDR MCA bound, so it stays isolated here. + """ + + def identifier(self) -> str: + return "SWIRL-UDR" + + def get_proximity_parameter(self, rate: float, dimension: int) -> float: + return (1.0 - rate) / 2.0 + + def get_max_list_size(self, rate: float, dimension: int) -> int: + return 1 + + def get_error_powers(self, rate: float, dimension: int, batch_size: int) -> float: + if batch_size <= 1: + return 0.0 + return self.get_error_linear(rate, dimension) * (batch_size - 1) + + def get_error_linear(self, rate: float, dimension: int) -> float: + code_length = dimension / rate + return code_length / self.field.F + + def get_error_multilinear(self, rate: float, dimension: int, batch_size: int) -> float: + if batch_size <= 1: + return 0.0 + return self.get_error_linear(rate, dimension) * math.ceil(math.log2(batch_size)) + + +class SWIRLListDecodingRegime(JohnsonBoundRegime): + """ + SWIRL uses the default BCHKS25 closed-form `a`-bound with an explicit multiplicity `m`. + + The shared Johnson-bound implementation already contains the default `a`-bound algebra. + We only override the parts where SWIRL fixes `m` directly and uses `D_Y` as the list-size + proxy for subsequent soundness terms. + """ + + def __init__(self, field: FieldParams, m: int): + super().__init__(field) + self.explicit_m = max(m, 1) + + def identifier(self) -> str: + return f"SWIRL-LDR(m={self.explicit_m})" + + def get_proximity_parameter(self, rate: float, dimension: int) -> float: + sqrt_rate = math.sqrt(rate) + return 1.0 - sqrt_rate - (sqrt_rate / (2.0 * self.explicit_m)) + + def get_m(self, rate: float, dimension: int) -> int: + return self.explicit_m + + def get_max_list_size(self, rate: float, dimension: int) -> float: + sqrt_rate = math.sqrt(rate) + return (self.explicit_m + 0.5) / sqrt_rate + + +def _challenge_field_bits(field: FieldParams) -> float: + return field.field_extension_degree * math.log2(field.p) + + +def _log2_add(log2_x: float, log2_y: float) -> float: + hi, lo = (log2_x, log2_y) if log2_x >= log2_y else (log2_y, log2_x) + return hi + math.log2(1.0 + (2.0 ** (lo - hi))) + + +def _combine_security_bits(bits_a: float, bits_b: float) -> float: + return -_log2_add(-bits_a, -bits_b) + + +def _n_logup_bound( + l_skip: int, + num_airs: int, + max_interactions_per_air: int, + max_log_height: int, + max_interaction_count: int, +) -> int: + field_bound = math.ceil(math.log2(max_interaction_count)) - l_skip + param_bound = ( + math.ceil(math.log2(num_airs)) + + math.ceil(math.log2(max_interactions_per_air)) + + max_log_height + - l_skip + ) + return min(field_bound, param_bound) + + +def _whir_sumcheck_security(challenge_field_bits: float, sub_round: int, folding_pow_bits: int) -> float: + sumcheck_degree = 2.0 if sub_round == 0 else 3.0 + return challenge_field_bits - math.log2(sumcheck_degree) + folding_pow_bits + + +def _whir_gamma_batching_security( + challenge_field_bits: float, + batch_size: int, + list_size: float, +) -> float: + return challenge_field_bits - math.log2(batch_size) - math.log2(list_size) + + +def _whir_ood_security( + challenge_field_bits: float, + log_degree_at_round_start: int, + list_size: float, +) -> float: + return challenge_field_bits - log_degree_at_round_start + 1.0 - 2.0 * math.log2(list_size) + + +def build_swirl_system_params( + *, + l_skip: int, + n_stack: int, + w_stack: int, + log_blowup: int, + folding_pow_bits: int, + mu_pow_bits: int, + proximity: SWIRLWhirProximityMode, + logup: SWIRLLogUpSecurityParameters, + security_bits_target: int = SWIRL_SECURITY_BITS_TARGET, +) -> SWIRLSystemParams: + protocol_security_level = security_bits_target - SWIRL_QUERY_PHASE_POW_BITS + log_stacked_height = l_skip + n_stack + num_rounds = math.ceil( + max(log_stacked_height - SWIRL_WHIR_MAX_LOG_FINAL_POLY_LEN, 0) / SWIRL_WHIR_K + ) + + rounds: list[SWIRLWhirRoundConfig] = [] + log_inv_rate = log_blowup + for _round in range(num_rounds): + per_query_bits = proximity.whir_query_security_bits(1, log_inv_rate) + num_queries = math.ceil(protocol_security_level / per_query_bits) + rounds.append(SWIRLWhirRoundConfig(num_queries=num_queries)) + log_inv_rate += SWIRL_WHIR_K - 1 + + return SWIRLSystemParams( + l_skip=l_skip, + n_stack=n_stack, + w_stack=w_stack, + log_blowup=log_blowup, + whir=SWIRLWhirConfig( + k=SWIRL_WHIR_K, + rounds=rounds, + mu_pow_bits=mu_pow_bits, + query_phase_pow_bits=SWIRL_QUERY_PHASE_POW_BITS, + folding_pow_bits=folding_pow_bits, + proximity=proximity, + ), + logup=logup, + max_constraint_degree=SWIRL_MAX_CONSTRAINT_DEGREE, + ) + + +def calculate_swirl_soundness( + *, + params: SWIRLSystemParams, + field: FieldParams, + whir: WHIR, + max_num_constraints_per_air: int, + num_airs: int, + max_log_trace_height: int, + num_trace_columns: int, + max_interactions_per_air: int, +) -> SWIRLSoundnessResult: + challenge_field_bits = _challenge_field_bits(field) + n_logup = _n_logup_bound( + params.l_skip, + num_airs, + max_interactions_per_air, + max_log_trace_height, + params.logup.max_interaction_count, + ) + + regime = params.whir.proximity.build_regime(field) + mu_batching_bits = -math.log2(whir._get_batching_error(regime)) + initial_list_size = whir._get_list_size_for_iteration_and_round(0, 0, regime) + log2_list_size = math.log2(initial_list_size) + + logup_bits = params.logup.get_soundness_bits( + field.F, + initial_list_size, + ) + + gkr_sumcheck_bits = challenge_field_bits - math.log2(3.0) + gkr_batching_bits = challenge_field_bits + + univariate_degree = (params.max_constraint_degree + 1) * ((1 << params.l_skip) - 1) + multilinear_degree = params.max_constraint_degree + 1 + zerocheck_sumcheck_bits = challenge_field_bits - math.log2(max(univariate_degree, multilinear_degree)) + + n_max = max_log_trace_height - params.l_skip + poly_degree_sum = ((1 << params.l_skip) - 1) + n_max + poly_identity_bits = challenge_field_bits - math.log2(poly_degree_sum) + zerocheck_bits = log2_list_size + min(zerocheck_sumcheck_bits, poly_identity_bits) + + lambda_batching_bits = challenge_field_bits - math.log2(max_num_constraints_per_air) + mu_constraint_bits = challenge_field_bits - math.log2(3.0 * num_airs) + constraint_batching_bits = log2_list_size + min(lambda_batching_bits, mu_constraint_bits) + + stacked_batching_bits = challenge_field_bits - math.log2(2.0 * num_trace_columns) + stacked_univariate_bits = challenge_field_bits - math.log2(2.0 * ((1 << params.l_skip) - 1)) + stacked_multilinear_bits = challenge_field_bits - 1.0 + stacked_reduction_bits = log2_list_size + min( + stacked_batching_bits, + stacked_univariate_bits, + stacked_multilinear_bits, + ) + + min_query_bits = math.inf + min_proximity_gaps_bits = math.inf + min_sumcheck_bits = math.inf + min_ood_bits = math.inf + min_gamma_batching_bits = math.inf + min_fold_rbr_bits = math.inf + min_shift_rbr_bits = math.inf + min_whir_bits = mu_batching_bits + + for round_index, round_config in enumerate(params.whir.rounds): + for sub_round in range(params.whir.k): + fold_bits = -math.log2(whir._epsilon_fold(round_index, sub_round + 1, regime)) + min_fold_rbr_bits = min(min_fold_rbr_bits, fold_bits) + min_whir_bits = min(min_whir_bits, fold_bits) + + rate, dimension = whir._get_code_for_iteration_and_round(round_index, sub_round + 1) + proximity_error = regime.get_error_powers(rate, dimension, 2) + proximity_bits = -math.log2(proximity_error) + params.whir.folding_pow_bits + min_proximity_gaps_bits = min(min_proximity_gaps_bits, proximity_bits) + + sumcheck_bits = _whir_sumcheck_security( + challenge_field_bits, + sub_round, + params.whir.folding_pow_bits, + ) + min_sumcheck_bits = min(min_sumcheck_bits, sumcheck_bits) + + query_bits = ( + params.whir.proximity.whir_query_security_bits( + round_config.num_queries, + whir.log_inv_rates[round_index], + ) + + params.whir.query_phase_pow_bits + ) + min_query_bits = min(min_query_bits, query_bits) + + next_list_size = whir._get_list_size_for_iteration_and_round( + round_index, + params.whir.k, + regime, + ) + gamma_batching_bits = _whir_gamma_batching_security( + challenge_field_bits, + round_config.num_queries + 1, + next_list_size, + ) + min_gamma_batching_bits = min(min_gamma_batching_bits, gamma_batching_bits) + + shift_rbr_bits = _combine_security_bits(query_bits, gamma_batching_bits) + min_shift_rbr_bits = min(min_shift_rbr_bits, shift_rbr_bits) + min_whir_bits = min(min_whir_bits, shift_rbr_bits) + + if round_index < whir.num_iterations - 1: + ood_bits = _whir_ood_security( + challenge_field_bits, + whir.log_degrees[round_index + 1], + next_list_size, + ) + min_ood_bits = min(min_ood_bits, ood_bits) + min_whir_bits = min(min_whir_bits, ood_bits) + + whir_details = SWIRLWhirDetails( + mu_batching_bits=mu_batching_bits, + fold_rbr_bits=min_fold_rbr_bits, + proximity_gaps_bits=min_proximity_gaps_bits, + sumcheck_bits=min_sumcheck_bits, + ood_rbr_bits=min_ood_bits, + shift_rbr_bits=min_shift_rbr_bits, + query_bits=min_query_bits, + gamma_batching_bits=min_gamma_batching_bits, + ) + + total_bits = min( + logup_bits, + gkr_sumcheck_bits, + gkr_batching_bits, + zerocheck_bits, + constraint_batching_bits, + stacked_reduction_bits, + min_whir_bits, + ) + + return SWIRLSoundnessResult( + logup_bits=logup_bits, + gkr_sumcheck_bits=gkr_sumcheck_bits, + gkr_batching_bits=gkr_batching_bits, + zerocheck_sumcheck_bits=zerocheck_bits, + constraint_batching_bits=constraint_batching_bits, + stacked_reduction_bits=stacked_reduction_bits, + whir_bits=min_whir_bits, + whir_details=whir_details, + total_bits=total_bits, + ) diff --git a/soundcalc/custom/swirl/circuit.py b/soundcalc/custom/swirl/circuit.py new file mode 100644 index 0000000..f5bea7a --- /dev/null +++ b/soundcalc/custom/swirl/circuit.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from dataclasses import dataclass + +from soundcalc.common.fields import FieldParams +from soundcalc.custom.swirl.calculator import SWIRLSoundnessResult, SWIRLSystemParams, calculate_swirl_soundness +from soundcalc.pcs.whir import WHIR +from soundcalc.zkvms.circuit import Circuit, CircuitConfig + + +@dataclass(frozen=True) +class SWIRLCircuitConfig: + name: str + pcs: WHIR + field: FieldParams + params: SWIRLSystemParams + max_num_constraints_per_air: int + num_airs: int + max_log_trace_height: int + num_trace_columns: int + max_interactions_per_air: int + + +class SWIRLCircuit(Circuit): + def __init__(self, config: SWIRLCircuitConfig): + super().__init__(CircuitConfig(name=config.name, pcs=config.pcs, field=config.field, udr_only=True)) + self.params = config.params + self.max_num_constraints_per_air = config.max_num_constraints_per_air + self.num_airs = config.num_airs + self.max_log_trace_height = config.max_log_trace_height + self.num_trace_columns = config.num_trace_columns + self.max_interactions_per_air = config.max_interactions_per_air + self.protocol_label = "SWIRL" + self._soundness_result: SWIRLSoundnessResult | None = None + + def get_soundness_result(self) -> SWIRLSoundnessResult: + if self._soundness_result is None: + self._soundness_result = calculate_swirl_soundness( + params=self.params, + field=self.field, + whir=self.pcs, + max_num_constraints_per_air=self.max_num_constraints_per_air, + num_airs=self.num_airs, + max_log_trace_height=self.max_log_trace_height, + num_trace_columns=self.num_trace_columns, + max_interactions_per_air=self.max_interactions_per_air, + ) + return self._soundness_result + + def get_security_levels(self) -> dict[str, dict[str, float]]: + result = self.get_soundness_result() + levels = { + "logup": round(result.logup_bits, 1), + "gkr_sumcheck": round(result.gkr_sumcheck_bits, 1), + "gkr_batching": round(result.gkr_batching_bits, 1), + "zerocheck_sumcheck": round(result.zerocheck_sumcheck_bits, 1), + "constraint_batching": round(result.constraint_batching_bits, 1), + "stacked_reduction": round(result.stacked_reduction_bits, 1), + "whir": round(result.whir_bits, 1), + "whir.query": round(result.whir_details.query_bits, 1), + "whir.proximity_gaps": round(result.whir_details.proximity_gaps_bits, 1), + "whir.sumcheck": round(result.whir_details.sumcheck_bits, 1), + "whir.fold_rbr": round(result.whir_details.fold_rbr_bits, 1), + "whir.ood_rbr": round(result.whir_details.ood_rbr_bits, 1), + "whir.gamma_batching": round(result.whir_details.gamma_batching_bits, 1), + "whir.shift_rbr": round(result.whir_details.shift_rbr_bits, 1), + "whir.mu_batching": round(result.whir_details.mu_batching_bits, 1), + "total": round(result.total_bits, 1), + } + return {"SWIRL": levels} + + def get_parameter_summary(self) -> str: + lines = [ + "", + "```", + " protocol_family : SWIRL", + " pcs : WHIR", + f" field : {self.field.to_string()}", + f" l_skip : {self.params.l_skip}", + f" n_stack : {self.params.n_stack}", + f" w_stack : {self.params.w_stack}", + f" log_blowup : {self.params.log_blowup}", + f" max_constraint_degree : {self.params.max_constraint_degree}", + f" whir_queries : {[round_config.num_queries for round_config in self.params.whir.rounds]}", + f" whir_folding_pow_bits : {self.params.whir.folding_pow_bits}", + f" whir_query_phase_pow_bits : {self.params.whir.query_phase_pow_bits}", + f" whir_mu_pow_bits : {self.params.whir.mu_pow_bits}", + f" max_constraints_per_air : {self.max_num_constraints_per_air}", + f" num_airs : {self.num_airs}", + f" max_log_trace_height : {self.max_log_trace_height}", + f" num_trace_columns : {self.num_trace_columns}", + f" max_interactions_per_air : {self.max_interactions_per_air}", + "```", + ] + return "\n".join(lines) + + def get_report_parameter_lines(self) -> list[str]: + return [ + "- Proof system: SWIRL", + "- Inner PCS: WHIR", + f"- Field: {self.field.to_string()}", + f"- `l_skip`: {self.params.l_skip}", + f"- `n_stack`: {self.params.n_stack}", + f"- `w_stack`: {self.params.w_stack}", + f"- Log blowup: {self.params.log_blowup}", + f"- WHIR queries per round: {[round_config.num_queries for round_config in self.params.whir.rounds]}", + f"- WHIR folding PoW (bits): {self.params.whir.folding_pow_bits}", + f"- WHIR query-phase PoW (bits): {self.params.whir.query_phase_pow_bits}", + f"- WHIR μ PoW (bits): {self.params.whir.mu_pow_bits}", + f"- Max constraints per AIR: {self.max_num_constraints_per_air}", + f"- Number of AIRs: {self.num_airs}", + f"- Max log trace height: {self.max_log_trace_height}", + f"- Number of trace columns: {self.num_trace_columns}", + f"- Max interactions per AIR: {self.max_interactions_per_air}", + ] diff --git a/soundcalc/main.py b/soundcalc/main.py index d2e2eee..7175a6d 100644 --- a/soundcalc/main.py +++ b/soundcalc/main.py @@ -6,7 +6,7 @@ from __future__ import annotations -from soundcalc.zkvms import risc0, miden, zisk, dummy_whir, pico, openvm, airbender, sp1 +from soundcalc.zkvms import risc0, miden, zisk, dummy_whir, pico, openvm, openvm2, airbender, sp1 from soundcalc import report_cli, report_md # All zkVM loaders @@ -17,6 +17,7 @@ ("DummyWHIR", dummy_whir.load), ("Pico", pico.load), ("OpenVM", openvm.load), + ("OpenVM 2.0", openvm2.load), ("Airbender", airbender.load), ("SP1", sp1.load), ] diff --git a/soundcalc/report_md.py b/soundcalc/report_md.py index e1de5f4..a64270d 100644 --- a/soundcalc/report_md.py +++ b/soundcalc/report_md.py @@ -33,7 +33,7 @@ class zkVMSummary: pcs: str num_circuits: int weakest_circuit_name: str - security_bits: int + security_bits: float security_regime: str final_proof_size_kib: int @@ -92,8 +92,16 @@ def _field_label(field) -> str: return "Unknown" +def _format_security_value(value: Any) -> str: + if isinstance(value, float): + return f"{value:.1f}" + return str(value) + + def _pcs_label(circuit: Circuit) -> str: """Get the PCS type label for a circuit.""" + if hasattr(circuit, "protocol_label"): + return getattr(circuit, "protocol_label") if isinstance(circuit.pcs, FRI): return "FRI" elif isinstance(circuit.pcs, WHIR): @@ -229,6 +237,8 @@ def _lookup_parameter_lines(circuit: Circuit) -> list[str]: def _get_parameter_lines(circuit: Circuit) -> list[str]: """Get parameter lines for a circuit.""" + if hasattr(circuit, "get_report_parameter_lines"): + return circuit.get_report_parameter_lines() if isinstance(circuit.pcs, FRI): lines = _fri_parameter_lines(circuit) elif isinstance(circuit.pcs, WHIR): @@ -315,12 +325,12 @@ def row_has_single_value(row: dict[str, Any]) -> bool: row_values = [row_name] if isinstance(row_data, dict): for col in columns[1:]: - row_values.append(str(row_data.get(col, "—"))) + row_values.append(_format_security_value(row_data.get(col, "—"))) else: # Non-dict value sits under the 'total' column when present. for col in columns[1:]: if col == "total": - row_values.append(str(row_data)) + row_values.append(_format_security_value(row_data)) else: row_values.append("—") md_table += "| " + " | ".join(row_values) + " |\n" @@ -366,7 +376,7 @@ def _build_zkvm_report(zkvm: zkVM, multi_circuit: bool = False) -> str: lines.append(f"| Metric | Value | Relevant circuit | Notes |") lines.append(f"| --- | --- | --- | --- |") lines.append(f"| Final proof size (worst case) | **{int(overview['final_proof_size_kib'])} KiB** | {final_circuit_link} | |") - lines.append(f"| Final bits of security | **{overview['min_security_bits']} bits** | {offending_circuit_link} | Regime: {overview['best_regime']} |") + lines.append(f"| Final bits of security | **{_format_security_value(overview['min_security_bits'])} bits** | {offending_circuit_link} | Regime: {overview['best_regime']} |") lines.append("") lines.append("## Circuits") @@ -435,7 +445,7 @@ def _build_summary_report(zkvms: list[zkVM]) -> str: "", "How to read this report:", "- Click on zkVM names to view detailed individual reports", - "- Security shows the best bits of security across regimes (UDR/JBR)", + "- Security shows the best bits of security across the reported regimes", "", "## Overview", "", @@ -454,7 +464,7 @@ def _build_summary_report(zkvms: list[zkVM]) -> str: lines.append( f"| [{s.name}]({report_filename}) " f"| {version_str} " - f"| **{s.security_bits}** bits ({s.security_regime}) " + f"| **{_format_security_value(s.security_bits)}** bits ({s.security_regime}) " f"| {s.final_proof_size_kib} KiB " f"| {s.pcs} | {s.field} | {s.num_circuits} | {s.weakest_circuit_name} |" ) @@ -463,7 +473,7 @@ def _build_summary_report(zkvms: list[zkVM]) -> str: "", "## Notes", "", - "- **Security**: Best bits of security across UDR (Unique Decoding) and JBR (Johnson Bound) regimes", + "- **Security**: Best bits of security across the reported regimes", "- **Weakest Circuit**: Circuit determining the overall security level", "- **Proof Size**: Final proof size in KiB (1 KiB = 1024 bytes)", "", diff --git a/soundcalc/zkvms/openvm2/__init__.py b/soundcalc/zkvms/openvm2/__init__.py new file mode 100644 index 0000000..68a748a --- /dev/null +++ b/soundcalc/zkvms/openvm2/__init__.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from pathlib import Path + +import toml + +from soundcalc.common.fields import parse_field +from soundcalc.custom.swirl import ( + SWIRLCircuit, + SWIRLCircuitConfig, + SWIRLLogUpSecurityParameters, + SWIRLWhirProximityMode, + build_swirl_system_params, +) +from soundcalc.pcs.whir import WHIR, WHIRConfig +from soundcalc.zkvms.zkvm import zkVM + + +def load() -> zkVM: + with open(Path(__file__).parent / "openvm2.toml", "r") as f: + config = toml.load(f) + + field = parse_field(config["zkevm"]["field"]) + hash_size_bits = config["zkevm"]["hash_size_bits"] + logup = SWIRLLogUpSecurityParameters( + max_interaction_count=config["swirl"]["logup_max_interaction_count"], + log_max_message_length=config["swirl"]["logup_log_max_message_length"], + pow_bits=config["swirl"]["logup_pow_bits"], + ) + + circuits = [] + for section in config.get("circuits", []): + if section["whir_proximity"] == "unique": + proximity = SWIRLWhirProximityMode(kind="unique") + else: + proximity = SWIRLWhirProximityMode(kind="list", m=section["whir_m"]) + + params = build_swirl_system_params( + l_skip=section["l_skip"], + n_stack=section["n_stack"], + w_stack=section["w_stack"], + log_blowup=section["log_blowup"], + folding_pow_bits=section["whir_folding_pow_bits"], + mu_pow_bits=section["whir_mu_pow_bits"], + proximity=proximity, + logup=logup, + ) + whir = WHIR(WHIRConfig( + hash_size_bits=hash_size_bits, + log_inv_rate=params.log_blowup, + num_iterations=len(params.whir.rounds), + folding_factor=params.whir.k, + field=field, + log_degree=params.log_stacked_height(), + batch_size=params.w_stack, + power_batching=True, + grinding_batching_phase=params.whir.mu_pow_bits, + constraint_degree=section["constraint_degree"], + grinding_bits_folding=[ + [params.whir.folding_pow_bits] * params.whir.k + for _ in params.whir.rounds + ], + num_queries=[round_config.num_queries for round_config in params.whir.rounds], + grinding_bits_queries=[params.whir.query_phase_pow_bits] * len(params.whir.rounds), + num_ood_samples=[1] * max(len(params.whir.rounds) - 1, 0), + grinding_bits_ood=[0] * max(len(params.whir.rounds) - 1, 0), + )) + circuits.append(SWIRLCircuit(SWIRLCircuitConfig( + name=section["name"], + pcs=whir, + field=field, + params=params, + max_num_constraints_per_air=section["max_constraints_per_air"], + num_airs=section["num_airs"], + max_log_trace_height=section["max_log_trace_height"], + num_trace_columns=section["num_trace_columns"], + max_interactions_per_air=section["max_interactions_per_air"], + ))) + + return zkVM(config["zkevm"]["name"], circuits=circuits, version=config["zkevm"].get("version")) diff --git a/soundcalc/zkvms/openvm2/openvm2.toml b/soundcalc/zkvms/openvm2/openvm2.toml new file mode 100644 index 0000000..0b071aa --- /dev/null +++ b/soundcalc/zkvms/openvm2/openvm2.toml @@ -0,0 +1,77 @@ +[zkevm] +name = "OpenVM2" +protocol_family = "SWIRL" +version = "2.0.0-beta" +field = "BabyBear^4" +hash_size_bits = 256 + +[swirl] +logup_max_interaction_count = 2013265921 +logup_log_max_message_length = 7 +logup_pow_bits = 18 + +[[circuits]] +name = "app" +l_skip = 4 +n_stack = 20 +w_stack = 2048 +log_blowup = 1 +whir_folding_pow_bits = 5 +whir_mu_pow_bits = 15 +whir_proximity = "unique" +constraint_degree = 3 +max_constraints_per_air = 5000 +num_airs = 100 +max_log_trace_height = 24 +num_trace_columns = 30000 +max_interactions_per_air = 1000 + +[[circuits]] +name = "leaf" +l_skip = 4 +n_stack = 17 +w_stack = 2048 +log_blowup = 2 +whir_folding_pow_bits = 4 +whir_mu_pow_bits = 13 +whir_proximity = "unique" +constraint_degree = 3 +max_constraints_per_air = 1000 +num_airs = 50 +max_log_trace_height = 20 +num_trace_columns = 2000 +max_interactions_per_air = 100 + +[[circuits]] +name = "internal" +l_skip = 2 +n_stack = 17 +w_stack = 512 +log_blowup = 3 +whir_folding_pow_bits = 18 +whir_mu_pow_bits = 20 +whir_proximity = "list" +whir_m = 2 +constraint_degree = 3 +max_constraints_per_air = 1000 +num_airs = 50 +max_log_trace_height = 19 +num_trace_columns = 2000 +max_interactions_per_air = 100 + +[[circuits]] +name = "root" +l_skip = 2 +n_stack = 19 +w_stack = 9 +log_blowup = 4 +whir_folding_pow_bits = 20 +whir_mu_pow_bits = 20 +whir_proximity = "list" +whir_m = 2 +constraint_degree = 3 +max_constraints_per_air = 1000 +num_airs = 50 +max_log_trace_height = 21 +num_trace_columns = 2000 +max_interactions_per_air = 100 From 5867036f5c660bbda9780b21ae0ab8409b15ac7b Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 13 Apr 2026 22:18:42 +0000 Subject: [PATCH 2/4] Refactor SWIRL to reuse shared WHIR soundness --- reports/openvm2.md | 129 +++++++++++++++++++++++++++ reports/summary.md | 10 +-- soundcalc/custom/swirl/calculator.py | 89 ++++++------------ soundcalc/pcs/whir.py | 33 +++---- soundcalc/zkvms/openvm2/openvm2.toml | 6 +- 5 files changed, 178 insertions(+), 89 deletions(-) create mode 100644 reports/openvm2.md diff --git a/reports/openvm2.md b/reports/openvm2.md new file mode 100644 index 0000000..f85e2c0 --- /dev/null +++ b/reports/openvm2.md @@ -0,0 +1,129 @@ +# 📊 OpenVM2 (v2.0.0-beta) + +How to read this report: +- Table rows correspond to security regimes +- Table columns correspond to proof system components +- Cells show bits of security per component +- Proof size estimates are indicative (1 KiB = 1024 bytes) + +## zkVM Overview + +| Metric | Value | Relevant circuit | Notes | +| --- | --- | --- | --- | +| Final proof size (worst case) | **140 KiB** | [root](#root) | | +| Final bits of security | **100.0 bits** | [leaf](#leaf) | Regime: SWIRL | + +## Circuits + +- [app](#app) +- [leaf](#leaf) +- [internal](#internal) +- [root](#root) + +## app + +**Parameters:** +- Proof system: SWIRL +- Inner PCS: WHIR +- Field: BabyBear⁴ +- `l_skip`: 4 +- `n_stack`: 20 +- `w_stack`: 2048 +- Log blowup: 1 +- WHIR queries per round: [193, 88, 81, 81] +- WHIR folding PoW (bits): 5 +- WHIR query-phase PoW (bits): 20 +- WHIR μ PoW (bits): 15 +- Max constraints per AIR: 5000 +- Number of AIRs: 100 +- Max log trace height: 24 +- Number of trace columns: 30000 +- Max interactions per AIR: 1000 + +**Proof Size:** 24165 KiB (expected) / 24272 KiB (worst case) + +| regime | total | constraint_batching | gkr_batching | gkr_sumcheck | logup | stacked_reduction | whir | whir.fold_rbr | whir.gamma_batching | whir.mu_batching | whir.ood_rbr | whir.proximity_gaps | whir.query | whir.shift_rbr | whir.sumcheck | zerocheck_sumcheck | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| SWIRL | 100.1 | 111.3 | 123.6 | 122.0 | 102.7 | 107.8 | 100.1 | 104.6 | 116.0 | 102.6 | 104.6 | 104.6 | 100.1 | 100.1 | 127.0 | 117.4 | + + +## leaf + +**Parameters:** +- Proof system: SWIRL +- Inner PCS: WHIR +- Field: BabyBear⁴ +- `l_skip`: 4 +- `n_stack`: 17 +- `w_stack`: 2048 +- Log blowup: 2 +- WHIR queries per round: [118, 84, 81] +- WHIR folding PoW (bits): 4 +- WHIR query-phase PoW (bits): 20 +- WHIR μ PoW (bits): 13 +- Max constraints per AIR: 1000 +- Number of AIRs: 50 +- Max log trace height: 20 +- Number of trace columns: 2000 +- Max interactions per AIR: 100 + +**Proof Size:** 14775 KiB (expected) / 14840 KiB (worst case) + +| regime | total | constraint_batching | gkr_batching | gkr_sumcheck | logup | stacked_reduction | whir | whir.fold_rbr | whir.gamma_batching | whir.mu_batching | whir.ood_rbr | whir.proximity_gaps | whir.query | whir.shift_rbr | whir.sumcheck | zerocheck_sumcheck | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| SWIRL | 100.0 | 113.7 | 123.6 | 122.0 | 102.7 | 111.7 | 100.0 | 105.6 | 116.7 | 102.6 | 107.6 | 105.6 | 100.0 | 100.0 | 126.0 | 117.4 | + + +## internal + +**Parameters:** +- Proof system: SWIRL +- Inner PCS: WHIR +- Field: BabyBear⁴ +- `l_skip`: 2 +- `n_stack`: 17 +- `w_stack`: 512 +- Log blowup: 3 +- WHIR queries per round: [68, 30, 20] +- WHIR folding PoW (bits): 18 +- WHIR query-phase PoW (bits): 20 +- WHIR μ PoW (bits): 20 +- Max constraints per AIR: 1000 +- Number of AIRs: 50 +- Max log trace height: 19 +- Number of trace columns: 2000 +- Max interactions per AIR: 100 + +**Proof Size:** 2164 KiB (expected) / 2186 KiB (worst case) + +| regime | total | constraint_batching | gkr_batching | gkr_sumcheck | logup | stacked_reduction | whir | whir.fold_rbr | whir.gamma_batching | whir.mu_batching | whir.ood_rbr | whir.proximity_gaps | whir.query | whir.shift_rbr | whir.sumcheck | zerocheck_sumcheck | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| SWIRL | 100.1 | 116.5 | 123.6 | 122.0 | 105.5 | 114.5 | 100.1 | 103.1 | 112.9 | 102.1 | 101.0 | 103.1 | 100.1 | 100.1 | 134.2 | 122.1 | + + +## root + +**Parameters:** +- Proof system: SWIRL +- Inner PCS: WHIR +- Field: BabyBear⁴ +- `l_skip`: 2 +- `n_stack`: 18 +- `w_stack`: 18 +- Log blowup: 4 +- WHIR queries per round: [57, 28, 19] +- WHIR folding PoW (bits): 20 +- WHIR query-phase PoW (bits): 20 +- WHIR μ PoW (bits): 20 +- Max constraints per AIR: 1000 +- Number of AIRs: 50 +- Max log trace height: 21 +- Number of trace columns: 2000 +- Max interactions per AIR: 100 + +**Proof Size:** 121 KiB (expected) / 140 KiB (worst case) + +| regime | total | constraint_batching | gkr_batching | gkr_sumcheck | logup | stacked_reduction | whir | whir.fold_rbr | whir.gamma_batching | whir.mu_batching | whir.ood_rbr | whir.proximity_gaps | whir.query | whir.shift_rbr | whir.sumcheck | zerocheck_sumcheck | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| SWIRL | 100.5 | 116.2 | 123.6 | 122.0 | 105.3 | 114.2 | 100.5 | 105.3 | 113.2 | 107.2 | 100.5 | 105.3 | 100.7 | 100.7 | 136.5 | 121.8 | + diff --git a/reports/summary.md b/reports/summary.md index cd413a6..3e9a603 100644 --- a/reports/summary.md +++ b/reports/summary.md @@ -2,20 +2,16 @@ How to read this report: - Click on zkVM names to view detailed individual reports -- Security shows the best bits of security across regimes (UDR/JBR) +- Security shows the best bits of security across the reported regimes ## Overview | zkVM | Version | Security | Proof Size | PCS | Field | Circuits | Weakest Circuit | |------|---------|----------|------------|-----|-------|----------|-----------------| -| [Airbender](airbender.md) | — | **64** bits (UDR) | 1951 KiB | FRI | M31⁴ | 1 | generalized_circuit | -| [OpenVM](openvm.md) | 1.5.0 | **100** bits (UDR) | 8231 KiB | FRI | BabyBear⁴ | 3 | app | -| [Pico](pico.md) | — | **53** bits (JBR) | 281 KiB | FRI | KoalaBear⁴ | 5 | riscv | -| [SP1](sp1.md) | — | **98** bits (UDR) | 1001 KiB | Unknown | KoalaBear⁴ | 4 | wrap | -| [ZisK](zisk.md) | 0.16.1 | **128** bits (JBR) | 313 KiB | FRI | Goldilocks³ | 44 | Dma | +| [OpenVM2](openvm2.md) | 2.0.0-beta | **100.0** bits (SWIRL) | 140 KiB | SWIRL | BabyBear⁴ | 4 | leaf | ## Notes -- **Security**: Best bits of security across UDR (Unique Decoding) and JBR (Johnson Bound) regimes +- **Security**: Best bits of security across the reported regimes - **Weakest Circuit**: Circuit determining the overall security level - **Proof Size**: Final proof size in KiB (1 KiB = 1024 bytes) diff --git a/soundcalc/custom/swirl/calculator.py b/soundcalc/custom/swirl/calculator.py index 9ce6615..e7d5570 100644 --- a/soundcalc/custom/swirl/calculator.py +++ b/soundcalc/custom/swirl/calculator.py @@ -43,7 +43,7 @@ def whir_query_security_bits(self, num_queries: int, log_inv_rate: int) -> float elif self.kind == "list": if self.m is None: raise ValueError("list-decoding mode requires multiplicity m") - max_agreement = math.sqrt(rho * (1.0 + 1.0 / self.m)) + 1e-6 + max_agreement = math.sqrt(rho) * (1.0 + 1.0 / (2.0 * self.m)) else: raise ValueError(f"Unknown SWIRL proximity mode: {self.kind}") @@ -201,42 +201,20 @@ def _combine_security_bits(bits_a: float, bits_b: float) -> float: return -_log2_add(-bits_a, -bits_b) -def _n_logup_bound( - l_skip: int, - num_airs: int, - max_interactions_per_air: int, - max_log_height: int, - max_interaction_count: int, -) -> int: - field_bound = math.ceil(math.log2(max_interaction_count)) - l_skip - param_bound = ( - math.ceil(math.log2(num_airs)) - + math.ceil(math.log2(max_interactions_per_air)) - + max_log_height - - l_skip - ) - return min(field_bound, param_bound) - - -def _whir_sumcheck_security(challenge_field_bits: float, sub_round: int, folding_pow_bits: int) -> float: - sumcheck_degree = 2.0 if sub_round == 0 else 3.0 - return challenge_field_bits - math.log2(sumcheck_degree) + folding_pow_bits - - -def _whir_gamma_batching_security( +def _whir_sumcheck_security( challenge_field_bits: float, - batch_size: int, list_size: float, + folding_pow_bits: int, ) -> float: - return challenge_field_bits - math.log2(batch_size) - math.log2(list_size) + return challenge_field_bits - math.log2(3.0) - math.log2(list_size) + folding_pow_bits -def _whir_ood_security( +def _whir_gamma_batching_security( challenge_field_bits: float, - log_degree_at_round_start: int, + batch_size: int, list_size: float, ) -> float: - return challenge_field_bits - log_degree_at_round_start + 1.0 - 2.0 * math.log2(list_size) + return challenge_field_bits - math.log2(batch_size) - math.log2(list_size) def build_swirl_system_params( @@ -295,14 +273,6 @@ def calculate_swirl_soundness( max_interactions_per_air: int, ) -> SWIRLSoundnessResult: challenge_field_bits = _challenge_field_bits(field) - n_logup = _n_logup_bound( - params.l_skip, - num_airs, - max_interactions_per_air, - max_log_trace_height, - params.logup.max_interaction_count, - ) - regime = params.whir.proximity.build_regime(field) mu_batching_bits = -math.log2(whir._get_batching_error(regime)) initial_list_size = whir._get_list_size_for_iteration_and_round(0, 0, regime) @@ -348,6 +318,14 @@ def calculate_swirl_soundness( min_whir_bits = mu_batching_bits for round_index, round_config in enumerate(params.whir.rounds): + current_list_size = whir._get_list_size_for_iteration_and_round(round_index, 0, regime) + sumcheck_bits = _whir_sumcheck_security( + challenge_field_bits, + current_list_size, + params.whir.folding_pow_bits, + ) + min_sumcheck_bits = min(min_sumcheck_bits, sumcheck_bits) + for sub_round in range(params.whir.k): fold_bits = -math.log2(whir._epsilon_fold(round_index, sub_round + 1, regime)) min_fold_rbr_bits = min(min_fold_rbr_bits, fold_bits) @@ -358,27 +336,15 @@ def calculate_swirl_soundness( proximity_bits = -math.log2(proximity_error) + params.whir.folding_pow_bits min_proximity_gaps_bits = min(min_proximity_gaps_bits, proximity_bits) - sumcheck_bits = _whir_sumcheck_security( - challenge_field_bits, - sub_round, - params.whir.folding_pow_bits, - ) - min_sumcheck_bits = min(min_sumcheck_bits, sumcheck_bits) - - query_bits = ( - params.whir.proximity.whir_query_security_bits( - round_config.num_queries, - whir.log_inv_rates[round_index], - ) - + params.whir.query_phase_pow_bits - ) + query_bits = -math.log2(whir._epsilon_query(round_index, regime)) min_query_bits = min(min_query_bits, query_bits) - next_list_size = whir._get_list_size_for_iteration_and_round( - round_index, - params.whir.k, - regime, - ) + if round_index == whir.num_iterations - 1: + min_shift_rbr_bits = min(min_shift_rbr_bits, query_bits) + min_whir_bits = min(min_whir_bits, query_bits) + continue + + next_list_size = whir._get_list_size_for_iteration_and_round(round_index + 1, 0, regime) gamma_batching_bits = _whir_gamma_batching_security( challenge_field_bits, round_config.num_queries + 1, @@ -390,14 +356,9 @@ def calculate_swirl_soundness( min_shift_rbr_bits = min(min_shift_rbr_bits, shift_rbr_bits) min_whir_bits = min(min_whir_bits, shift_rbr_bits) - if round_index < whir.num_iterations - 1: - ood_bits = _whir_ood_security( - challenge_field_bits, - whir.log_degrees[round_index + 1], - next_list_size, - ) - min_ood_bits = min(min_ood_bits, ood_bits) - min_whir_bits = min(min_whir_bits, ood_bits) + ood_bits = -math.log2(whir._epsilon_out(round_index + 1, regime)) + min_ood_bits = min(min_ood_bits, ood_bits) + min_whir_bits = min(min_whir_bits, ood_bits) whir_details = SWIRLWhirDetails( mu_batching_bits=mu_batching_bits, diff --git a/soundcalc/pcs/whir.py b/soundcalc/pcs/whir.py index ad93713..389030c 100644 --- a/soundcalc/pcs/whir.py +++ b/soundcalc/pcs/whir.py @@ -607,6 +607,23 @@ def _get_batching_error(self, regime: ProximityGapsRegime) -> float: epsilon = apply_grinding(epsilon, self.grinding_batching_phase) return epsilon + def _epsilon_query(self, iteration: int, regime: ProximityGapsRegime) -> float: + """ + Returns the query-only error (1-delta_i)^{t_i} for the given iteration, + including the per-query grinding. + """ + + assert 0 <= iteration < self.num_iterations, "Iteration out of bounds" + + t = self.num_queries[iteration] + delta = self._get_delta_for_iteration(iteration, regime) + + assert 0 < delta < 1.0, f"Invalid delta {delta} for iteration {iteration}" + + epsilon = (1.0 - delta) ** t + epsilon = apply_grinding(epsilon, self.grinding_bits_queries[iteration]) + return epsilon + def _epsilon_fold( self, iteration: int, round: int, regime: ProximityGapsRegime ) -> float: @@ -699,21 +716,7 @@ def _epsilon_final(self, regime: ProximityGapsRegime) -> float: Returns the error epsilon^fin from the paper (Theorem 5.2 in WHIR paper). """ - t_final = self.num_queries[-1] - grinding_bits = self.grinding_bits_queries[-1] - - # the error is (1-delta_{M-1})^{t_{M-1}} - delta = self._get_delta_for_iteration(self.num_iterations - 1, regime) - - # Sanity Check: If delta is 1.0, the code has no redundancy, and security is 0. - # (Technically error=0, but this implies a broken config). - assert 0 < delta < 1.0, f"Invalid delta {delta} for final round" - - epsilon = (1.0 - delta) ** t_final - - # grinding - epsilon = apply_grinding(epsilon, grinding_bits) - return epsilon + return self._epsilon_query(self.num_iterations - 1, regime) def _get_log_grinding_overhead(self) -> float: """ diff --git a/soundcalc/zkvms/openvm2/openvm2.toml b/soundcalc/zkvms/openvm2/openvm2.toml index 0b071aa..77cc5b3 100644 --- a/soundcalc/zkvms/openvm2/openvm2.toml +++ b/soundcalc/zkvms/openvm2/openvm2.toml @@ -62,13 +62,13 @@ max_interactions_per_air = 100 [[circuits]] name = "root" l_skip = 2 -n_stack = 19 -w_stack = 9 +n_stack = 18 +w_stack = 18 log_blowup = 4 whir_folding_pow_bits = 20 whir_mu_pow_bits = 20 whir_proximity = "list" -whir_m = 2 +whir_m = 1 constraint_degree = 3 max_constraints_per_air = 1000 num_airs = 50 From 96436007e5963e08f139da3c639de75010f4379b Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 13 Apr 2026 22:26:16 +0000 Subject: [PATCH 3/4] chore: add comment about root --- soundcalc/zkvms/openvm2/openvm2.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/soundcalc/zkvms/openvm2/openvm2.toml b/soundcalc/zkvms/openvm2/openvm2.toml index 77cc5b3..134708d 100644 --- a/soundcalc/zkvms/openvm2/openvm2.toml +++ b/soundcalc/zkvms/openvm2/openvm2.toml @@ -59,6 +59,7 @@ max_log_trace_height = 19 num_trace_columns = 2000 max_interactions_per_air = 100 +# the root aggregation layer is only used for STARK-to-SNARK recursion [[circuits]] name = "root" l_skip = 2 From 018a728cb42e374a32c170302e0c2b95ab28324d Mon Sep 17 00:00:00 2001 From: Jonathan Wang <31040440+jonathanpwang@users.noreply.github.com> Date: Mon, 13 Apr 2026 22:33:35 +0000 Subject: [PATCH 4/4] Use shared WHIR soundness errors in SWIRL --- soundcalc/custom/swirl/calculator.py | 96 ++++++++++++++-------------- soundcalc/pcs/whir.py | 74 +++++++++++++++++---- 2 files changed, 110 insertions(+), 60 deletions(-) diff --git a/soundcalc/custom/swirl/calculator.py b/soundcalc/custom/swirl/calculator.py index e7d5570..4cbf375 100644 --- a/soundcalc/custom/swirl/calculator.py +++ b/soundcalc/custom/swirl/calculator.py @@ -274,7 +274,12 @@ def calculate_swirl_soundness( ) -> SWIRLSoundnessResult: challenge_field_bits = _challenge_field_bits(field) regime = params.whir.proximity.build_regime(field) - mu_batching_bits = -math.log2(whir._get_batching_error(regime)) + whir_errors = whir.get_pcs_soundness_errors(regime) + + mu_batching_bits = math.inf + if whir_errors.batching_error is not None: + mu_batching_bits = -math.log2(whir_errors.batching_error) + initial_list_size = whir._get_list_size_for_iteration_and_round(0, 0, regime) log2_list_size = math.log2(initial_list_size) @@ -308,57 +313,54 @@ def calculate_swirl_soundness( stacked_multilinear_bits, ) - min_query_bits = math.inf - min_proximity_gaps_bits = math.inf - min_sumcheck_bits = math.inf - min_ood_bits = math.inf - min_gamma_batching_bits = math.inf - min_fold_rbr_bits = math.inf - min_shift_rbr_bits = math.inf - min_whir_bits = mu_batching_bits - - for round_index, round_config in enumerate(params.whir.rounds): - current_list_size = whir._get_list_size_for_iteration_and_round(round_index, 0, regime) - sumcheck_bits = _whir_sumcheck_security( + query_bits_by_round = tuple(-math.log2(error) for error in whir_errors.query_errors) + fold_bits = tuple( + -math.log2(error) + for round_errors in whir_errors.fold_errors + for error in round_errors + ) + ood_bits = tuple(-math.log2(error) for error in whir_errors.ood_errors) + + sumcheck_bits = tuple( + _whir_sumcheck_security( challenge_field_bits, - current_list_size, + whir._get_list_size_for_iteration_and_round(round_index, 0, regime), params.whir.folding_pow_bits, ) - min_sumcheck_bits = min(min_sumcheck_bits, sumcheck_bits) - - for sub_round in range(params.whir.k): - fold_bits = -math.log2(whir._epsilon_fold(round_index, sub_round + 1, regime)) - min_fold_rbr_bits = min(min_fold_rbr_bits, fold_bits) - min_whir_bits = min(min_whir_bits, fold_bits) - - rate, dimension = whir._get_code_for_iteration_and_round(round_index, sub_round + 1) - proximity_error = regime.get_error_powers(rate, dimension, 2) - proximity_bits = -math.log2(proximity_error) + params.whir.folding_pow_bits - min_proximity_gaps_bits = min(min_proximity_gaps_bits, proximity_bits) - - query_bits = -math.log2(whir._epsilon_query(round_index, regime)) - min_query_bits = min(min_query_bits, query_bits) - - if round_index == whir.num_iterations - 1: - min_shift_rbr_bits = min(min_shift_rbr_bits, query_bits) - min_whir_bits = min(min_whir_bits, query_bits) - continue - - next_list_size = whir._get_list_size_for_iteration_and_round(round_index + 1, 0, regime) - gamma_batching_bits = _whir_gamma_batching_security( + for round_index in range(whir.num_iterations) + ) + proximity_gaps_bits = tuple( + -math.log2(regime.get_error_powers(*whir._get_code_for_iteration_and_round(round_index, sub_round + 1), 2)) + + params.whir.folding_pow_bits + for round_index in range(whir.num_iterations) + for sub_round in range(params.whir.k) + ) + gamma_bits_by_round = tuple( + _whir_gamma_batching_security( challenge_field_bits, - round_config.num_queries + 1, - next_list_size, + whir.num_queries[round_index] + 1, + whir._get_list_size_for_iteration_and_round(round_index + 1, 0, regime), ) - min_gamma_batching_bits = min(min_gamma_batching_bits, gamma_batching_bits) - - shift_rbr_bits = _combine_security_bits(query_bits, gamma_batching_bits) - min_shift_rbr_bits = min(min_shift_rbr_bits, shift_rbr_bits) - min_whir_bits = min(min_whir_bits, shift_rbr_bits) - - ood_bits = -math.log2(whir._epsilon_out(round_index + 1, regime)) - min_ood_bits = min(min_ood_bits, ood_bits) - min_whir_bits = min(min_whir_bits, ood_bits) + for round_index in range(whir.num_iterations - 1) + ) + shift_bits = tuple( + _combine_security_bits(query_bits_by_round[round_index], gamma_bits_by_round[round_index]) + for round_index in range(whir.num_iterations - 1) + ) + (query_bits_by_round[-1],) + + min_query_bits = min(query_bits_by_round) + min_proximity_gaps_bits = min(proximity_gaps_bits) + min_sumcheck_bits = min(sumcheck_bits) + min_ood_bits = min(ood_bits, default=math.inf) + min_gamma_batching_bits = min(gamma_bits_by_round, default=math.inf) + min_fold_rbr_bits = min(fold_bits) + min_shift_rbr_bits = min(shift_bits) + min_whir_bits = min( + mu_batching_bits, + min_fold_rbr_bits, + min_ood_bits, + min_shift_rbr_bits, + ) whir_details = SWIRLWhirDetails( mu_batching_bits=mu_batching_bits, diff --git a/soundcalc/pcs/whir.py b/soundcalc/pcs/whir.py index 389030c..6cc1c16 100644 --- a/soundcalc/pcs/whir.py +++ b/soundcalc/pcs/whir.py @@ -308,6 +308,17 @@ class WHIRConfig: # (This is useful to pin fixed parameters in TOML configs.) gap_to_radius: Optional[float] = None + +@dataclass(frozen=True) +class WHIRSoundnessErrors: + batching_error: Optional[float] + fold_errors: tuple[tuple[float, ...], ...] + query_errors: tuple[float, ...] + ood_errors: tuple[float, ...] + shift_errors: tuple[float, ...] + final_error: float + + class WHIR(PCS): """ WHIR Polynomial Commitment Scheme. @@ -455,19 +466,18 @@ def get_pcs_security_levels(self, regime: ProximityGapsRegime) -> dict[str, int] """ Returns PCS-specific security levels for a given regime. """ + errors = self.get_pcs_soundness_errors(regime) levels: dict[str, int] = {} # add an error from the batching step - if self.batch_size > 1: - epsilon_batch = self._get_batching_error(regime) - levels["batching"] = get_bits_of_security_from_error(epsilon_batch) + if errors.batching_error is not None: + levels["batching"] = get_bits_of_security_from_error(errors.batching_error) # Initial Iteration (i=0) # # Construction 5.1: "1. Initial sumcheck... For l = 1...k0" # This iteration only contains folding (sumcheck), no OOD/Shift. - for round_s in range(1, self.folding_factor + 1): - epsilon = self._epsilon_fold(iteration=0, round=round_s, regime=regime) + for round_s, epsilon in enumerate(errors.fold_errors[0], start=1): levels[f"fold(i=0,s={round_s})"] = get_bits_of_security_from_error(epsilon) # Main Loop (i=1 to M-1) @@ -476,29 +486,67 @@ def get_pcs_security_levels(self, regime: ProximityGapsRegime) -> dict[str, int] # For each iteration i = 1, ... M - 1: OOD errors, shift errors, fold errors for iteration in range(1, self.num_iterations): # out of domain samples - epsilon_ood = self._epsilon_out(iteration, regime) + epsilon_ood = errors.ood_errors[iteration - 1] levels[f"OOD(i={iteration})"] = get_bits_of_security_from_error(epsilon_ood) # shift queries - epsilon_shift = self._epsilon_shift(iteration, regime) + epsilon_shift = errors.shift_errors[iteration - 1] levels[f"Shift(i={iteration})"] = get_bits_of_security_from_error( epsilon_shift ) # sum check (one error for each round) - for round in range(1, self.folding_factor + 1): - epsilon = self._epsilon_fold(iteration, round, regime) - levels[f"fold(i={iteration},s={round})"] = ( + for round_s, epsilon in enumerate(errors.fold_errors[iteration], start=1): + levels[f"fold(i={iteration},s={round_s})"] = ( get_bits_of_security_from_error(epsilon) ) # final error # Construction 5.1: "3. Check final polynomial..." - epsilon_final = self._epsilon_final(regime) - levels["fin"] = get_bits_of_security_from_error(epsilon_final) - + levels["fin"] = get_bits_of_security_from_error(errors.final_error) return levels + def get_pcs_soundness_errors( + self, regime: ProximityGapsRegime + ) -> WHIRSoundnessErrors: + """ + Returns the exact WHIR soundness errors without converting them to integer bits. + """ + + batching_error = None + if self.batch_size > 1: + batching_error = self._get_batching_error(regime) + + fold_errors = tuple( + tuple( + self._epsilon_fold(iteration, round_s, regime) + for round_s in range(1, self.folding_factor + 1) + ) + for iteration in range(self.num_iterations) + ) + query_errors = tuple( + self._epsilon_query(iteration, regime) + for iteration in range(self.num_iterations) + ) + ood_errors = tuple( + self._epsilon_out(iteration, regime) + for iteration in range(1, self.num_iterations) + ) + shift_errors = tuple( + self._epsilon_shift(iteration, regime) + for iteration in range(1, self.num_iterations) + ) + final_error = self._epsilon_final(regime) + + return WHIRSoundnessErrors( + batching_error=batching_error, + fold_errors=fold_errors, + query_errors=query_errors, + ood_errors=ood_errors, + shift_errors=shift_errors, + final_error=final_error, + ) + def _get_code_for_iteration_and_round( self, iteration: int, round: int ) -> tuple[float, int]: