From e1768850e33b0f551673641340dbc59154eb2a32 Mon Sep 17 00:00:00 2001 From: Jonas Rembser Date: Wed, 15 Apr 2026 21:32:54 +0200 Subject: [PATCH 1/4] [tmva][sofie] Implement `Gemm_Call_pullback` in terms of `Gemm_Call` This ensures that we use a consistent GEMM functions by wrapping the GEMM call in the pullback also with the wrapper. --- math/mathcore/inc/Math/CladDerivator.h | 35 +++++++++----------------- 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/math/mathcore/inc/Math/CladDerivator.h b/math/mathcore/inc/Math/CladDerivator.h index d5b46e444c996..7d547aef1f333 100644 --- a/math/mathcore/inc/Math/CladDerivator.h +++ b/math/mathcore/inc/Math/CladDerivator.h @@ -1114,45 +1114,34 @@ inline void Gemm_Call_pullback(float *output, bool transa, bool transb, int m, i bool *, int *, int *, int *, float *_d_alpha, float *_d_A, float *_d_B, float *_d_beta, float *_d_C) { + using ::TMVA::Experimental::SOFIE::Gemm_Call; + // TODO: // - fix and test the implementation for alpha != 1.0 if (alpha != 1.0f) { return; } - char ct = 't'; - char cn = 'n'; - // beta needs to be one because we want to add to _d_A and _d_B instead of // overwriting it. float one = 1.; - // Leading dimensions for the original storage (must match how sgemm_ is called in the primal) - const int lda_opA = transa ? k : m; // lda used with transa flag as in primal - const int ldb_opB = transb ? n : k; // ldb used with transb flag as in primal - - // Flags for op(A), op(B) - const char TA = transa ? ct : cn; - const char TB = transb ? ct : cn; - - // Flags for op(A)^T and op(B)^T - const char TAT = transa ? cn : ct; // (A^T)^T = A, A^T if A - const char TBT = transb ? cn : ct; // (B^T)^T = B, B^T if B - + // ---- dA ---- if (!transa) { - // dA += alpha * dY * op(B)^T (m x n) * (n x k) -> (m x k) - ::sgemm_(&cn, &TBT, &m, &k, &n, &alpha, _d_output, &m, B, &ldb_opB, &one, _d_A, &m); + // dA += dY * op(B)^T + Gemm_Call(_d_A, false, !transb, m, k, n, one, _d_output, B, one, _d_A); } else { - // dA (shape k x m) += alpha * op(B) * dY^T (k x n) * (n x m) -> (k x m) - ::sgemm_(&TB, &ct, &k, &m, &n, &alpha, B, &ldb_opB, _d_output, &m, &one, _d_A, &k); + // dA += op(B) * dY^T + Gemm_Call(_d_A, transb, true, k, m, n, one, B, _d_output, one, _d_A); } + // ---- dB ---- if (!transb) { - // dB += alpha * op(A)^T * dY (k x m) * (m x n) -> (k x n) - ::sgemm_(&TAT, &cn, &k, &n, &m, &alpha, A, &lda_opA, _d_output, &m, &one, _d_B, &k); + // dB += op(A)^T * dY + Gemm_Call(_d_B, !transa, false, k, n, m, one, A, _d_output, one, _d_B); } else { - // dB (shape n x k) += alpha * dY^T * op(A) (n x m) * (m x k) -> (n x k) - ::sgemm_(&ct, &TA, &n, &k, &m, &alpha, _d_output, &m, A, &lda_opA, &one, _d_B, &n); + // dB += dY^T * op(A) + Gemm_Call(_d_B, true, transa, n, k, m, one, _d_output, A, one, _d_B); } int sizeC = n * m; From f73bf827d66bc39934150a6fdfa5713ff477dfe3 Mon Sep 17 00:00:00 2001 From: Jonas Rembser Date: Thu, 16 Apr 2026 01:57:31 +0200 Subject: [PATCH 2/4] [CMake] Also consider fixture arguments in `ROOT_ADD_GTEST` This is helpful if such a test depends on input files that are created by another test. --- cmake/modules/RootMacros.cmake | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/cmake/modules/RootMacros.cmake b/cmake/modules/RootMacros.cmake index 9e1dfce102fef..3365f45ff0158 100644 --- a/cmake/modules/RootMacros.cmake +++ b/cmake/modules/RootMacros.cmake @@ -1932,6 +1932,7 @@ endfunction() # [TIMEOUT seconds] # [COPY_TO_BUILDDIR file1 file2] Copy listed files when ctest invokes the test. # [LIBRARIES lib1 lib2...] -- Libraries to link against +# [FIXTURES_SETUP ...] [FIXTURES_CLEANUP ...] [FIXTURES_REQUIRED ...] # [LABELS label1 label2...] -- Labels to annotate the test # [INCLUDE_DIRS label1 label2...] -- Extra target include directories # [REPEATS number] -- Repeats testsuite `number` times, stopping at the first failure. @@ -1943,7 +1944,7 @@ function(ROOT_ADD_GTEST test_suite) cmake_parse_arguments(ARG "WILLFAIL" "TIMEOUT;REPEATS;FAILREGEX" - "COPY_TO_BUILDDIR;LIBRARIES;LABELS;INCLUDE_DIRS;ENVIRONMENT" ${ARGN}) + "COPY_TO_BUILDDIR;LIBRARIES;LABELS;FIXTURES_SETUP;FIXTURES_CLEANUP;FIXTURES_REQUIRED;INCLUDE_DIRS;ENVIRONMENT" ${ARGN}) ROOT_GET_SOURCES(source_files . ${ARG_UNPARSED_ARGUMENTS}) # Note we cannot use ROOT_EXECUTABLE without user-specified set of LIBRARIES to link with. @@ -1979,6 +1980,18 @@ function(ROOT_ADD_GTEST test_suite) set(extra_command --gtest_repeat=${ARG_REPEATS} --gtest_break_on_failure) endif() + if (ARG_FIXTURES_SETUP) + set(fixtures_setup ${ARG_FIXTURES_SETUP}) + endif() + + if (ARG_FIXTURES_CLEANUP) + set(fixtures_cleanup ${ARG_FIXTURES_CLEANUP}) + endif() + + if (ARG_FIXTURES_REQUIRED) + set(fixtures_required ${ARG_FIXTURES_REQUIRED}) + endif() + ROOT_PATH_TO_STRING(name_with_path ${test_suite} PATH_SEPARATOR_REPLACEMENT "-") string(REPLACE "-test-" "-" clean_name_with_path ${name_with_path}) ROOT_ADD_TEST( @@ -1989,6 +2002,9 @@ function(ROOT_ADD_GTEST test_suite) ${willfail} TIMEOUT "${ARG_TIMEOUT}" LABELS "${ARG_LABELS}" + FIXTURES_SETUP ${fixtures_setup} + FIXTURES_CLEANUP ${fixtures_cleanup} + FIXTURES_REQUIRED ${fixtures_required} FAILREGEX "${ARG_FAILREGEX}" ENVIRONMENT "${ARG_ENVIRONMENT}" ) From e162a9d5d7b67b276cb19c9e14a32537b5dca5d8 Mon Sep 17 00:00:00 2001 From: Jonas Rembser Date: Fri, 17 Apr 2026 00:10:31 +0200 Subject: [PATCH 3/4] [tmva][sofie] Emit input tensor meta info in generated code Add some `constexpr` info in input tensor structure to the generated code, inside the same namespace as the session struct. It looks for example like this: ```c++ constexpr std::array dim_start{SingleDim{1}}; constexpr std::array dim_limit{SingleDim{1}}; constexpr std::array dim_delta{SingleDim{1}}; constexpr std::array inputTensorDims{ makeDims(dim_start), makeDims(dim_limit), makeDims(dim_delta) }; ``` This is helpful for validating inputs when using the SOFIE-generated code without knowing much about about the original ONNX file, and is the case for the use in RooFit. --- tmva/sofie/inc/TMVA/RModel.hxx | 1 + tmva/sofie/inc/TMVA/SOFIE_common.hxx | 36 +++++++++++++++++++++ tmva/sofie/src/RModel.cxx | 48 ++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/tmva/sofie/inc/TMVA/RModel.hxx b/tmva/sofie/inc/TMVA/RModel.hxx index ad28209de0c0d..ec4e1115b759d 100644 --- a/tmva/sofie/inc/TMVA/RModel.hxx +++ b/tmva/sofie/inc/TMVA/RModel.hxx @@ -197,6 +197,7 @@ protected: void GenerateSessionCode(); bool IsInputTensorShapeParam(std::string const &name) const; std::vector CollectTensorMemberNames(const std::string &input); + void GenerateRequiredInputTensorInfo(); public: const std::vector & GetInputTensorNames() const { return fInputTensorNames; } diff --git a/tmva/sofie/inc/TMVA/SOFIE_common.hxx b/tmva/sofie/inc/TMVA/SOFIE_common.hxx index 9b04109c21384..9f35cca5f7db3 100644 --- a/tmva/sofie/inc/TMVA/SOFIE_common.hxx +++ b/tmva/sofie/inc/TMVA/SOFIE_common.hxx @@ -844,6 +844,42 @@ struct MemoryResult { /// Greedy best-fit planner with coalescing free list. MemoryResult OrganizeMemory(const std::vector & tensorsInfo ); +// Simple Dimension classes ans helpers to add constexpr meta info on input +// tensors to the emitted code. +struct SingleDim { + enum class Kind { + Static, + Symbolic + }; + + Kind kind; + std::size_t dim; + std::string_view name; + + constexpr SingleDim(std::size_t v) : kind(Kind::Static), dim(v), name() {} + constexpr SingleDim(const char *v) : kind(Kind::Symbolic), dim(0), name(v) {} +}; + +struct TensorDims { + const SingleDim *data; + std::size_t size; + + constexpr std::size_t total_size() const + { + std::size_t result = 1; + for (std::size_t i = 0; i < size; ++i) { + result *= data[i].dim; + } + return result; + } +}; + +template +constexpr TensorDims makeDims(Arr const &arr) +{ + return TensorDims{arr.data(), arr.size()}; +} + } // namespace SOFIE } // namespace Experimental } // namespace TMVA diff --git a/tmva/sofie/src/RModel.cxx b/tmva/sofie/src/RModel.cxx index 738b5066cebf8..822b8ae2f4e32 100644 --- a/tmva/sofie/src/RModel.cxx +++ b/tmva/sofie/src/RModel.cxx @@ -1364,6 +1364,8 @@ void RModel::GenerateSessionCode() // end of session if (fUseSession && !fIsGNNComponent) { fGC += "}; // end of Session\n\n"; + + GenerateRequiredInputTensorInfo(); } fGC += doInferSignature + " {\n"; @@ -1691,6 +1693,52 @@ void RModel::PrintSummary() const { } } +/// To emit the dimensions of the input tensors as a data member of a session, +/// which is helpful when validating the inference inputs. +void RModel::GenerateRequiredInputTensorInfo() +{ + fGC += "\n// Input tensor dimensions\n"; + fGC += "using TMVA::Experimental::SOFIE::SingleDim;\n"; + fGC += "using TMVA::Experimental::SOFIE::TensorDims;\n"; + fGC += "using TMVA::Experimental::SOFIE::makeDims;\n\n"; + bool hasDynamicInputTensors = false; + + for (std::size_t iInput = 0; iInput < fInputTensorNames.size(); ++iInput) { + auto const &name = fInputTensorNames[iInput]; + if (IsDimInputTensor(name)) { + hasDynamicInputTensors = true; + } + std::vector shape = GetDimTensorShape(name); + fGC += "constexpr std::array dim_" + name + "{"; + for (std::size_t iDim = 0; iDim < shape.size(); ++iDim) { + auto const &dim = shape[iDim]; + if (dim.isParam) { + fGC += "SingleDim{\"" + dim.GetVal() + "\"}"; + } else { + fGC += "SingleDim{" + dim.GetVal() + "}"; + } + if (iDim != shape.size() - 1) { + fGC += ", "; + } + } + fGC += "};\n"; + } + fGC += "\nconstexpr std::array inputTensorDims{\n"; + for (std::size_t iInput = 0; iInput < fInputTensorNames.size(); ++iInput) { + auto const &name = fInputTensorNames[iInput]; + fGC += SP + "makeDims(dim_" + name + ")"; + if (iInput == fInputTensorNames.size() - 1) { + fGC += "\n"; + } else { + fGC += ",\n"; + } + } + fGC += "};\n"; + + fGC += + "\nconstexpr bool hasDynamicInputTensors{" + std::string{hasDynamicInputTensors ? "true" : "false"} + "};\n\n"; +} + void RModel::PrintRequiredInputTensors() const { std::cout << "Model requires following inputs:\n"; for (auto& inputInfo: fInputTensorInfos) { From 65eaa9e0ab51850dfd41184689297ed1954fbf3d Mon Sep 17 00:00:00 2001 From: Jonas Rembser Date: Thu, 16 Apr 2026 01:58:52 +0200 Subject: [PATCH 4/4] [RF] Add RooONNXFunction: ONNX-backed neural function support for RooFit Introduce `RooONNXFunction`, a new `RooAbsReal` implementation that enables native inference of ONNX models within RooFit. The class loads ONNX graphs, JIT-compiles them via TMVA SOFIE at runtime, and evaluates them efficiently with support for automatic differentiation through Clad. Key features: * Seamless integration of parametric neural networks into RooFit workflows * Runtime ONNX-to-C++ code generation via SOFIE (no hard dependency at link time, the SOFIE usage is an implementation detail) * Support for analytic gradients (codegen + Clad) * Support for serialization to RooWorkspace by embedding the ONNX payload as a binary blob A unit test is also implemented. This development unlocks new use cases such as neural simulation-based inference (SBI), likelihood surrogate models, and ML-driven parametric models. This commit addresses an item on the ROOT plan of work 2026. --- README/ReleaseNotes/v640/index.md | 18 + roofit/codegen/inc/RooFit/CodegenImpl.h | 4 +- roofit/codegen/src/CodegenImpl.cxx | 16 + roofit/roofit/CMakeLists.txt | 36 +- roofit/roofit/inc/LinkDef1.h | 2 + roofit/roofit/inc/RooONNXFunction.h | 88 +++++ roofit/roofit/src/RooONNXFunction.cxx | 379 +++++++++++++++++++++ roofit/roofit/test/CMakeLists.txt | 15 + roofit/roofit/test/create_onnx_model.py | 108 ++++++ roofit/roofit/test/testRooONNXFunction.cxx | 134 ++++++++ 10 files changed, 782 insertions(+), 18 deletions(-) create mode 100644 roofit/roofit/inc/RooONNXFunction.h create mode 100644 roofit/roofit/src/RooONNXFunction.cxx create mode 100644 roofit/roofit/test/create_onnx_model.py create mode 100644 roofit/roofit/test/testRooONNXFunction.cxx diff --git a/README/ReleaseNotes/v640/index.md b/README/ReleaseNotes/v640/index.md index 699c3a6e49fa3..97d4f1527b3fd 100644 --- a/README/ReleaseNotes/v640/index.md +++ b/README/ReleaseNotes/v640/index.md @@ -239,6 +239,8 @@ This is new and efficient bracketing root-finding algorithm. It combines bisecti ## RooFit +### General changes + - A new RooAbsPdf has been added: `RooStudentT`, which describes the location-scale student's t-distribution. - The `RooNumber::setRangeEpsRel()` and `RooNumber::setRangeEpsAbs()` have been introduced 2 years ago in 48637270a9113aa to customize range check behavior @@ -249,6 +251,22 @@ This is new and efficient bracketing root-finding algorithm. It combines bisecti - The constructors of **RooDataSet** and **RooDataHist** that combine datasets via `Index()` and `Import()` now validate that the import names correspond to existing states of the index category. If an imported data slice refers to a category label that is not defined in the index category, the constructor now throws an error. Previously, such labels were silently added as new category states, which could lead to inconsistent datasets when the state names were not synchronized with the model definition. This change prevents the creation of invalid combined datasets and surfaces configuration problems earlier. +### ONNX model integration via RooONNXFunction + +A new class `RooONNXFunction` has been introduced to enable the use of machine learning models in ONNX format directly within RooFit workflows. + +`RooONNXFunction` wraps an ONNX model as a `RooAbsReal`, allowing it to be used as a building block in likelihoods, fits, and statistical analyses without additional boilerplate code. The class supports models with one or more statically-shaped input tensors and a single scalar output. +The class was designed to share workspaces with neural functions for combined fits in RooFit-based frameworks written in C++. +Therefore, the `RooONNXFunction` doesn't depend on any Python packages and fully supports ROOT IO, + +**Key features:** + + * **Compiled inference via TMVA SOFIE:** The ONNX graph is translated into optimized C++ code at runtime using SOFIE, avoiding external runtime dependencies. + + * **Automatic differentiation with Clad:** Gradients of the model output with respect to RooFit parameters are generated automatically for efficient gradient-based minimization with RooFits `"codegen"` backend. + + * **Portable serialization:** The ONNX model is stored as part of the `RooONNXFunction` object and serialized with ROOT I/O. Upon reading a workspace, the inference code is regenerated automatically. + ### Deprecation of the the constant term optimization for legacy test statistic classes The **RooFit::Optimize()** option (constant term optimization) has been deprecated and will be removed in ROOT 6.42. diff --git a/roofit/codegen/inc/RooFit/CodegenImpl.h b/roofit/codegen/inc/RooFit/CodegenImpl.h index 8f253348ea74f..0db4591e6de9d 100644 --- a/roofit/codegen/inc/RooFit/CodegenImpl.h +++ b/roofit/codegen/inc/RooFit/CodegenImpl.h @@ -48,6 +48,7 @@ class RooLandau; class RooLognormal; class RooMultiPdf; class RooMultiVarGaussian; +class RooONNXFunction; class RooParamHistFunc; class RooPoisson; class RooPolyVar; @@ -111,12 +112,13 @@ void codegenImpl(RooHistFunc &arg, CodegenContext &ctx); void codegenImpl(RooHistPdf &arg, CodegenContext &ctx); void codegenImpl(RooLandau &arg, CodegenContext &ctx); void codegenImpl(RooLognormal &arg, CodegenContext &ctx); +void codegenImpl(RooMultiPdf &arg, CodegenContext &ctx); void codegenImpl(RooMultiVarGaussian &arg, CodegenContext &ctx); +void codegenImpl(RooONNXFunction &arg, CodegenContext &ctx); void codegenImpl(RooParamHistFunc &arg, CodegenContext &ctx); void codegenImpl(RooPoisson &arg, CodegenContext &ctx); void codegenImpl(RooPolyVar &arg, CodegenContext &ctx); void codegenImpl(RooPolynomial &arg, CodegenContext &ctx); -void codegenImpl(RooMultiPdf &arg, CodegenContext &ctx); void codegenImpl(RooProduct &arg, CodegenContext &ctx); void codegenImpl(RooRatio &arg, CodegenContext &ctx); void codegenImpl(RooRealIntegral &arg, CodegenContext &ctx); diff --git a/roofit/codegen/src/CodegenImpl.cxx b/roofit/codegen/src/CodegenImpl.cxx index 4fd12803b13b6..a103f94604bc8 100644 --- a/roofit/codegen/src/CodegenImpl.cxx +++ b/roofit/codegen/src/CodegenImpl.cxx @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -946,6 +947,21 @@ std::string codegenIntegralImpl(RooMultiVarGaussian &arg, int code, const char * return doubleToString(arg.analyticalIntegral(code, rangeName)); } +void codegenImpl(RooONNXFunction &arg, CodegenContext &ctx) +{ + std::stringstream ss; + ss << arg.outerWrapperName() << "("; + for (std::size_t i = 0; i < arg.nInputTensors(); ++i) { + ss << ctx.buildArg(arg.inputTensorList(i)) << std::endl; + if (i != arg.nInputTensors() - 1) { + ss << ", "; + } + } + ss << ")"; + + ctx.addResult(&arg, ss.str()); +} + std::string codegenIntegralImpl(RooPoisson &arg, int code, const char *rangeName, CodegenContext &ctx) { assert(code == 1 || code == 2); diff --git a/roofit/roofit/CMakeLists.txt b/roofit/roofit/CMakeLists.txt index 9acc215897036..be43bd46908a8 100644 --- a/roofit/roofit/CMakeLists.txt +++ b/roofit/roofit/CMakeLists.txt @@ -21,14 +21,13 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFit RooBCPEffDecay.h RooBCPGenDecay.h RooBDecay.h + RooBMixDecay.h RooBernstein.h RooBifurGauss.h RooBlindTools.h - RooBMixDecay.h RooBreitWigner.h RooBukinPdf.h RooCBShape.h - RooCrystalBall.h RooCFunction1Binding.h RooCFunction2Binding.h RooCFunction3Binding.h @@ -36,37 +35,40 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFit RooChebychev.h RooChi2MCSModule.h RooChiSquarePdf.h + RooCrystalBall.h RooDecay.h RooDstD0BG.h RooExponential.h - RooLegacyExpPoly.h - RooPowerSum.h RooFunctor1DBinding.h RooFunctorBinding.h + RooGExpModel.h RooGamma.h RooGaussExpTails.h - RooGaussian.h RooGaussModel.h - RooGExpModel.h + RooGaussian.h RooHistConstraint.h RooIntegralMorph.h RooJeffreysPrior.h + RooJohnson.h RooKeysPdf.h RooLagrangianMorphFunc.h RooLandau.h + RooLegacyExpPoly.h RooLognormal.h RooMathCoreReg.h + RooMomentMorph.h RooMomentMorphFunc.h RooMomentMorphFuncND.h - RooMomentMorph.h RooMultiBinomial.h RooNDKeysPdf.h RooNonCPEigenDecay.h RooNovosibirsk.h - RooParametricStepFunction.h + RooONNXFunction.h RooParamHistFunc.h + RooParametricStepFunction.h RooPoisson.h RooPolynomial.h + RooPowerSum.h RooPyBind.h RooSpline.h RooStepFunction.h @@ -80,21 +82,19 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFit RooUnblindUniform.h RooUniform.h RooVoigtian.h - RooJohnson.h SOURCES src/Roo2DKeysPdf.cxx src/RooArgusBG.cxx src/RooBCPEffDecay.cxx src/RooBCPGenDecay.cxx src/RooBDecay.cxx + src/RooBMixDecay.cxx src/RooBernstein.cxx src/RooBifurGauss.cxx src/RooBlindTools.cxx - src/RooBMixDecay.cxx src/RooBreitWigner.cxx src/RooBukinPdf.cxx src/RooCBShape.cxx - src/RooCrystalBall.cxx src/RooCFunction1Binding.cxx src/RooCFunction2Binding.cxx src/RooCFunction3Binding.cxx @@ -102,24 +102,25 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFit src/RooChebychev.cxx src/RooChi2MCSModule.cxx src/RooChiSquarePdf.cxx + src/RooCrystalBall.cxx src/RooDecay.cxx src/RooDstD0BG.cxx src/RooExponential.cxx - src/RooLegacyExpPoly.cxx - src/RooPowerSum.cxx src/RooFunctor1DBinding.cxx src/RooFunctorBinding.cxx + src/RooGExpModel.cxx src/RooGamma.cxx src/RooGaussExpTails.cxx - src/RooGaussian.cxx src/RooGaussModel.cxx - src/RooGExpModel.cxx + src/RooGaussian.cxx src/RooHistConstraint.cxx src/RooIntegralMorph.cxx src/RooJeffreysPrior.cxx + src/RooJohnson.cxx src/RooKeysPdf.cxx src/RooLagrangianMorphFunc.cxx src/RooLandau.cxx + src/RooLegacyExpPoly.cxx src/RooLognormal.cxx src/RooMathCoreReg.cxx src/RooMomentMorph.cxx @@ -129,10 +130,12 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFit src/RooNDKeysPdf.cxx src/RooNonCPEigenDecay.cxx src/RooNovosibirsk.cxx - src/RooParametricStepFunction.cxx + src/RooONNXFunction.cxx src/RooParamHistFunc.cxx + src/RooParametricStepFunction.cxx src/RooPoisson.cxx src/RooPolynomial.cxx + src/RooPowerSum.cxx src/RooSpline.cxx src/RooStepFunction.cxx src/RooStudentT.cxx @@ -145,7 +148,6 @@ ROOT_STANDARD_LIBRARY_PACKAGE(RooFit src/RooUnblindUniform.cxx src/RooUniform.cxx src/RooVoigtian.cxx - src/RooJohnson.cxx DICTIONARY_OPTIONS "-writeEmptyRootPCM" LINKDEF diff --git a/roofit/roofit/inc/LinkDef1.h b/roofit/roofit/inc/LinkDef1.h index b808304c99b15..4a6c67192ec12 100644 --- a/roofit/roofit/inc/LinkDef1.h +++ b/roofit/roofit/inc/LinkDef1.h @@ -77,6 +77,8 @@ _nonInterfering.back().emplace_back(arg->GetName()); \ } \ } }"; +#pragma link C++ class RooONNXFunction+ ; + #pragma link C++ class RooFunctorBinding+ ; #pragma link C++ class RooFunctor1DBinding+ ; #pragma link C++ class RooFunctorPdfBinding+ ; diff --git a/roofit/roofit/inc/RooONNXFunction.h b/roofit/roofit/inc/RooONNXFunction.h new file mode 100644 index 0000000000000..278bfdfb18719 --- /dev/null +++ b/roofit/roofit/inc/RooONNXFunction.h @@ -0,0 +1,88 @@ +/* + * Project: RooFit + * Authors: + * Jonas Rembser, CERN 04/2026 + * + * Copyright (c) 2026, CERN + * + * Redistribution and use in source and binary forms, + * with or without modification, are permitted according to the terms + * listed in LICENSE (http://roofit.sourceforge.net/license.txt) + */ + +#ifndef RooFit_RooONNXFunction_h +#define RooFit_RooONNXFunction_h + +#include +#include + +#include + +class RooONNXFunction : public RooAbsReal { +public: + RooONNXFunction() = default; + + RooONNXFunction(const char *name, const char *title, const std::vector &inputTensors, + const std::string &onnxFile, const std::vector &inputNames = {}, + const std::vector> &inputShapes = {}); + + RooONNXFunction(const RooONNXFunction &other, const char *newName = nullptr); + + TObject *clone(const char *newName) const override { return new RooONNXFunction(*this, newName); } + + std::size_t nInputTensors() const { return _inputTensors.size(); } + RooArgList const &inputTensorList(int iTensor) const { return *(_inputTensors[iTensor]); } + + std::string funcName() const + { + initialize(); + return _funcName; + } + std::string outerWrapperName() const { return "TMVA_SOFIE_" + funcName() + "::roo_outer_wrapper"; } + +protected: + double evaluate() const override; + +private: + /// Build transient runtime backend on first use. + void initialize() const; + + /// Gather current RooFit inputs into a contiguous feature buffer. + void fillInputBuffer() const; + + struct RuntimeCache; + + std::vector> _inputTensors; ///< Inputs mapping to flattened input tensors. + std::vector _onnxBytes; ///< Persisted ONNX model bytes. + mutable std::shared_ptr _runtime; /// _inputBuffer; /// + void emplace() + { + any = std::make_any(); + ptr = std::any_cast(&any); + } + + void emplace(std::string const &typeName); +}; + +template +void doInferWithSessionVoidPtr(void *session, float const *input, float *out) +{ + doInfer(*reinterpret_cast(session), input, out); +} + +} // namespace RooFit::Detail + +#endif diff --git a/roofit/roofit/src/RooONNXFunction.cxx b/roofit/roofit/src/RooONNXFunction.cxx new file mode 100644 index 0000000000000..00f6e84969e4b --- /dev/null +++ b/roofit/roofit/src/RooONNXFunction.cxx @@ -0,0 +1,379 @@ +/* + * Project: RooFit + * Authors: + * Jonas Rembser, CERN 04/2026 + * + * Copyright (c) 2026, CERN + * + * Redistribution and use in source and binary forms, + * with or without modification, are permitted according to the terms + * listed in LICENSE (http://roofit.sourceforge.net/license.txt) + */ + +#include + +#include +#include + +#include +#include + +/** + \file RooONNXFunction.cxx + \class RooONNXFunction + \ingroup Roofit + + RooONNXFunction wraps an ONNX model as a RooAbsReal, allowing it to be used as + a building block in likelihoods, fits, and statistical analyses without + additional boilerplate code. The class supports models with **one or more + statically-shaped input tensors** and a **single scalar output**. The class + was designed to share workspaces with neural functions for combined fits in + RooFit-based frameworks written in C++. Therefore, the RooONNXFunction doesn't + depend on any Python packages and fully supports ROOT IO, + + The ONNX model is evaluated through compiled C++ code generated at runtime + using **TMVA SOFIE**. Automatic differentiation is supported via **Clad**, + allowing RooFit to access analytical gradients for fast minimization with + Minuit 2. + + The ONNX model is stored internally as a byte payload and serialized together + with the RooONNXFunction object using ROOT I/O. Upon reading from a file or + workspace, the runtime backend is rebuilt automatically. + + ### Input handling + + The model inputs are provided as a list of tensors, where each tensor is + represented by a RooArgList of RooAbsReal objects. The order of the inputs + defines the feature ordering passed to the ONNX model. + Optionally, users can validate that the ONNX model has the expected input + + ### Example (C++) + + \code + // Define input variables + RooRealVar x{"x", "x", 0.0}; + RooRealVar y{"y", "y", 0.0}; + RooRealVar z{"z", "z", 0.0}; + + // Construct ONNX function, building the std::vector in-place + RooONNXFunction func{ + "func", "func", + {{x, y}, {z}}, + "model.onnx" + }; + + // Evaluate + double val = func.getVal(); + std::cout << "Model output: " << val << std::endl; + \endcode + + ### Example (Python) + + \code{.py} + import ROOT + + # Define variables + x = ROOT.RooRealVar("x", "x", 0.0) + y = ROOT.RooRealVar("y", "y", 0.0) + z = ROOT.RooRealVar("z", "z", 0.0) + + # Create ONNX function + func = ROOT.RooONNXFunction( + "func", "func", + [[x, y], [z]], + "model.onnx" + ) + + # Evaluate + print("Model output:", func.getVal()) + \endcode + + */ + +namespace { + +std::vector fileToBytes(std::string const &filePath) +{ + // Read file into byte vector + std::ifstream file(filePath, std::ios::binary); + if (!file) { + std::ostringstream os; + os << "failed to open file '" << filePath << "'"; + throw std::runtime_error(os.str()); + } + + file.seekg(0, std::ios::end); + const std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + + if (size <= 0) { + std::ostringstream os; + os << "file '" << filePath << "' is empty"; + throw std::runtime_error(os.str()); + } + + std::vector bytes(static_cast(size)); + file.read(reinterpret_cast(bytes.data()), size); + + if (!file) { + std::ostringstream os; + os << "error while reading file '" << filePath << "'"; + throw std::runtime_error(os.str()); + } + + return bytes; +} + +template +Fn resolveLazy(std::string const &name, const char *code) +{ + static Fn fn = nullptr; + static std::once_flag flag; + + std::call_once(flag, [&] { + // Try to declare the code + if (!gInterpreter->Declare(code)) { + throw std::runtime_error(std::string("ROOT JIT Declare failed for code defining ") + name); + } + + // Try to resolve the symbol + void *symbol = reinterpret_cast(gInterpreter->ProcessLine((name + ";").c_str())); + + if (!symbol) { + throw std::runtime_error(std::string("ROOT JIT failed to resolve symbol: ") + name); + } + + fn = reinterpret_cast(symbol); + + if (!fn) { + throw std::runtime_error(std::string("ROOT JIT produced null function pointer for: ") + name); + } + }); + + return fn; +} + +template +std::string toPtrString(T *ptr, std::string const &castType) +{ + return TString::Format("reinterpret_cast<%s>(0x%zx)", (castType + "*").c_str(), reinterpret_cast(ptr)) + .Data(); +} + +} // namespace + +void RooFit::Detail::AnyWithVoidPtr::emplace(std::string const &typeName) +{ + auto anyPtrSession = toPtrString(this, "RooFit::Detail::AnyWithVoidPtr"); + gInterpreter->ProcessLine((anyPtrSession + "->emplace<" + typeName + ">();").c_str()); +} + +struct RooONNXFunction::RuntimeCache { + using Func = void (*)(void *, float const *, float *); + + RooFit::Detail::AnyWithVoidPtr _session; + RooFit::Detail::AnyWithVoidPtr _d_session; + Func _func; +}; + +/** + Construct a RooONNXFunction from an ONNX model file. + + \param name Name of the RooFit object + \param title Title of the RooFit object + \param inputTensors Vector of RooArgList, each representing one input tensor. + The variables in each RooArgList match to each flattened input tensor. + \param onnxFile Path to the ONNX model file. The file is read and stored + internally as a byte payload for persistence with RooWorkspace. + \param inputNames Optional list of ONNX input node names. If provided, these + are used to validate that the ONNX model has the structure expected by + your RooFit code. + \param inputShapes Optional list of tensor shapes corresponding to each input + tensor. If provided, these are used to validate that the ONNX models + input tensors have the shape that you expect. If omitted, only the + total size of each tensor is checked. + */ +RooONNXFunction::RooONNXFunction(const char *name, const char *title, const std::vector &inputTensors, + const std::string &onnxFile, const std::vector & /*inputNames*/, + const std::vector> & /*inputShapes*/) + : RooAbsReal{name, title}, _onnxBytes{fileToBytes(onnxFile)} +{ + for (std::size_t i = 0; i < inputTensors.size(); ++i) { + std::string istr = std::to_string(i); + _inputTensors.emplace_back( + std::make_unique(("!inputs_" + istr).c_str(), ("Input tensor " + istr).c_str(), this)); + _inputTensors.back()->addTyped(inputTensors[i]); + } +} + +RooONNXFunction::RooONNXFunction(const RooONNXFunction &other, const char *newName) + : RooAbsReal{other, newName}, _onnxBytes{other._onnxBytes}, _runtime{other._runtime} +{ + for (std::size_t i = 0; i < other._inputTensors.size(); ++i) { + _inputTensors.emplace_back(std::make_unique("!inputs", this, *other._inputTensors[i])); + } +} + +void RooONNXFunction::fillInputBuffer() const +{ + _inputBuffer.clear(); + _inputBuffer.reserve(_inputTensors.size()); + + for (auto const &tensorList : _inputTensors) { + for (auto const *real : static_range_cast(*tensorList)) { + _inputBuffer.push_back(static_cast(real->getVal(tensorList->nset()))); + } + } +} + +void RooONNXFunction::initialize() const +{ + if (_runtime) { + return; + } + + _runtime = std::make_unique(); + + // We are jitting the SOFIE invocation lazily at runtime, to avoid the + // link-time dependency to the SOFIE parser library. + if (gSystem->Load("libROOTTMVASofieParser") < 0) { + throw std::runtime_error("RooONNXFunction: cannot load ONNX file since SOFIE ONNX parser is missing." + " Please build ROOT with tmva-sofie=ON."); + } + using OnnxToCpp = std::string (*)(std::uint8_t const *, std::size_t, const char *); + auto onnxToCppWithSofie = resolveLazy("_RooONNXFunction_onnxToCppWithSofie", + R"( +#include "TMVA/RModelParser_ONNX.hxx" + +std::string _RooONNXFunction_onnxToCppWithSofie(std::uint8_t const *onnxBytes, std::size_t onnxBytesSize, const char *outputName) +{ + namespace SOFIE = TMVA::Experimental::SOFIE; + + std::string buffer{reinterpret_cast(onnxBytes), onnxBytesSize}; + std::istringstream stream{buffer}; + + SOFIE::RModel rmodel = SOFIE::RModelParser_ONNX{}.Parse(stream, outputName); + rmodel.SetOptimizationLevel(SOFIE::OptimizationLevel::kBasic); + rmodel.Generate(SOFIE::Options::kNoWeightFile); + + std::stringstream ss{}; + rmodel.PrintGenerated(ss); + return ss.str(); +} +)"); + + static int counter = 0; + _funcName = "roo_onnx_func_" + std::to_string(counter); + std::string namespaceName = "TMVA_SOFIE_" + _funcName + ""; + counter++; + + std::string modelCode = onnxToCppWithSofie(_onnxBytes.data(), _onnxBytes.size(), _funcName.c_str()); + gInterpreter->Declare(modelCode.c_str()); + + // Declare string to the interpreter, where the %%NAMESPACE%% placeholder + // will first be replaced by the namespace for the emitted code. + auto declareWithNamespace = [&](std::string codeTemplate) { + const std::string placeholder = "%%NAMESPACE%%"; + size_t pos = 0; + + while ((pos = codeTemplate.find(placeholder, pos)) != std::string::npos) { + codeTemplate.replace(pos, placeholder.length(), namespaceName); + pos += namespaceName.length(); + } + + gInterpreter->Declare(codeTemplate.c_str()); + }; + + declareWithNamespace(R"( + +namespace %%NAMESPACE%% { + +float roo_inner_wrapper(Session const &session, float const *input) +{ + float out = 0.; + doInfer(session, input, &out); + return out; +} + +float roo_wrapper(Session const &session, float const *input) +{ + return roo_inner_wrapper(session, input); +} + +} // namespace %%NAMESPACE%% + +)"); + + std::string sessionName = "::TMVA_SOFIE_" + _funcName + "::Session"; + + _runtime->_session.emplace(sessionName); + auto ptrSession = toPtrString(_runtime->_session.ptr, sessionName); + + std::stringstream ss2; + ss2 << "static_cast(RooFit::Detail::doInferWithSessionVoidPtr<" + << sessionName << ">" << ");"; + _runtime->_func = reinterpret_cast(gInterpreter->ProcessLine(ss2.str().c_str())); + + // hardcode the gradient for now + _runtime->_d_session.emplace(sessionName); + auto ptrDSession = toPtrString(_runtime->_d_session.ptr, sessionName); + + gInterpreter->Declare("#include "); + + gInterpreter->ProcessLine(("clad::gradient(" + namespaceName + "::roo_wrapper, \"input\");").c_str()); + + declareWithNamespace(R"( +namespace %%NAMESPACE%% { + +double roo_outer_wrapper(double const *input) { + auto &session = *)" + + ptrSession + R"(; + float inputFlt[inputTensorDims[0].total_size()]; + for (std::size_t i = 0; i < std::size(inputFlt); ++i) { + inputFlt[i] = input[i]; + } + return roo_inner_wrapper(session, inputFlt); +} + +} // namespace %%NAMESPACE%% + +namespace clad::custom_derivatives { + +namespace %%NAMESPACE%% { + +void roo_outer_wrapper_pullback(double const *input, double d_y, double *d_input) { + + using namespace ::%%NAMESPACE%%; + + float inputFlt[inputTensorDims[0].total_size()]; + float d_inputFlt[::std::size(inputFlt)]; + for (::std::size_t i = 0; i < ::std::size(inputFlt); ++i) { + inputFlt[i] = input[i]; + d_inputFlt[i] = d_input[i]; + } + auto *session = )" + ptrSession + + R"(; + auto *d_session = )" + + ptrDSession + R"(; + roo_inner_wrapper_pullback(*session, inputFlt, d_y, d_session, d_inputFlt); + for (::std::size_t i = 0; i < ::std::size(inputFlt); ++i) { + d_input[i] += d_inputFlt[i]; + } +} + +} // namespace %%NAMESPACE%% + +} // namespace clad::custom_derivatives + +)"); +} + +double RooONNXFunction::evaluate() const +{ + initialize(); + fillInputBuffer(); + + float out = 0.f; + _runtime->_func(_runtime->_session.ptr, _inputBuffer.data(), &out); + return static_cast(out); +} diff --git a/roofit/roofit/test/CMakeLists.txt b/roofit/roofit/test/CMakeLists.txt index b13998cbb4de6..b9635e81651d6 100644 --- a/roofit/roofit/test/CMakeLists.txt +++ b/roofit/roofit/test/CMakeLists.txt @@ -38,3 +38,18 @@ else() endif() add_subdirectory(vectorisedPDFs) + +if(tmva-sofie) + ROOT_FIND_PYTHON_MODULE(torch) + ROOT_FIND_PYTHON_MODULE(onnx) + + if (ROOT_TORCH_FOUND AND ROOT_ONNX_FOUND) + ROOT_ADD_TEST(roofit-create-onnx-model COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/create_onnx_model.py + FIXTURES_SETUP roofit-onnx-model + ) + ROOT_ADD_GTEST(testRooONNXFunction testRooONNXFunction.cxx LIBRARIES RooFit + FIXTURES_REQUIRED roofit-onnx-model + ) + endif() + +endif() diff --git a/roofit/roofit/test/create_onnx_model.py b/roofit/roofit/test/create_onnx_model.py new file mode 100644 index 0000000000000..ebac98a0c1eb7 --- /dev/null +++ b/roofit/roofit/test/create_onnx_model.py @@ -0,0 +1,108 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + + +# ---- 1) Define a small fully-connected regression model ---- +class SmallMLP(nn.Module): + def __init__(self, in_features=10, hidden=32, out_features=1): + super().__init__() + self.net = nn.Sequential( + nn.Linear(in_features, hidden), + nn.ReLU(), + nn.Linear(hidden, hidden), + nn.ReLU(), + nn.Linear(hidden, out_features), + ) + + def forward(self, x): + return self.net(x) + + +def write_onnx_model(onnx_path): + + # ---- 2) Create synthetic regression data ---- + torch.manual_seed(0) + + num_samples = 1000 + in_features = 10 + + X = torch.randn(num_samples, in_features) + + # True function: linear combination + noise + true_w = torch.randn(in_features, 1) + true_b = torch.randn(1) + + y = X @ true_w + true_b + 0.1 * torch.randn(num_samples, 1) + + dataset = torch.utils.data.TensorDataset(X, y) + loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True) + + # ---- 3) Train the model ---- + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model = SmallMLP(in_features=10, hidden=32, out_features=1).to(device) + + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=1e-3) + + epochs = 20 + for epoch in range(epochs): + model.train() + total_loss = 0.0 + + for xb, yb in loader: + xb, yb = xb.to(device), yb.to(device) + + optimizer.zero_grad() + pred = model(xb) + loss = criterion(pred, yb) + loss.backward() + optimizer.step() + + total_loss += loss.item() * xb.size(0) + + avg_loss = total_loss / len(dataset) + print(f"Epoch {epoch + 1}/{epochs} - loss: {avg_loss:.6f}") + + # ---- 4) Export the trained model to ONNX ---- + model.eval() + + # Create a torch.export program in the parent process. + # This does not require importing ONNX. + example_input = (torch.randn(1, 10, device=device),) + exported = torch.export.export(model, example_input) + + torch.onnx.export( + exported, + args=(), + f=onnx_path, + external_data=False, + dynamo=True, + ) + + return model + + +def main(): + + onnx_path = "regression_mlp.onnx" + + model = write_onnx_model(onnx_path) + + x = torch.tensor([[0.1] * 10], requires_grad=True) + + y = model(x) + + y.backward() + + print("prediction:", y.item()) + print("input gradient:", x.grad) + + np.savetxt("regression_mlp_pred.txt", y.detach().numpy()) + np.savetxt("regression_mlp_grad.txt", x.grad.detach().numpy()) + + +if __name__ == "__main__": + main() diff --git a/roofit/roofit/test/testRooONNXFunction.cxx b/roofit/roofit/test/testRooONNXFunction.cxx new file mode 100644 index 0000000000000..b184eb82030ea --- /dev/null +++ b/roofit/roofit/test/testRooONNXFunction.cxx @@ -0,0 +1,134 @@ +// Tests for the RooONNXFunction +// Authors: Jonas Rembser, CERN 2026 + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include + +namespace { + +std::vector readDoublesFromFile(const std::string &filename) +{ + std::vector values; + std::ifstream file(filename); + + if (!file) { + std::cerr << "Error: Could not open file " << filename << "\n"; + return values; + } + + double x; + while (file >> x) { + values.push_back(x); + } + + return values; +} + +} // namespace + +/// Basic test for the evaluation of a RooONNXFunction with a single input +/// vector. +TEST(RooONNXFunction, Basic) +{ + double refPred = readDoublesFromFile("regression_mlp_pred.txt")[0]; + + RooArgList args; + for (int i = 0; i < 10; ++i) { + auto v = std::make_unique(std::to_string(i).c_str(), "", 0.1, -10.0, 10.0); + args.addOwned(std::move(v)); + } + + RooONNXFunction roo_func{"func", "", {args}, "regression_mlp.onnx"}; + + EXPECT_NEAR(roo_func.getVal(), refPred, 1e-5); +} + +// Test the serialization to RooWorkspace. The ONNX payload will be embedded in +// the RooWorkspace as a binary blob. +TEST(RooONNXFunction, Basic_RooWorkspace) +{ + RooHelpers::LocalChangeMsgLevel chmsglvl{RooFit::WARNING, 0u, RooFit::ObjectHandling, true}; + + // Write to RooWorkspace + { + RooArgList args; + for (int i = 0; i < 10; ++i) { + auto v = std::make_unique(std::to_string(i).c_str(), "", 0.1, -10.0, 10.0); + args.addOwned(std::move(v)); + } + + RooONNXFunction roo_func{"func", "", {args}, "regression_mlp.onnx"}; + RooWorkspace ws{"ws"}; + ws.import(roo_func); + ws.writeToFile("RooONNXFunction_Basic.root"); + } + + // Read back and validate + std::unique_ptr file{TFile::Open("RooONNXFunction_Basic.root")}; + RooWorkspace *ws = dynamic_cast(file->Get("ws")); + auto *roo_func = dynamic_cast(ws->function("func")); + + double refPred = readDoublesFromFile("regression_mlp_pred.txt")[0]; + EXPECT_NEAR(roo_func->getVal(), refPred, 1e-5); +} + +#ifdef ROOFIT_CLAD +/// Basic test for getting the analytic gradient of a RooONNXFunction with a +/// single input vector. +TEST(RooONNXFunction, Basic_CodegenAD) +{ + RooHelpers::LocalChangeMsgLevel chmsglvl{RooFit::WARNING, 0u, RooFit::Fitting, true}; + + double refPred = readDoublesFromFile("regression_mlp_pred.txt")[0]; + std::vector refGrad = readDoublesFromFile("regression_mlp_grad.txt"); + + RooArgList args; + for (int i = 0; i < 10; ++i) { + auto v = std::make_unique(std::to_string(i).c_str(), "", 0.1, -10.0, 10.0); + args.addOwned(std::move(v)); + } + + RooONNXFunction roo_func{"func", "", {args}, "regression_mlp.onnx"}; + + RooDataSet data("data", "data", {}); + + RooFit::Experimental::RooEvaluatorWrapper roo_final{roo_func, &data, false, "", nullptr, false}; + + EXPECT_NEAR(roo_final.getVal(), refPred, 1e-5); + + roo_final.generateGradient(); + + std::vector output_vec(10); + + roo_final.gradient(output_vec.data()); + roo_final.setUseGeneratedFunctionCode(true); + // For debugging + // roo_final.writeDebugMacro("codegen"); + + for (int i = 0; i < 10; ++i) { + EXPECT_NEAR(output_vec[i], refGrad[i], 1e-5); + } + + // Zero out gradient output buffer and recalculate, just to check that no + // internal state in not reset. + for (int i = 0; i < 10; ++i) { + output_vec[i] = 0.; + } + + roo_final.gradient(output_vec.data()); + + for (int i = 0; i < 10; ++i) { + EXPECT_NEAR(output_vec[i], refGrad[i], 1e-5); + } +} +#endif