diff --git a/R-package/R/xgb.Booster.R b/R-package/R/xgb.Booster.R index ebc4a9bb61f8..018ae17ce3c7 100644 --- a/R-package/R/xgb.Booster.R +++ b/R-package/R/xgb.Booster.R @@ -362,14 +362,13 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA newdata <- validate.features(object, newdata) } is_dmatrix <- inherits(newdata, "xgb.DMatrix") - if (is_dmatrix && !is.null(base_margin)) { - stop( - "'base_margin' is not supported when passing 'xgb.DMatrix' as input.", - " Should be passed as argument to 'xgb.DMatrix' constructor." - ) - } if (is_dmatrix) { rnames <- NULL + + # FIX: If user passed a margin argument, apply it to the DMatrix. + if (!is.null(base_margin)) { + setinfo(newdata, "base_margin", base_margin) + } } else { rnames <- row.names(newdata) } diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 7bdf9a601d2a..4756f1d697cb 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -1231,3 +1231,43 @@ test_that("xgb.train works with nrounds=0 (serialization, continuation, callback expect_equal(iter_cb, 5) } }) + +test_that("predict respects base_margin inside xgb.DMatrix", { + # Reuse global data variable 'train' + dtrain <- xgb.DMatrix(train$data, label = train$label, nthread = 1) + + # Train a dummy model + bst <- xgb.train( + params = list(objective = "binary:logistic", nthread = 1), + data = dtrain, + nrounds = 1, + verbose = 0 + ) + + # Case 1: Fresh DMatrix with margin 0.5 + dtest1 <- xgb.DMatrix(train$data[1:10, ], label = train$label[1:10], nthread = 1) + setinfo(dtest1, "base_margin", rep(0.5, 10)) + p1 <- predict(bst, dtest1) + + # Case 2: Fresh DMatrix with margin 1.5 + # Using a NEW object ensures no prediction caching interferes + dtest2 <- xgb.DMatrix(train$data[1:10, ], label = train$label[1:10], nthread = 1) + setinfo(dtest2, "base_margin", rep(1.5, 10)) + p2 <- predict(bst, dtest2) + + # Logic check: Different margins MUST yield different predictions + expect_false(isTRUE(all.equal(p1, p2))) + + # Case 3: Explicit override via argument (The Fix) + # We use a FRESH DMatrix to ensure we test the argument wiring, not the + # C++ cache. If we reused dtest1, XGBoost would return the cached + # prediction from Case 1. + dtest3 <- xgb.DMatrix(train$data[1:10, ], label = train$label[1:10], nthread = 1) + + # Even if dtest3 has margin 0.5 inside... + setinfo(dtest3, "base_margin", rep(0.5, 10)) + + # ... passing 1.5 as an argument should override it (and match p2) + p3 <- predict(bst, dtest3, base_margin = rep(1.5, 10)) + expect_equal(p2, p3, tolerance = 1e-6) +})