Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion datafusion/execution/src/memory_pool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,19 @@ pub trait MemoryPool: Send + Sync + std::fmt::Debug {
/// On error the `allocation` will not be increased in size
fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()>;

/// Attempt to reclaim `target_bytes` from existing spillable consumers already registered
/// with this pool.
///
/// `exclude_consumer_id`, when provided, identifies the current requester and should not be
/// reclaimed from to avoid re-entering the same operator while it is mid-allocation.
fn reclaim(
&self,
_target_bytes: usize,
_exclude_consumer_id: Option<usize>,
) -> Result<usize> {
Ok(0)
}

/// Return the total amount of memory reserved
fn reserved(&self) -> usize;

Expand Down Expand Up @@ -240,11 +253,22 @@ pub enum MemoryLimit {
/// For help with allocation accounting, see the [`proxy`] module.
///
/// [proxy]: datafusion_common::utils::proxy
#[derive(Debug)]
pub struct MemoryConsumer {
name: String,
can_spill: bool,
id: usize,
reclaimer: Option<Arc<dyn MemoryReclaimer>>,
}

impl std::fmt::Debug for MemoryConsumer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryConsumer")
.field("name", &self.name)
.field("can_spill", &self.can_spill)
.field("id", &self.id)
.field("has_reclaimer", &self.reclaimer.is_some())
.finish()
}
}

impl PartialEq for MemoryConsumer {
Expand Down Expand Up @@ -283,6 +307,7 @@ impl MemoryConsumer {
name: name.into(),
can_spill: false,
id: Self::new_unique_id(),
reclaimer: None,
}
}

Expand All @@ -294,6 +319,7 @@ impl MemoryConsumer {
name: self.name.clone(),
can_spill: self.can_spill,
id: Self::new_unique_id(),
reclaimer: self.reclaimer.clone(),
}
}

Expand All @@ -307,6 +333,15 @@ impl MemoryConsumer {
Self { can_spill, ..self }
}

/// Configure a callback that can reclaim memory from this consumer when another consumer in
/// the same pool is under pressure.
pub fn with_reclaimer(self, reclaimer: Arc<dyn MemoryReclaimer>) -> Self {
Self {
reclaimer: Some(reclaimer),
..self
}
}

/// Returns true if this allocation can spill to disk
pub fn can_spill(&self) -> bool {
self.can_spill
Expand All @@ -317,6 +352,11 @@ impl MemoryConsumer {
&self.name
}

/// Returns the reclaim callback registered for this consumer, if any.
pub fn reclaimer(&self) -> Option<Arc<dyn MemoryReclaimer>> {
self.reclaimer.clone()
}

/// Registers this [`MemoryConsumer`] with the provided [`MemoryPool`] returning
/// a [`MemoryReservation`] that can be used to grow or shrink the memory reservation
pub fn register(self, pool: &Arc<dyn MemoryPool>) -> MemoryReservation {
Expand All @@ -331,6 +371,12 @@ impl MemoryConsumer {
}
}

/// Callback implemented by spillable operators that can synchronously reclaim existing
/// reservations when another consumer in the same pool is under pressure.
pub trait MemoryReclaimer: Send + Sync {
fn reclaim(&self, target_bytes: usize) -> Result<usize>;
}

/// A registration of a [`MemoryConsumer`] with a [`MemoryPool`].
///
/// Calls [`MemoryPool::unregister`] on drop to return any memory to
Expand Down
157 changes: 155 additions & 2 deletions datafusion/execution/src/memory_pool/pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,16 @@
// under the License.

use crate::memory_pool::{
MemoryConsumer, MemoryLimit, MemoryPool, MemoryReservation, human_readable_size,
MemoryConsumer, MemoryLimit, MemoryPool, MemoryReclaimer, MemoryReservation,
human_readable_size,
};
use datafusion_common::HashMap;
use datafusion_common::{DataFusionError, Result, resources_datafusion_err};
use log::debug;
use parking_lot::Mutex;
use std::{
num::NonZeroUsize,
sync::Arc,
sync::atomic::{AtomicUsize, Ordering},
};

Expand Down Expand Up @@ -269,12 +271,24 @@ fn insufficient_capacity_err(
)
}

#[derive(Debug)]
struct TrackedConsumer {
name: String,
can_spill: bool,
reserved: AtomicUsize,
peak: AtomicUsize,
reclaimer: Option<Arc<dyn MemoryReclaimer>>,
}

impl std::fmt::Debug for TrackedConsumer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TrackedConsumer")
.field("name", &self.name)
.field("can_spill", &self.can_spill)
.field("reserved", &self.reserved())
.field("peak", &self.peak())
.field("has_reclaimer", &self.reclaimer.is_some())
.finish()
}
}

impl TrackedConsumer {
Expand Down Expand Up @@ -428,6 +442,7 @@ impl<I: MemoryPool> MemoryPool for TrackConsumersPool<I> {
can_spill: consumer.can_spill(),
reserved: Default::default(),
peak: Default::default(),
reclaimer: consumer.reclaimer(),
},
);

