Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions include/xgboost/multi_target_tree_model.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023-2025, XGBoost contributors
* Copyright 2023-2026, XGBoost contributors
*
* @brief Core data structure for multi-target trees.
*/
Expand Down Expand Up @@ -58,6 +58,10 @@ class MultiTargetTree : public Model {
HostDeviceVector<float> weights_;
// Output weights.
HostDeviceVector<float> leaf_weights_;
// Loss change for each node.
HostDeviceVector<float> loss_chg_;
// Sum of hessians for each node (coverage).
HostDeviceVector<float> sum_hess_;

[[nodiscard]] linalg::VectorView<float const> NodeWeight(bst_node_t nidx) const {
auto beg = nidx * this->NumSplitTargets();
Expand All @@ -81,16 +85,20 @@ class MultiTargetTree : public Model {
MultiTargetTree& operator=(MultiTargetTree&& that) = delete;

/**
* @brief Set the weight for the root.
* @brief Set the weight and statistics for the root.
*
* @param weight The weight vector for the root node.
* @param sum_hess The sum of hessians for the root node (coverage).
*/
void SetRoot(linalg::VectorView<float const> weight);
void SetRoot(linalg::VectorView<float const> weight, float sum_hess);
/**
* @brief Expand a leaf into split node.
*/
void Expand(bst_node_t nidx, bst_feature_t split_idx, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);
linalg::VectorView<float const> right_weight, float loss_chg, float sum_hess,
float left_sum, float right_sum);
/** @see RegTree::SetLeaves */
void SetLeaves(std::vector<bst_node_t> leaves, common::Span<float const> weights);
/** @brief Copy base weight into leaf weight for a non-reduced multi-target tree. */
Expand Down
19 changes: 13 additions & 6 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2014-2025, XGBoost Contributors
* Copyright 2014-2026, XGBoost Contributors
*
* @brief model structure for tree
* \author Tianqi Chen
Expand Down Expand Up @@ -322,11 +322,17 @@ class RegTree : public Model {
bst_node_t leaf_right_child = kInvalidNodeId);
/**
* @brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
*
* @param gain The gain (loss change) from this split.
* @param sum_hess The sum of hessians for the parent node (coverage).
* @param left_sum The sum of hessians for the left child (coverage).
* @param right_sum The sum of hessians for the right child (coverage).
*/
void ExpandNode(bst_node_t nidx, bst_feature_t split_index, float split_cond, bool default_left,
linalg::VectorView<float const> base_weight,
linalg::VectorView<float const> left_weight,
linalg::VectorView<float const> right_weight);
linalg::VectorView<float const> right_weight, float loss_chg, float sum_hess,
float left_sum, float right_sum);
/**
* @brief Set all leaf weights for a multi-target tree.
*
Expand Down Expand Up @@ -407,13 +413,14 @@ class RegTree : public Model {
*/
[[nodiscard]] bst_node_t GetDepth(bst_node_t nidx) const;
/**
* @brief Set the root weight for a multi-target tree.
* @brief Set the root weight and statistics for a multi-target tree.
*
* @param weight Internal split weight, with size equals to reduced targets.
* @param weight Internal split weight, with size equals to reduced targets.
* @param sum_hess The sum of hessians for the root node (coverage).
*/
void SetRoot(linalg::VectorView<float const> weight) {
void SetRoot(linalg::VectorView<float const> weight, float sum_hess) {
CHECK(IsMultiTarget());
return this->p_mt_tree_->SetRoot(weight);
return this->p_mt_tree_->SetRoot(weight, sum_hess);
}
/**
* @brief Get the maximum depth.
Expand Down
129 changes: 118 additions & 11 deletions python-package/xgboost/testing/multi_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# pylint: disable=unbalanced-tuple-unpacking
from types import ModuleType
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple

import numpy as np
import pytest
Expand All @@ -11,6 +11,7 @@
make_multilabel_classification,
make_regression,
)
from sklearn.metrics.pairwise import cosine_similarity

import xgboost.testing as tm

Expand Down Expand Up @@ -384,12 +385,13 @@ def run() -> Booster:


def run_column_sampling(device: Device) -> None:
"""Test with column sampling."""
"""Test column sampling with feature importance for multi-target trees."""
n_features = 32
n_targets = 3
X, y = make_regression(
n_samples=1024, n_features=n_features, random_state=1994, n_targets=3
n_samples=1024, n_features=n_features, random_state=1994, n_targets=n_targets
)
# First half is valid, second half is 0.
# First half of features have weight, second half has 0 weight (not sampled).
feature_weights = np.zeros(shape=(n_features, 1), dtype=np.float32)
feature_weights[: n_features // 2] = 1.0 / (n_features / 2)
Xy = QuantileDMatrix(X, y, feature_weights=feature_weights)
Expand All @@ -401,13 +403,28 @@ def run_column_sampling(device: Device) -> None:
"colsample_bynode": 0.4,
}
booster = train(params, Xy, num_boost_round=16)
fscores = booster.get_fscore()
# sampled
for f in range(0, n_features // 2):
assert f"f{f}" in fscores
# not sampled
for f in range(n_features // 2, n_features):
assert f"f{f}" not in fscores

# Test all importance types
for importance_type in ["weight", "gain", "total_gain", "cover", "total_cover"]:
scores = booster.get_score(importance_type=importance_type)
assert len(scores) > 0, f"No scores for {importance_type}"

# Sampled features (first half) should be in scores
for f in range(0, n_features // 2):
assert f"f{f}" in scores, f"f{f} not in {importance_type} scores"

# Non-sampled features (second half) should NOT be in scores
for f in range(n_features // 2, n_features):
assert (
f"f{f}" not in scores
), f"f{f} should not be in {importance_type} scores"

# Verify values are scalars and non-negative
for feat, score in scores.items():
assert isinstance(
score, float
), f"Score should be scalar, got {type(score)}"
Comment thread
trivialfis marked this conversation as resolved.
Outdated
assert score >= 0, f"Negative {importance_type} for {feat}: {score}"


def run_grow_policy(device: Device, grow_policy: str) -> None:
Expand All @@ -426,3 +443,93 @@ def run_grow_policy(device: Device, grow_policy: str) -> None:

evals_result = train_result(params, Xy, num_rounds=10)
assert non_increasing(evals_result["train"]["rmse"])


def run_mixed_strategy(device: Device) -> None:
"""Test mixed multi_strategy with ResetStrategy callback."""
X, y = make_classification(
n_samples=1024, n_informative=8, n_classes=3, random_state=1994
)
Xy = DMatrix(data=X, label=y)

booster = train(
{
"num_parallel_tree": 4,
"num_class": 3,
"objective": "multi:softprob",
"multi_strategy": "multi_output_tree",
"device": device,
"debug_synchronize": True,
"base_score": 0,
},
num_boost_round=16,
dtrain=Xy,
callbacks=[ResetStrategy()],
)

# Test model slicing - each boosting round should be iterable
assert len(list(booster)) == 16

# Test that sliced predictions sum to full prediction
predt = booster.predict(Xy, output_margin=True)
predt_sum = np.zeros(predt.shape)
for t in booster:
predt_sum += t.predict(Xy, output_margin=True)
np.testing.assert_allclose(predt, predt_sum, atol=1e-5)

# Test feature importance works with mixed trees
for importance_type in ["weight", "gain", "total_gain", "cover", "total_cover"]:
scores = booster.get_score(importance_type=importance_type)
assert len(scores) > 0
for score in scores.values():
assert isinstance(score, float)
assert score >= 0


def run_feature_importance_strategy_compare(device: Device) -> None:
"""Different strategies produce similar feature importance ratios."""
n_features = 16
X, y = make_classification(
n_samples=2048, n_features=n_features, n_informative=10, n_classes=4,
random_state=1994,
)
Xy = DMatrix(data=X, label=y)

Comment thread
trivialfis marked this conversation as resolved.
base_params: Dict[str, Any] = {
"num_class": 4,
"objective": "multi:softprob",
"device": device,
"debug_synchronize": True,
"max_depth": 5,
}

# Train models with different strategies
boosters = [
train({**base_params, "multi_strategy": "multi_output_tree"}, Xy, num_boost_round=32),
train({**base_params, "multi_strategy": "one_output_per_tree"}, Xy, num_boost_round=32),
train(
{**base_params, "multi_strategy": "multi_output_tree"},
Xy,
num_boost_round=32,
callbacks=[ResetStrategy()],
),
]

def get_normalized_importance(booster: Booster, importance_type: str) -> np.ndarray:
"""Get feature importance as normalized array (sums to 1)."""
scores = booster.get_score(importance_type=importance_type)
arr = np.array([scores.get(f"f{i}", 0.0) for i in range(n_features)])
return arr / arr.sum() if arr.sum() > 0 else arr

for importance_type in ["weight", "gain", "total_gain", "cover", "total_cover"]:
imps = [get_normalized_importance(b, importance_type) for b in boosters]

# Check that importances are not exactly the same (different strategies)
assert not np.allclose(imps[0], imps[1])
assert not np.allclose(imps[0], imps[2])

# Check that normalized importances are similar (correlated)
# All strategies should have reasonably similar importance patterns
assert cosine_similarity([imps[0]], [imps[1]])[0, 0] > 0.9
assert cosine_similarity([imps[0]], [imps[2]])[0, 0] > 0.9
assert cosine_similarity([imps[1]], [imps[2]])[0, 0] > 0.9
6 changes: 3 additions & 3 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2014-2025, XGBoost Contributors
* Copyright 2014-2026, XGBoost Contributors
* \file gbtree.cc
* \brief gradient boosted tree implementation.
* \author Tianqi Chen
Expand Down Expand Up @@ -256,15 +256,15 @@ class GBTree : public GradientBooster {
if constexpr (tree::IsScalarTree<decltype(tree)>()) {
gain_map[split] += tree.Stat(nidx).loss_chg;
} else {
LOG(FATAL) << "gain/total_gain " << MTNotImplemented();
gain_map[split] += tree.LossChg(nidx);
}
});
} else if (importance_type == "cover" || importance_type == "total_cover") {
add_score([&](auto const& tree, bst_node_t nidx, bst_feature_t split) {
if constexpr (tree::IsScalarTree<decltype(tree)>()) {
gain_map[split] += tree.Stat(nidx).sum_hess;
} else {
LOG(FATAL) << "cover/total_cover " << MTNotImplemented();
gain_map[split] += tree.SumHess(nidx);
}
});
} else {
Expand Down
4 changes: 2 additions & 2 deletions src/predictor/gbtree_view.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2025, XGBoost Contributors
* Copyright 2025-2026, XGBoost Contributors
*/
#pragma once

Expand Down Expand Up @@ -51,7 +51,7 @@ class GBTreeModelView {
for (bst_tree_t tree_idx = this->tree_begin; tree_idx < this->tree_end; ++tree_idx) {
auto const& p_tree = model.trees[tree_idx];
if (p_tree->IsMultiTarget()) {
auto tree = tree::MultiTargetTreeView{device, p_tree.get()};
auto tree = tree::MultiTargetTreeView{device, need_stat, p_tree.get()};
this->n_nodes += tree.Size();
trees.emplace_back(tree);
} else {
Expand Down
17 changes: 11 additions & 6 deletions src/tree/gpu_hist/expand_entry.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ struct MultiExpandEntry {
MultiSplitCandidate split;

common::Span<float> base_weight;
// Sum Hessian of the first target. Used as a surrogate for node size.
double left_fst_hess{0};
double right_fst_hess{0};
// Sum of hessians across all targets for left/right children.
double left_sum{0};
double right_sum{0};

MultiExpandEntry() = default;

Expand Down Expand Up @@ -168,9 +168,14 @@ struct MultiExpandEntry {
return true;
}

__device__ void UpdateFirstHessian(GradientPairPrecise const& lg, GradientPairPrecise const& rg) {
this->left_fst_hess = lg.GetHess();
this->right_fst_hess = rg.GetHess();
/**
* @brief Update hessian statistics.
* @param left_hess Sum of hessians across all targets for left child.
* @param right_hess Sum of hessians across all targets for right child.
*/
__device__ void UpdateHessian(double left_hess, double right_hess) {
this->left_sum = left_hess;
this->right_sum = right_hess;
}

friend std::ostream& operator<<(std::ostream& os, MultiExpandEntry const& entry);
Expand Down
12 changes: 5 additions & 7 deletions src/tree/gpu_hist/multi_evaluate_splits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,

bool l = true, r = true;
float parent_gain = 0;
GradientPairPrecise lg_fst, rg_fst;
double left_hess = 0, right_hess = 0; // Sum of child hessians across all targets
auto eta = shared_inputs.param.learning_rate;

for (bst_target_t t = 0; t < n_targets; ++t) {
Expand Down Expand Up @@ -417,16 +417,14 @@ void MultiHistEvaluator::EvaluateSplits(Context const *ctx,
left_weight[t] = CalcWeight(shared_inputs.param, lg.GetGrad(), lg.GetHess()) * eta;
}

if (t == 0) {
lg_fst = lg;
rg_fst = rg;
}
left_hess += lg.GetHess();
right_hess += rg.GetHess();
}

// Set up the output entry with spans pointing to persistent weight storage
out_splits[nidx_in_set] = {nidx, input.depth, best_split, base_weight};
out_splits[nidx_in_set].split.loss_chg -= parent_gain;
out_splits[nidx_in_set].UpdateFirstHessian(lg_fst, rg_fst);
out_splits[nidx_in_set].UpdateHessian(left_hess, right_hess);

if (l || r) {
out_splits[nidx_in_set].split.loss_chg = -std::numeric_limits<float>::max();
Expand All @@ -438,7 +436,7 @@ void MultiHistEvaluator::ApplyTreeSplit(Context const *ctx, RegTree const *p_tre
common::Span<MultiExpandEntry const> d_candidates,
bst_target_t n_targets) {
// Assign the node sums here, for the next evaluate split call.
auto mt_tree = MultiTargetTreeView{ctx->Device(), p_tree};
auto mt_tree = MultiTargetTreeView{ctx->Device(), false, p_tree};
auto max_in_it = dh::MakeIndexTransformIter([=] __device__(std::size_t i) -> bst_node_t {
return std::max(mt_tree.LeftChild(d_candidates[i].nidx),
mt_tree.RightChild(d_candidates[i].nidx));
Expand Down
13 changes: 12 additions & 1 deletion src/tree/hist/evaluate_splits.h
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,19 @@ class HistMultiEvaluator {
linalg::MakeVec(candidate.split.right_sum.data(), candidate.split.right_sum.size());
CalcWeight(*param_, right_sum, param_->learning_rate, right_weight);

// Compute the loss_chg and sum hessians for parent and children
float loss_chg = candidate.split.loss_chg;
// Sum hessians across all targets for each child
float left_sum_hess = 0.0f, right_sum_hess = 0.0f;
for (std::size_t t = 0; t < candidate.split.left_sum.size(); ++t) {
left_sum_hess += candidate.split.left_sum[t].GetHess();
right_sum_hess += candidate.split.right_sum[t].GetHess();
}
float sum_hess = left_sum_hess + right_sum_hess;

p_tree->ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight, left_weight, right_weight);
candidate.split.DefaultLeft(), base_weight, left_weight, right_weight,
loss_chg, sum_hess, left_sum_hess, right_sum_hess);

CHECK(p_tree->IsMultiTarget());
auto mt_tree = p_tree->HostMtView();
Expand Down
Loading
Loading