Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "9.25.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand All @@ -17,6 +18,8 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SimpleNonlinearSolve = "727e6d20-b764-4bd8-a329-72de5adea6c7"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

Expand Down Expand Up @@ -54,6 +57,8 @@ RecursiveArrayTools = "3.35, 4"
Reexport = "1.2"
SafeTestsets = "0.1"
SciMLBase = "2.115, 3.1"
Setfield = "1"
SimpleNonlinearSolve = "1, 2"
StableRNGs = "1"
StaticArrays = "1.9.8"
Statistics = "1"
Expand All @@ -63,7 +68,6 @@ Test = "1"
julia = "1.10"

[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
Expand All @@ -78,4 +82,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ADTypes", "Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"]
test = ["Aqua", "ExplicitImports", "FastBroadcast", "LinearAlgebra", "LinearSolve", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "StableRNGs", "Statistics", "StochasticDiffEq", "Test"]
7 changes: 5 additions & 2 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ using StaticArrays: StaticArrays, SVector, setindex
using Base.Threads: Threads
using Base.FastMath: add_fast

using SimpleNonlinearSolve: SimpleNonlinearSolve, SimpleNewtonRaphson
using ADTypes: ADTypes, AutoFiniteDiff

# Import functions we extend from Base
import Base: size, getindex, setindex!, length, similar, show, merge!, merge

Expand All @@ -40,7 +43,7 @@ using DiffEqBase: DiffEqBase, CallbackSet, ContinuousCallback, DAEFunction,
ODESolution, ReturnCode, SDEFunction, SDEProblem, add_tstop!,
deleteat!, isinplace, remake, savevalues!, step!,
u_modified!
using SciMLBase: SciMLBase, DEIntegrator
using SciMLBase: SciMLBase, DEIntegrator, NonlinearProblem

abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
Expand Down Expand Up @@ -131,7 +134,7 @@ export SSAStepper

# leaping:
include("simple_regular_solve.jl")
export SimpleTauLeaping, SimpleExplicitTauLeaping, EnsembleGPUKernel
export SimpleTauLeaping, SimpleExplicitTauLeaping, SimpleAdaptiveTauLeaping, NewtonImplicitSolver, TrapezoidalImplicitSolver, EnsembleGPUKernel

# spatial:
include("spatial/spatial_massaction_jump.jl")
Expand Down
290 changes: 290 additions & 0 deletions src/simple_regular_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@ end

SimpleExplicitTauLeaping(; epsilon = 0.05) = SimpleExplicitTauLeaping(epsilon)

# Define solver type hierarchy
abstract type AbstractImplicitSolver end
struct NewtonImplicitSolver <: AbstractImplicitSolver end
struct TrapezoidalImplicitSolver <: AbstractImplicitSolver end

# Adaptive tau-leaping solver
struct SimpleAdaptiveTauLeaping{T <: AbstractFloat} <: DiffEqBase.DEAlgorithm
epsilon::T # Error control parameter for tau selection
solver::AbstractImplicitSolver # Solver type for implicit method
eigenvalue_check::Bool # Enable eigenvalue-based stiffness detection
stiffness_ratio_threshold::T # # Stiffness ratio threshold
implicit_epsilon_factor::T # Scaling factor for implicit tau-selection
end

# Stiffness detection uses a dynamic threshold epsilon * sum(u) for propensity ratios,
# as inspired by Cao et al. (2007), Section III.B. Optional eigenvalue-based check
# uses the Jacobian's eigenvalue ratio. implicit_epsilon_factor=10.0 relaxes tau-selection
# for implicit tau-leaping, per Cao et al. (2007), Section III.A.
SimpleAdaptiveTauLeaping(; epsilon=0.05, solver=NewtonImplicitSolver(), eigenvalue_check=false, stiffness_ratio_threshold=1e4, implicit_epsilon_factor=10.0) =
SimpleAdaptiveTauLeaping(epsilon, solver, eigenvalue_check, stiffness_ratio_threshold, implicit_epsilon_factor)

function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
Expand Down Expand Up @@ -69,6 +90,20 @@ function _process_saveat(saveat, tspan, save_start, save_end)
return saveat_vec, _save_start, _save_end
end

# Validation for adaptive tau-leaping
function validate_pure_leaping_inputs(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping)
if !(jump_prob.aggregator isa PureLeaping)
@warn "When using $alg, please pass PureLeaping() as the aggregator to the \
JumpProblem, i.e. call JumpProblem(::DiscreteProblem, PureLeaping(),...). \
Passing $(jump_prob.aggregator) is deprecated and will be removed in the next breaking release."
end
isempty(jump_prob.jump_callback.continuous_callbacks) &&
isempty(jump_prob.jump_callback.discrete_callbacks) &&
isempty(jump_prob.constant_jumps) &&
isempty(jump_prob.variable_jumps) &&
jump_prob.massaction_jump !== nothing
end

function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleTauLeaping;
seed = nothing, dt = error("dt is required for SimpleTauLeaping."),
saveat = nothing, save_start = nothing, save_end = nothing)
Expand Down Expand Up @@ -405,6 +440,261 @@ function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleExplicitTauLeaping;
return sol
end


