-
-
Notifications
You must be signed in to change notification settings - Fork 8.9k
Use context for CUDA external memory DMatrix. #12137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e2028d8
d9fbab9
7e8a6cc
115dea7
5fb7aeb
54dd6a7
cd03fbc
1242879
65e7486
101add0
63370db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,12 @@ | ||
| /** | ||
| * Copyright 2019-2025, XGBoost contributors | ||
| * Copyright 2019-2026, XGBoost contributors | ||
| */ | ||
| #include <dmlc/registry.h> | ||
|
|
||
| #include <cstddef> // for size_t | ||
| #include <vector> // for vector | ||
|
|
||
| #include "../common/cuda_rt_utils.h" | ||
| #include "../common/cuda_context.cuh" // for CUDAContext | ||
| #include "../common/cuda_stream.h" // for Event | ||
| #include "../common/io.h" // for AlignedResourceReadStream, AlignedFileWriteStream | ||
| #include "../common/ref_resource_view.cuh" // for MakeFixedVecWithCudaMalloc | ||
|
|
@@ -21,7 +21,7 @@ DMLC_REGISTRY_FILE_TAG(ellpack_page_raw_format); | |
| namespace { | ||
| // Function to support system without HMM or ATS | ||
| template <typename T> | ||
| [[nodiscard]] bool ReadDeviceVec(common::AlignedResourceReadStream* fi, | ||
| [[nodiscard]] bool ReadDeviceVec(Context const* ctx, common::AlignedResourceReadStream* fi, | ||
| common::RefResourceView<T>* vec) { | ||
| xgboost_NVTX_FN_RANGE(); | ||
|
|
||
|
|
@@ -42,7 +42,7 @@ template <typename T> | |
|
|
||
| *vec = common::MakeFixedVecWithCudaMalloc<T>(n); | ||
| dh::safe_cuda( | ||
| cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, curt::DefaultStream())); | ||
| cudaMemcpyAsync(vec->data(), ptr, n_bytes, cudaMemcpyDefault, ctx->CUDACtx()->Stream())); | ||
| return true; | ||
| } | ||
| } // namespace | ||
|
|
@@ -62,7 +62,7 @@ template <typename T> | |
| RET_IF_NOT(fi->Read(&impl->info.row_stride)); | ||
|
|
||
| if (this->param_.prefetch_copy || !has_hmm_ats_) { | ||
| RET_IF_NOT(ReadDeviceVec(fi, &impl->gidx_buffer)); | ||
| RET_IF_NOT(ReadDeviceVec(ctx_, fi, &impl->gidx_buffer)); | ||
| } else { | ||
|
Comment on lines
64
to
66
|
||
| RET_IF_NOT(common::ReadVec(fi, &impl->gidx_buffer)); | ||
| } | ||
|
|
@@ -73,7 +73,7 @@ template <typename T> | |
|
|
||
| impl->SetCuts(this->cuts_); | ||
|
|
||
| curt::DefaultStream().Sync(); | ||
| ctx_->CUDACtx()->Stream().Sync(); | ||
| return true; | ||
| } | ||
|
|
||
|
|
@@ -87,14 +87,13 @@ template <typename T> | |
| bytes += fo->Write(impl->is_dense); | ||
| bytes += fo->Write(impl->info.row_stride); | ||
| std::vector<common::CompressedByteT> h_gidx_buffer; | ||
| Context ctx = Context{}.MakeCUDA(curt::CurrentDevice()); | ||
| // write data into the h_gidx_buffer | ||
| [[maybe_unused]] auto h_accessor = impl->GetHostEllpack(&ctx, &h_gidx_buffer); | ||
| [[maybe_unused]] auto h_accessor = impl->GetHostEllpack(ctx_, &h_gidx_buffer); | ||
| bytes += common::WriteVec(fo, h_gidx_buffer); | ||
| bytes += fo->Write(impl->base_rowid); | ||
| bytes += fo->Write(impl->NumSymbols()); | ||
|
|
||
| curt::DefaultStream().Sync(); | ||
| ctx_->CUDACtx()->Stream().Sync(); | ||
| return bytes; | ||
| } | ||
|
|
||
|
|
@@ -104,21 +103,21 @@ template <typename T> | |
| auto* impl = page->Impl(); | ||
| CHECK(this->cuts_->cut_values_.DeviceCanRead()); | ||
|
|
||
| auto ctx = Context{}.MakeCUDA(curt::CurrentDevice()); | ||
| auto stream = ctx_->CUDACtx()->Stream(); | ||
|
|
||
| auto dispatch = [&] { | ||
| fi->Read(&ctx, page, this->param_.prefetch_copy || !this->has_hmm_ats_); | ||
| fi->Read(ctx_, page, this->param_.prefetch_copy || !this->has_hmm_ats_); | ||
| impl->SetCuts(this->cuts_); | ||
| }; | ||
|
|
||
| if (ConsoleLogger::GlobalVerbosity() == ConsoleLogger::LogVerbosity::kDebug) { | ||
| curt::Event start{false}, stop{false}; | ||
| float milliseconds = 0; | ||
| start.Record(ctx.CUDACtx()->Stream()); | ||
| start.Record(stream); | ||
|
|
||
| dispatch(); | ||
|
|
||
| stop.Record(ctx.CUDACtx()->Stream()); | ||
| stop.Record(stream); | ||
| stop.Sync(); | ||
| dh::safe_cuda(cudaEventElapsedTime(&milliseconds, start, stop)); | ||
| double n_bytes = page->Impl()->MemCostBytes(); | ||
|
|
@@ -128,7 +127,7 @@ template <typename T> | |
| dispatch(); | ||
| } | ||
|
|
||
| curt::DefaultStream().Sync(); | ||
| stream.Sync(); | ||
|
|
||
| return true; | ||
| } | ||
|
|
@@ -137,8 +136,8 @@ template <typename T> | |
| EllpackHostCacheStream* fo) const { | ||
| xgboost_NVTX_FN_RANGE_C(3, 252, 198); | ||
|
|
||
| bool new_page = fo->Write(page); | ||
| curt::DefaultStream().Sync(); | ||
| bool new_page = fo->Write(ctx_, page); | ||
| ctx_->CUDACtx()->Stream().Sync(); | ||
|
|
||
| if (new_page) { | ||
| auto cache = fo->Share(); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReadDeviceVecnow takes aContext const*and immediately usesctx->CUDACtx()->Stream()forcudaMemcpyAsync, but it never validates thatctxis non-null / CUDA, nor does it ensure the current CUDA device matchesctx's ordinal. This can lead to invalid stream/device usage when the caller's current device differs fromctx. Add aCHECK(ctx && ctx->IsCUDA())and set the device (e.g.curt::SetDevice(ctx->Ordinal())) before allocating/copying.