Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/data/ellpack_page.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2024, XGBoost contributors
* Copyright 2019-2026, XGBoost contributors
*/
#ifndef XGBOOST_USE_CUDA

Expand Down
16 changes: 4 additions & 12 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#include "../common/compressed_iterator.h" // for CompressedIterator
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/cuda_rt_utils.h" // for SetDevice
#include "../common/cuda_stream.h" // for DefaultStream
#include "../common/cuda_stream.h" // for StreamRef
#include "../common/hist_util.cuh" // for HistogramCuts
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
#include "../common/transform_iterator.h" // for MakeIndexTransformIter
Expand All @@ -32,6 +32,8 @@
namespace xgboost {
EllpackPage::EllpackPage() : impl_{new EllpackPageImpl{}} {}

EllpackPageImpl::EllpackPageImpl() = default;

EllpackPage::EllpackPage(Context const* ctx, DMatrix* dmat, const BatchParam& param)
: impl_{new EllpackPageImpl{ctx, dmat, param}} {}

Expand Down Expand Up @@ -500,17 +502,7 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
this->monitor_.Stop("CopyGHistToEllpack");
}

EllpackPageImpl::~EllpackPageImpl() noexcept(false) {
// Sync the stream to make sure all running CUDA kernels finish before deallocation.
auto status = curt::DefaultStream().Sync(false);
if (status != cudaSuccess) {
auto str = cudaGetErrorString(status);
// For external-memory, throwing here can trigger a series of calls to
// `std::terminate` by various destructors. For now, we just log the error.
LOG(WARNING) << "Ran into CUDA error:" << str << "\nXGBoost is likely to abort.";
}
dh::safe_cuda(status);
}
EllpackPageImpl::~EllpackPageImpl() noexcept(false) = default;

// A functor that copies the data from one EllpackPage to another.
template <typename IterT>
Expand Down
4 changes: 2 additions & 2 deletions src/data/ellpack_page.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2025, XGBoost Contributors
* Copyright 2019-2026, XGBoost Contributors
*/
#ifndef XGBOOST_DATA_ELLPACK_PAGE_CUH_
#define XGBOOST_DATA_ELLPACK_PAGE_CUH_
Expand Down Expand Up @@ -186,7 +186,7 @@ class EllpackPageImpl {
* This is used in the external memory case. An empty ELLPACK page is constructed with its content
* set later by the reader.
*/
EllpackPageImpl() = default;
EllpackPageImpl();

/**
* @brief Constructor from existing ellpack matrics.
Expand Down
2 changes: 1 addition & 1 deletion src/data/ellpack_page.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2017-2023 by XGBoost Contributors
* Copyright 2017-2026, XGBoost Contributors
*/
#ifndef XGBOOST_DATA_ELLPACK_PAGE_H_
#define XGBOOST_DATA_ELLPACK_PAGE_H_
Expand Down
31 changes: 15 additions & 16 deletions src/data/ellpack_page_raw_format.cu
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
/**
* Copyright 2019-2025, XGBoost contributors
* Copyright 2019-2026, XGBoost contributors
*/
#include <dmlc/registry.h>

#include <cstddef> // for size_t
#include <vector> // for vector

#include "../common/cuda_rt_utils.h"
#include "../common/cuda_context.cuh" // for CUDAContext
#include "../common/cuda_stream.h" // for Event
#include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream
#include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc
Expand All @@ -21,7 +21,7 @@ DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format);
namespace {
// Function to support system without HMM or ATS
template <typename T>
[[nodiscard]] bool ReadDeviceVec(common::AlignedResourceReadStream* fi,
[[nodiscard]] bool ReadDeviceVec(Context const* ctx, common::AlignedResourceReadStream* fi,
common::RefResourceView<T>* vec) {
xgboost_NVTX_FN_RANGE();

Expand All @@ -42,7 +42,7 @@ template <typename T>

*vec = common::MakeFixedVecWithCudaMalloc<T>(n);
dh::safe_cuda(
cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, curt::DefaultStream()));
cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, ctx->CUDACtx()->Stream()));
return true;
Comment on lines 43 to 46
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReadDeviceVec now takes a Context const* and immediately uses ctx->CUDACtx()->Stream() for cudaMemcpyAsync, but it never validates that ctx is non-null / CUDA, nor does it ensure the current CUDA device matches ctx's ordinal. This can lead to invalid stream/device usage when the caller's current device differs from ctx. Add a CHECK(ctx && ctx->IsCUDA()) and set the device (e.g. curt::SetDevice(ctx->Ordinal())) before allocating/copying.

Copilot uses AI. Check for mistakes.
}
} // namespace
Expand All @@ -62,7 +62,7 @@ template <typename T>
RET_IF_NOT(fi->Read(&impl->info.row_stride));