# Compute tau for implicit tau-leaping with relaxed error control
# Reference: Cao et al. (2007), J. Chem. Phys. 126, 224101, Section III.A
function compute_tau_implicit(u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin,
max_hor, max_stoich, numjumps, implicit_epsilon_factor)
tau_explicit = compute_tau(
u, rate_cache, nu, hor, p, t, epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
u_predict = float.(u)
rate(rate_cache, u, p, t)
for j in 1:numjumps
for spec_idx in 1:size(nu, 1)
u_predict[spec_idx] += nu[spec_idx, j] * rate_cache[j] * tau_explicit
end
end
u_predict .= max.(u_predict, zero(eltype(u_predict)))
relaxed_epsilon = epsilon * implicit_epsilon_factor
tau = compute_tau(u_predict, rate_cache, nu, hor, p, t + tau_explicit,
relaxed_epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
return max(tau, dtmin)
end

# Define residual for implicit equation
# Newton: u_new = u_current + sum_j nu_j * a_j(u_new) * tau (Cao et al., 2004)
# Trapezoidal: u_new = u_current + sum_j nu_j * (a_j(u_current) + a_j(u_new))/2 * tau
function implicit_equation!(resid, u_new, params)
u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver = params
rate(rate_cache, u_new, p, t + tau)
resid .= u_new .- u_current
if isa(solver, NewtonImplicitSolver)
for j in 1:numjumps
for spec_idx in 1:size(nu, 1)
resid[spec_idx] -= nu[spec_idx, j] * rate_cache[j] * tau # Cao et al. (2004)
end
end
else # TrapezoidalImplicitSolver
rate_current = similar(rate_cache)
rate(rate_current, u_current, p, t)
half = one(eltype(rate_cache)) / 2
for j in 1:numjumps
for spec_idx in 1:size(nu, 1)
resid[spec_idx] -= nu[spec_idx, j] * half * (rate_cache[j] + rate_current[j]) * tau
end
end
end
resid .= max.(resid, -u_new) # Ensure non-negative solution
end

# Solve implicit equation using SimpleNonlinearSolve
function solve_implicit(u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver)
u_new = convert(Vector{float(eltype(u_current))}, u_current)
prob = NonlinearProblem(implicit_equation!, u_new, (u_current, rate_cache, nu, p, t, tau, rate, numjumps, solver))
sol = solve(prob, SimpleNewtonRaphson(autodiff=AutoFiniteDiff()); abstol=1e-6, reltol=1e-6)
return sol.u, sol.retcode == ReturnCode.Success
end


# Compute Jacobian for eigenvalue-based stiffness detection
# Reference: Cao et al. (2007), Section III.B
function compute_jacobian(u, rate, numjumps, numspecies, p, t)
T = float(eltype(u))
sqrteps = sqrt(eps(T))
J = zeros(T, numjumps, numspecies)
rate_cache = zeros(T, numjumps)
rate(rate_cache, u, p, t)
rate_plus = zeros(T, numjumps)
u_plus = float.(copy(u))
for i in 1:numspecies
h_i = sqrteps * max(abs(u[i]), one(T))
u_plus[i] = u[i] + h_i
rate(rate_plus, u_plus, p, t)
for j in 1:numjumps
J[j, i] = (rate_plus[j] - rate_cache[j]) / h_i
end
u_plus[i] = u[i]
end
return J
end

# Stiffness detection using propensity ratio or eigenvalues
# Reference: Cao et al. (2007), Section III.B
function is_stiff(rate_cache, u, epsilon, eigenvalue_check, stiffness_ratio_threshold, p, t, rate, numjumps, numspecies)
non_zero_rates = [rate for rate in rate_cache if rate > 0]
if length(non_zero_rates) <= 1
return false
end
if eigenvalue_check
J = compute_jacobian(u, rate, numjumps, numspecies, p, t)
eigvals = real.(LinearAlgebra.eigvals(J))
non_zero_eigvals = [abs(λ) for λ in eigvals if abs(λ) > 1e-10]
if length(non_zero_eigvals) <= 1
return false
end
max_eig = maximum(non_zero_eigvals)
min_eig = minimum(non_zero_eigvals)
return max_eig / min_eig > stiffness_ratio_threshold # Stiffness ratio threshold, Petzold (1983), SIAM J. Sci. Stat. Comput. 4(1), 136–148
else
max_rate = maximum(non_zero_rates)
min_rate = minimum(non_zero_rates)
threshold = epsilon * sum(u)
return max_rate / min_rate > threshold # Propensity ratio threshold, Cao et al. (2007), J. Chem. Phys. 126, 224101, Section III.B
end
end

function simple_adaptive_tau_leaping_loop!(
prob, alg, u_current, u_new, t_current, t_end, p, rng,
rate, nu, hor, max_hor, max_stoich, numjumps, numspecies, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj,
solver, eigenvalue_check, stiffness_ratio_threshold, implicit_epsilon_factor,
save_end)
save_idx = 1

while t_current < t_end
rate(rate_cache, u_current, p, t_current)
if all(<=(0), rate_cache)
t_current = t_end
break
end
use_implicit = is_stiff(rate_cache, u_current, epsilon, eigenvalue_check,
stiffness_ratio_threshold, p, t_current, rate, numjumps, numspecies)
tau = if use_implicit
compute_tau_implicit(u_current, rate_cache, nu, hor, p, t_current,
epsilon, rate, dtmin, max_hor, max_stoich, numjumps,
implicit_epsilon_factor)
else
compute_tau(u_current, rate_cache, nu, hor, p, t_current,
epsilon, rate, dtmin, max_hor, max_stoich, numjumps)
end
tau = min(tau, t_end - t_current)
if !isempty(saveat_times) && save_idx <= length(saveat_times) &&
t_current + tau > saveat_times[save_idx]
tau = saveat_times[save_idx] - t_current
end

if use_implicit
u_new_float, converged = solve_implicit(u_current, rate_cache, nu, p,
t_current, tau, rate, numjumps, solver)
if !converged
tau /= 2
continue
end
rate(rate_cache, u_new_float, p, t_current + tau)
end

rate_effective .= max.(rate_cache .* tau, zero(eltype(rate_cache)))
for j in eachindex(counts)
if rate_effective[j] <= zero(eltype(rate_effective))
counts[j] = zero(eltype(counts))
else
counts[j] = pois_rand(rng, rate_effective[j])
end
end
du .= zero(eltype(du))
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
du[spec_idx] += stoch * counts[j]
end
end
u_new .= u_current .+ du
if any(<(0), u_new)
tau /= 2
continue
end
t_new = t_current + tau

if isempty(saveat_times) ||
(save_idx <= length(saveat_times) && t_new >= saveat_times[save_idx])
push!(usave, copy(u_new))
push!(tsave, t_new)
if !isempty(saveat_times) && t_new >= saveat_times[save_idx]
save_idx += 1
end
end

u_current .= u_new
t_current = t_new
end

if save_end && (isempty(tsave) || tsave[end] != t_end)
push!(usave, copy(u_current))
push!(tsave, t_end)
end
end

# Adaptive tau-leaping solver
# Reference: Cao et al. (2007), Cao et al. (2004), Cao et al. (2006)
function DiffEqBase.solve(jump_prob::JumpProblem, alg::SimpleAdaptiveTauLeaping;
seed = nothing,
dtmin = nothing,
saveat = nothing, save_start = nothing, save_end = nothing)
validate_pure_leaping_inputs(jump_prob, alg) ||
error("SimpleAdaptiveTauLeaping can only be used with PureLeaping JumpProblem with a MassActionJump.")

(; prob, rng) = jump_prob
(seed !== nothing) && seed!(rng, seed)

maj = jump_prob.massaction_jump
numjumps = get_num_majumps(maj)
rj = jump_prob.regular_jump
rate = rj !== nothing ? rj.rate : massaction_rate(maj, numjumps)
u0 = copy(prob.u0)
tspan = prob.tspan
p = prob.p

if dtmin === nothing
dtmin = 1e-10 * one(typeof(tspan[2]))
end

saveat_times, save_start, save_end = _process_saveat(saveat, tspan, save_start, save_end)

u_current = copy(u0)
u_new = similar(u0)
t_current = tspan[1]
if save_start
usave = [copy(u0)]
tsave = [tspan[1]]
else
usave = typeof(u0)[]
tsave = typeof(tspan[1])[]
end
rate_cache = zeros(float(eltype(u0)), numjumps)
rate_effective = similar(rate_cache)
counts = zero(rate_cache)
du = similar(u0)
t_end = tspan[2]
epsilon = alg.epsilon
solver = alg.solver
eigenvalue_check = alg.eigenvalue_check
stiffness_ratio_threshold = alg.stiffness_ratio_threshold
implicit_epsilon_factor = alg.implicit_epsilon_factor
numspecies = length(u0)

nu = zeros(float(eltype(u0)), length(u0), numjumps)
for j in 1:numjumps
for (spec_idx, stoch) in maj.net_stoch[j]
nu[spec_idx, j] = stoch
end
end
reactant_stoch = maj.reactant_stoch
hor = compute_hor(reactant_stoch, numjumps)
max_hor, max_stoich = precompute_reaction_conditions(
reactant_stoch, hor, numspecies, numjumps)

simple_adaptive_tau_leaping_loop!(
prob, alg, u_current, u_new, t_current, t_end, p, rng,
rate, nu, hor, max_hor, max_stoich, numjumps, numspecies, epsilon,
dtmin, saveat_times, usave, tsave, du, counts, rate_cache, rate_effective, maj,
solver, eigenvalue_check, stiffness_ratio_threshold, implicit_epsilon_factor,
save_end)

sol = DiffEqBase.build_solution(prob, alg, tsave, usave,
calculate_error = false,
interp = DiffEqBase.ConstantInterpolation(tsave, usave))
return sol
end

struct EnsembleGPUKernel{Backend} <: SciMLBase.EnsembleAlgorithm
backend::Backend
cpu_offload::Float64
Expand Down
Loading
Loading