Expand Down Expand Up @@ -488,6 +503,50 @@ impl<I: MemoryPool> MemoryPool for TrackConsumersPool<I> {
Ok(())
}

fn reclaim(
&self,
target_bytes: usize,
exclude_consumer_id: Option<usize>,
) -> Result<usize> {
if target_bytes == 0 {
return Ok(0);
}

let mut candidates = self
.tracked_consumers
.lock()
.iter()
.filter_map(|(consumer_id, tracked_consumer)| {
let reserved = tracked_consumer.reserved();
let reclaimer = tracked_consumer.reclaimer.as_ref()?;
if exclude_consumer_id == Some(*consumer_id)
|| !tracked_consumer.can_spill
|| reserved == 0
{
return None;
}

Some((*consumer_id, reserved, Arc::clone(reclaimer)))
})
.collect::<Vec<_>>();
candidates.sort_by(
|(left_id, left_reserved, _), (right_id, right_reserved, _)| {
right_reserved
.cmp(left_reserved)
.then_with(|| left_id.cmp(right_id))
},
);

let mut reclaimed = 0;
for (_, _, reclaimer) in candidates {
if reclaimed >= target_bytes {
break;
}
reclaimed += reclaimer.reclaim(target_bytes - reclaimed)?;
}
Ok(reclaimed)
}

fn reserved(&self) -> usize {
self.inner.reserved()
}
Expand All @@ -513,6 +572,24 @@ mod tests {
use insta::{Settings, allow_duplicates, assert_snapshot};
use std::sync::Arc;

#[derive(Debug)]
struct TestReclaimer {
reservation: Arc<Mutex<Option<Arc<MemoryReservation>>>>,
}

impl MemoryReclaimer for TestReclaimer {
fn reclaim(&self, target_bytes: usize) -> Result<usize> {
let Some(reservation) = self.reservation.lock().clone() else {
return Ok(0);
};
let reclaimed = reservation.size().min(target_bytes);
if reclaimed > 0 {
reservation.shrink(reclaimed);
}
Ok(reclaimed)
}
}

fn make_settings() -> Settings {
let mut settings = Settings::clone_current();
settings.add_filter(
Expand Down Expand Up @@ -811,4 +888,80 @@ mod tests {
r1#[ID](can spill: false) consumed 20.0 B, peak 20.0 B.
");
}

#[test]
fn test_tracked_consumers_pool_reclaim_prefers_largest_consumer() {
let pool = Arc::new(TrackConsumersPool::new(
GreedyMemoryPool::new(200),
NonZeroUsize::new(3).unwrap(),
)) as Arc<dyn MemoryPool>;

let first_reservation_handle = Arc::new(Mutex::new(None));
let first = Arc::new(
MemoryConsumer::new("spillable-1")
.with_can_spill(true)
.with_reclaimer(Arc::new(TestReclaimer {
reservation: Arc::clone(&first_reservation_handle),
}))
.register(&pool),
);
*first_reservation_handle.lock() = Some(Arc::clone(&first));
first.grow(100);

let second_reservation_handle = Arc::new(Mutex::new(None));
let second = Arc::new(
MemoryConsumer::new("spillable-2")
.with_can_spill(true)
.with_reclaimer(Arc::new(TestReclaimer {
reservation: Arc::clone(&second_reservation_handle),
}))
.register(&pool),
);
*second_reservation_handle.lock() = Some(Arc::clone(&second));
second.grow(60);

let reclaimed = pool.reclaim(80, None).unwrap();

assert_eq!(reclaimed, 80);
assert_eq!(first.size(), 20);
assert_eq!(second.size(), 60);
}

#[test]
fn test_tracked_consumers_pool_reclaim_excludes_requester() {
let pool = Arc::new(TrackConsumersPool::new(
GreedyMemoryPool::new(200),
NonZeroUsize::new(3).unwrap(),
)) as Arc<dyn MemoryPool>;

let first_reservation_handle = Arc::new(Mutex::new(None));
let first = Arc::new(
MemoryConsumer::new("spillable-1")
.with_can_spill(true)
.with_reclaimer(Arc::new(TestReclaimer {
reservation: Arc::clone(&first_reservation_handle),
}))
.register(&pool),
);
*first_reservation_handle.lock() = Some(Arc::clone(&first));
first.grow(100);

let second_reservation_handle = Arc::new(Mutex::new(None));
let second = Arc::new(
MemoryConsumer::new("spillable-2")
.with_can_spill(true)
.with_reclaimer(Arc::new(TestReclaimer {
reservation: Arc::clone(&second_reservation_handle),
}))
.register(&pool),
);
*second_reservation_handle.lock() = Some(Arc::clone(&second));
second.grow(60);

let reclaimed = pool.reclaim(80, Some(first.consumer().id())).unwrap();

assert_eq!(reclaimed, 60);
assert_eq!(first.size(), 100);
assert_eq!(second.size(), 0);
}
}
Loading