if (this->param_.prefetch_copy || !has_hmm_ats_) {
RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer));
RET_IF_NOT(ReadDeviceVec(ctx_, fi, &impl->gidx_buffer));
} else {
Comment on lines 64 to 66
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EllpackPageRawFormat::Read uses ctx_ (stored in the format object) for device copies and stream sync, but there’s no precondition check that ctx_ is non-null / CUDA or that the active device matches ctx_. Since the constructor now accepts an arbitrary context pointer, add validation (and potentially a device set) at the start of Read/Write to prevent null deref or cross-device stream usage.

Copilot uses AI. Check for mistakes.
RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer));
}
Expand All @@ -73,7 +73,7 @@ template <typename T>

impl->SetCuts(this->cuts_);

curt::DefaultStream().Sync();
ctx_->CUDACtx()->Stream().Sync();
return true;
}

Expand All @@ -87,14 +87,13 @@ template <typename T>
bytes += fo->Write(impl->is_dense);
bytes += fo->Write(impl->info.row_stride);
std::vector<common::CompressedByteT> h_gidx_buffer;
Context ctx = Context{}.MakeCUDA(curt::CurrentDevice());
// write data into the h_gidx_buffer
[[maybe_unused]] auto h_accessor = impl->GetHostEllpack(&ctx, &h_gidx_buffer);
[[maybe_unused]] auto h_accessor = impl->GetHostEllpack(ctx_, &h_gidx_buffer);
bytes += common::WriteVec(fo, h_gidx_buffer);
bytes += fo->Write(impl->base_rowid);
bytes += fo->Write(impl->NumSymbols());

curt::DefaultStream().Sync();
ctx_->CUDACtx()->Stream().Sync();
return bytes;
}

Expand All @@ -104,21 +103,21 @@ template <typename T>
auto* impl = page->Impl();
CHECK(this->cuts_->cut_values_.DeviceCanRead());

auto ctx = Context{}.MakeCUDA(curt::CurrentDevice());
auto stream = ctx_->CUDACtx()->Stream();

auto dispatch = [&] {
fi->Read(&ctx, page, this->param_.prefetch_copy || !this->has_hmm_ats_);
fi->Read(ctx_, page, this->param_.prefetch_copy || !this->has_hmm_ats_);
impl->SetCuts(this->cuts_);
};

if (ConsoleLogger::GlobalVerbosity() == ConsoleLogger::LogVerbosity::kDebug) {
curt::Event start{false}, stop{false};
float milliseconds = 0;
start.Record(ctx.CUDACtx()->Stream());
start.Record(stream);

dispatch();

stop.Record(ctx.CUDACtx()->Stream());
stop.Record(stream);
stop.Sync();
dh::safe_cuda(cudaEventElapsedTime(&milliseconds, start, stop));
double n_bytes = page->Impl()->MemCostBytes();
Expand All @@ -128,7 +127,7 @@ template <typename T>
dispatch();
}

curt::DefaultStream().Sync();
stream.Sync();

