diff --git a/include/xgboost/cache.h b/include/xgboost/cache.h index 32e1b21ac3f6..c0a89aef949b 100644 --- a/include/xgboost/cache.h +++ b/include/xgboost/cache.h @@ -78,7 +78,8 @@ class DMatrixCache { auto p_fmat = queue_.front(); auto it = container_.find(p_fmat); CHECK(it != container_.cend()); - if (it->second.ref.expired()) { + // Re-new the cache if this has never been read. + if (it->second.ref.expired() || !it->second.ref.lock()->Info().HasBeenRead()) { expired.push_back(it->first); } else { remained.push(it->first); @@ -101,7 +102,7 @@ class DMatrixCache { void ClearExcess() { this->CheckConsistent(); - // clear half of the entries to prevent repeatingly clearing cache. + // clear half of the entries to prevent repeatedly clearing cache. std::size_t half_size = max_size_ / 2; while (queue_.size() >= half_size && !queue_.empty()) { auto p_fmat = queue_.front(); diff --git a/include/xgboost/data.h b/include/xgboost/data.h index d3bc2074ad15..e6b86b977e92 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -150,10 +150,10 @@ class MetaInfo { * \param fo The output stream. */ void SaveBinary(dmlc::Stream* fo) const; - /*! - * \brief Set information in the meta info with array interface. - * \param key The key of the information. - * \param interface_str String representation of json format array interface. + /** + * @brief Set data in the meta info with array interface. + * @param key The key of the information. + * @param interface_str String representation of json format array interface. */ void SetInfo(Context const& ctx, StringView key, StringView interface_str); @@ -218,6 +218,9 @@ class MetaInfo { * @brief Setter for categories. */ void Cats(std::shared_ptr cats); + // Flag to indicate whether one needs to refresh the DMatrix cache. + void SetReadFlag(bool has_been_read) { this->has_been_read_ = has_been_read; } + [[nodiscard]] bool HasBeenRead() const { return this->has_been_read_; } private: void SetInfoFromHost(Context const* ctx, StringView key, Json arr); @@ -226,6 +229,7 @@ class MetaInfo { /*! \brief argsort of labels */ mutable std::vector label_order_cache_; bool has_categorical_{false}; + bool has_been_read_{false}; std::shared_ptr cats_; }; @@ -740,6 +744,7 @@ class DMatrix { template <> inline BatchSet DMatrix::GetBatches() { + this->Info().SetReadFlag(true); return GetRowBatches(); } @@ -760,31 +765,37 @@ inline bool DMatrix::PageExists() const { template <> inline BatchSet DMatrix::GetBatches(Context const*) { + this->Info().SetReadFlag(true); return GetRowBatches(); } template <> inline BatchSet DMatrix::GetBatches(Context const* ctx) { + this->Info().SetReadFlag(true); return GetColumnBatches(ctx); } template <> inline BatchSet DMatrix::GetBatches(Context const* ctx) { + this->Info().SetReadFlag(true); return GetSortedColumnBatches(ctx); } template <> inline BatchSet DMatrix::GetBatches(Context const* ctx, BatchParam const& param) { + this->Info().SetReadFlag(true); return GetEllpackBatches(ctx, param); } template <> inline BatchSet DMatrix::GetBatches(Context const* ctx, BatchParam const& param) { + this->Info().SetReadFlag(true); return GetGradientIndex(ctx, param); } template <> inline BatchSet DMatrix::GetBatches(Context const* ctx, BatchParam const& param) { + this->Info().SetReadFlag(true); return GetExtBatches(ctx, param); } } // namespace xgboost diff --git a/src/data/data.cc b/src/data/data.cc index fa0545d09b08..57dd5e377d09 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -513,6 +513,8 @@ void CopyTensorInfoImpl(Context const* ctx, Json arr_interface, linalg::TensorSetReadFlag(false); + Json j_interface = Json::Load(interface_str); bool is_cuda{false}; if (IsA(j_interface)) {