Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
da288b7
Fix iteration handling for nrounds=0 in xgb.train
Sanidhyavijay24 Dec 10, 2025
23179cf
Removed trailing whitespace
Sanidhyavijay24 Dec 10, 2025
5e7b8e0
whitespace error fix again (last time sorry!!)
Sanidhyavijay24 Dec 10, 2025
0a05eee
Add regression test for nrounds=0
Sanidhyavijay24 Dec 10, 2025
e73f2b1
Merge branch 'master' into master
Sanidhyavijay24 Dec 10, 2025
615f9b3
Fix test for xgb.train with nrounds = 0
Sanidhyavijay24 Dec 10, 2025
296f1ba
Fix style/whitespace error
Sanidhyavijay24 Dec 10, 2025
d488100
Clean up test_basic.R by removing blank line
Sanidhyavijay24 Dec 11, 2025
946a511
Improve test coverage for xgb.train with nrounds = 0
Sanidhyavijay24 Dec 11, 2025
0a54297
Merge branch 'master' into master
Sanidhyavijay24 Dec 11, 2025
f67741a
Fixed use of single '&' for vector comparison, not '&&'
Sanidhyavijay24 Dec 11, 2025
0ed5267
Now handles NULL niter in continuation with callbacks
Sanidhyavijay24 Dec 12, 2025
03b1300
Merge branch 'master' into master
Sanidhyavijay24 Dec 12, 2025
8de55c3
Corrected case 3
Sanidhyavijay24 Dec 12, 2025
ee411c4
Merge branch 'master' into master
Sanidhyavijay24 Dec 12, 2025
5099fa3
Updated tests for CRAN compliance and cleanup
Sanidhyavijay24 Dec 13, 2025
2ebf867
Merge branch 'dmlc:master' into master
Sanidhyavijay24 Dec 17, 2025
15e340d
Fix predict to respect base_margin in xgb.DMatrix
Sanidhyavijay24 Dec 17, 2025
2a495dc
Added test for DMatrix base_margin prediction
Sanidhyavijay24 Dec 17, 2025
cdac55b
Simplify DMatrix margin fix
Sanidhyavijay24 Dec 17, 2025
4490396
Update test to use fresh objects (avoid caching)
Sanidhyavijay24 Dec 17, 2025
61bbadd
Update test to avoid caching in all cases
Sanidhyavijay24 Dec 17, 2025
7d99ab8
Remove unnecessary blank line in xgb.Booster.R
Sanidhyavijay24 Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions R-package/R/xgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -362,14 +362,13 @@ predict.xgb.Booster <- function(object, newdata, missing = NA, outputmargin = FA
newdata <- validate.features(object, newdata)
}
is_dmatrix <- inherits(newdata, "xgb.DMatrix")
if (is_dmatrix && !is.null(base_margin)) {
stop(
"'base_margin' is not supported when passing 'xgb.DMatrix' as input.",
" Should be passed as argument to 'xgb.DMatrix' constructor."
)
}
if (is_dmatrix) {
rnames <- NULL

# FIX: If user passed a margin argument, apply it to the DMatrix.
if (!is.null(base_margin)) {
setinfo(newdata, "base_margin", base_margin)
Comment thread
Sanidhyavijay24 marked this conversation as resolved.
}
} else {
rnames <- row.names(newdata)
}
Expand Down
40 changes: 40 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -1231,3 +1231,43 @@ test_that("xgb.train works with nrounds=0 (serialization, continuation, callback
expect_equal(iter_cb, 5)
}
})

test_that("predict respects base_margin inside xgb.DMatrix", {
# Reuse global data variable 'train'
dtrain <- xgb.DMatrix(train$data, label = train$label, nthread = 1)

# Train a dummy model
bst <- xgb.train(
params = list(objective = "binary:logistic", nthread = 1),
data = dtrain,
nrounds = 1,
verbose = 0
)

# Case 1: Fresh DMatrix with margin 0.5
dtest1 <- xgb.DMatrix(train$data[1:10, ], label = train$label[1:10], nthread = 1)
setinfo(dtest1, "base_margin", rep(0.5, 10))
p1 <- predict(bst, dtest1)

# Case 2: Fresh DMatrix with margin 1.5
# Using a NEW object ensures no prediction caching interferes
dtest2 <- xgb.DMatrix(train$data[1:10, ], label = train$label[1:10], nthread = 1)
setinfo(dtest2, "base_margin", rep(1.5, 10))
p2 <- predict(bst, dtest2)

# Logic check: Different margins MUST yield different predictions
expect_false(isTRUE(all.equal(p1, p2)))

# Case 3: Explicit override via argument (The Fix)
# We use a FRESH DMatrix to ensure we test the argument wiring, not the
# C++ cache. If we reused dtest1, XGBoost would return the cached
# prediction from Case 1.
dtest3 <- xgb.DMatrix(train$data[1:10, ], label = train$label[1:10], nthread = 1)

# Even if dtest3 has margin 0.5 inside...
setinfo(dtest3, "base_margin", rep(0.5, 10))

# ... passing 1.5 as an argument should override it (and match p2)
p3 <- predict(bst, dtest3, base_margin = rep(1.5, 10))
expect_equal(p2, p3, tolerance = 1e-6)
})
Loading