return true;
}
Expand All @@ -137,8 +136,8 @@ template <typename T>
EllpackHostCacheStream* fo) const {
xgboost_NVTX_FN_RANGE_C(3, 252, 198);

bool new_page = fo->Write(page);
curt::DefaultStream().Sync();
bool new_page = fo->Write(ctx_, page);
ctx_->CUDACtx()->Stream().Sync();

if (new_page) {
auto cache = fo->Share();
Expand Down
7 changes: 5 additions & 2 deletions src/data/ellpack_page_raw_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,17 @@ class EllpackPageRawFormat : public SparsePageFormat<EllpackPage> {
BatchParam param_;
// Supports CUDA HMM or ATS
bool has_hmm_ats_{false};
Context const* ctx_;

public:
explicit EllpackPageRawFormat(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
explicit EllpackPageRawFormat(Context const* ctx,
std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
BatchParam param, bool has_hmm_ats)
: cuts_{std::move(cuts)},
device_{device},
param_{std::move(param)},
has_hmm_ats_{has_hmm_ats} {}
has_hmm_ats_{has_hmm_ats},
ctx_{ctx} {}
[[nodiscard]] bool Read(EllpackPage* page, common::AlignedResourceReadStream* fi) override;
[[nodiscard]] std::size_t Write(EllpackPage const& page,
common::AlignedFileWriteStream* fo) override;
Expand Down
43 changes: 29 additions & 14 deletions src/data/ellpack_page_source.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2025, XGBoost contributors
* Copyright 2019-2026, XGBoost contributors
*/
#include <algorithm> // for max
#include <cstddef> // for size_t
Expand Down Expand Up @@ -145,9 +145,8 @@ class EllpackHostCacheStreamImpl {
ptr_ = k;
}

[[nodiscard]] bool Write(EllpackPage const& page) {
[[nodiscard]] bool Write(Context const* ctx, EllpackPage const& page) {
auto impl = page.Impl();
auto ctx = Context{}.MakeCUDA(dh::CurrentDevice());

this->cache_->sizes_orig.push_back(page.Impl()->MemCostBytes());
auto orig_ptr = this->cache_->sizes_orig.size() - 1;
Expand Down Expand Up @@ -219,10 +218,10 @@ class EllpackHostCacheStreamImpl {
dc::CuMemParams c_out;
std::size_t constexpr kChunkSize = 1ul << 21;
auto params = dc::CompressSnappy(
&ctx, old_impl->gidx_buffer.ToSpan().subspan(n_h_bytes, n_comp_bytes), &tmp, kChunkSize);
ctx, old_impl->gidx_buffer.ToSpan().subspan(n_h_bytes, n_comp_bytes), &tmp, kChunkSize);
common::RefResourceView<std::uint8_t> c_buf = dc::CoalesceCompressedBuffersToHost(
ctx.CUDACtx()->Stream(), this->cache_->pool, params, tmp, &c_out);
auto c_page = dc::MakeSnappyDecomprMgr(ctx.CUDACtx()->Stream(), this->cache_->pool,
ctx->CUDACtx()->Stream(), this->cache_->pool, params, tmp, &c_out);
auto c_page = dc::MakeSnappyDecomprMgr(ctx->CUDACtx()->Stream(), this->cache_->pool,
std::move(c_out), c_buf.ToSpan());
CHECK_EQ(c_page.DecompressedBytes() + new_impl->gidx_buffer.size_bytes(), n_bytes);

Expand Down Expand Up @@ -264,13 +263,13 @@ class EllpackHostCacheStreamImpl {
// Push a new page
auto n_bytes = this->cache_->buffer_bytes.at(this->cache_->h_pages.size());
auto n_samples = this->cache_->buffer_rows.at(this->cache_->h_pages.size());
auto new_impl = std::make_unique<EllpackPageImpl>(&ctx, impl->CutsShared(), impl->IsDense(),
auto new_impl = std::make_unique<EllpackPageImpl>(ctx, impl->CutsShared(), impl->IsDense(),
impl->info.row_stride, n_samples);
new_impl->SetBaseRowId(impl->base_rowid);
new_impl->SetNumSymbols(impl->NumSymbols());
new_impl->gidx_buffer =
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(&ctx, n_bytes, 0);
auto offset = new_impl->Copy(&ctx, impl, 0);
common::MakeFixedVecWithCudaMalloc<common::CompressedByteT>(ctx, n_bytes, 0);
auto offset = new_impl->Copy(ctx, impl, 0);

this->cache_->offsets.push_back(offset);

Expand All @@ -284,7 +283,7 @@ class EllpackHostCacheStreamImpl {
CHECK(!this->cache_->h_pages.empty());
CHECK_EQ(cache_idx, this->cache_->h_pages.size() - 1);
auto& new_impl = this->cache_->h_pages.back();
auto offset = new_impl->Copy(&ctx, impl, this->cache_->offsets.back());
auto offset = new_impl->Copy(ctx, impl, this->cache_->offsets.back());
this->cache_->offsets.back() += offset;
}

Expand Down Expand Up @@ -382,10 +381,24 @@ void EllpackHostCacheStream::Read(Context const* ctx, EllpackPage* page, bool pr
this->p_impl_->Read(ctx, page, prefetch_copy);
}

[[nodiscard]] bool EllpackHostCacheStream::Write(EllpackPage const& page) {
return this->p_impl_->Write(page);
[[nodiscard]] bool EllpackHostCacheStream::Write(Context const* ctx, EllpackPage const& page) {
return this->p_impl_->Write(ctx, page);
}

/**
* EllpackFormatPolicy
*/
template <typename S>
void EllpackFormatPolicy<S>::DestroyPage(std::shared_ptr<S>* page) const {
if (page && ctx_) {
ctx_->CUDACtx()->Stream().Sync();
}
page->reset();
}

template void EllpackFormatPolicy<EllpackPage>::DestroyPage(
std::shared_ptr<EllpackPage>* page) const;

/**
* EllpackCacheStreamPolicy
*/
Expand Down Expand Up @@ -528,13 +541,14 @@ void EllpackPageSourceImpl<F>::Fetch() {
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(this->Iter(), this->source_->Iter());
auto const& csr = this->source_->Page();
this->DestroyPage(&this->page_);
this->page_.reset(new EllpackPage{});
auto* impl = this->page_->Impl();
Context ctx = Context{}.MakeCUDA(this->Device().ordinal);
if (this->GetCuts()->HasCategorical()) {
CHECK(!this->feature_types_.empty());
}
*impl = EllpackPageImpl{&ctx, this->GetCuts(), *csr, is_dense_, row_stride_, feature_types_};
*impl =
EllpackPageImpl{this->Ctx(), this->GetCuts(), *csr, is_dense_, row_stride_, feature_types_};
this->page_->SetBaseRowId(csr->base_rowid);
LOG(INFO) << "Generated an Ellpack page with size: "
<< common::HumanMemUnit(impl->MemCostBytes())
Expand Down Expand Up @@ -573,6 +587,7 @@ void ExtEllpackPageSourceImpl<F>::Fetch() {
bst_idx_t row_stride = GetRowCounts(this->ctx_, value, row_counts_span,
dh::GetDevice(this->ctx_), this->missing_);
CHECK_LE(row_stride, this->ext_info_.row_stride);
this->DestroyPage(&this->page_);
this->page_.reset(new EllpackPage{});
*this->page_->Impl() = EllpackPageImpl{this->ctx_,
value,
Expand Down
23 changes: 16 additions & 7 deletions src/data/ellpack_page_source.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2019-2025, XGBoost Contributors
* Copyright 2019-2026, XGBoost Contributors
*/

#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
Expand Down Expand Up @@ -164,7 +164,7 @@ class EllpackHostCacheStream {
* @return Whether a new cache page is create. False if the new page is appended to the
* previous one.
*/
[[nodiscard]] bool Write(EllpackPage const& page);
[[nodiscard]] bool Write(Context const* ctx, EllpackPage const& page);
};

namespace detail {
Expand All @@ -177,6 +177,7 @@ class EllpackFormatPolicy {
std::shared_ptr<common::HistogramCuts const> cuts_{nullptr};
DeviceOrd device_;
bool has_hmm_{curt::SupportsPageableMem()};
Context const* ctx_{nullptr};

EllpackCacheInfo cache_info_;
static_assert(std::is_same_v<S, EllpackPage>);
Expand Down Expand Up @@ -214,11 +215,12 @@ class EllpackFormatPolicy {

[[nodiscard]] auto CreatePageFormat(BatchParam const& param) const {
CHECK_EQ(cuts_->cut_values_.Device(), device_);
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{cuts_, device_, param, has_hmm_}};
std::unique_ptr<FormatT> fmt{new EllpackPageRawFormat{ctx_, cuts_, device_, param, has_hmm_}};
return fmt;
}
void SetCuts(std::shared_ptr<common::HistogramCuts const> cuts, DeviceOrd device,
EllpackCacheInfo cinfo) {
void SetCuts(Context const* ctx, std::shared_ptr<common::HistogramCuts const> cuts,
DeviceOrd device, EllpackCacheInfo cinfo) {
this->ctx_ = ctx;
std::swap(this->cuts_, cuts);
this->device_ = device;
CHECK(this->device_.IsCUDA());
Expand All @@ -230,6 +232,8 @@ class EllpackFormatPolicy {
}
[[nodiscard]] auto Device() const { return this->device_; }
[[nodiscard]] auto const& CacheInfo() { return this->cache_info_; }
[[nodiscard]] auto Ctx() const { return this->ctx_; }
void DestroyPage(std::shared_ptr<S>* page) const;
};

template <typename S, template <typename> typename F>
Expand Down Expand Up @@ -311,7 +315,7 @@ class EllpackPageSourceImpl : public PageSourceIncMixIn<EllpackPage, F> {
feature_types_{feature_types} {
this->source_ = source;
cuts->SetDevice(ctx->Device());
this->SetCuts(std::move(cuts), ctx->Device(), cinfo);
this->SetCuts(ctx, std::move(cuts), ctx->Device(), cinfo);
this->Fetch();
}

Expand Down Expand Up @@ -353,7 +357,7 @@ class ExtEllpackPageSourceImpl : public ExtQantileSourceMixin<EllpackPage, Forma
info_{info},
ext_info_{std::move(ext_info)} {
cuts->SetDevice(ctx->Device());
this->SetCuts(std::move(cuts), ctx->Device(), cinfo);
this->SetCuts(ctx, std::move(cuts), ctx->Device(), cinfo);
CHECK(!this->cache_info_->written);
this->source_->Reset();
CHECK(this->source_->Next());
Expand Down Expand Up @@ -383,6 +387,11 @@ using ExtEllpackPageSource =
ExtEllpackPageSourceImpl<EllpackMmapStreamPolicy<EllpackPage, EllpackFormatPolicy>>;

#if !defined(XGBOOST_USE_CUDA)
template <typename S>
inline void EllpackFormatPolicy<S>::DestroyPage(std::shared_ptr<S>* page) const {
page->reset();
}

template <typename F>
inline void EllpackPageSourceImpl<F>::Fetch() {
// silent the warning about unused variables.
Expand Down
Loading
Loading