Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions baml_language/crates/bex_engine/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ use ::core::sync::atomic::AtomicBool;
use async_trait::async_trait;
use bex_events::{EventKind, FunctionEnd, FunctionEvent, FunctionStart, SpanContext};
pub use bex_events::{HostSpanContext, RuntimeEvent, SpanId};
use bex_external_types::Handle;
pub use bex_external_types::{BexExternalValue, EpochGuard, Ty, TypeName, UnionMetadata};
use bex_heap::BexHeap;
// Re-export GcStats for users of the engine
Expand Down Expand Up @@ -116,8 +117,12 @@ pub struct UserFunctionInfo {
}

/// Result of an external future.
///
/// We carry a GC-stable handle rather than a raw `HeapPtr` because the VM can
/// hit a GC safepoint while awaiting the result. The handle resolves to the
/// future's current location when the async task completes.
struct FutureResult {
id: HeapPtr,
id: Handle,
result: Result<BexExternalValue, EngineError>,
}

Expand Down Expand Up @@ -195,6 +200,9 @@ pub enum EngineError {
#[error("Future channel closed unexpectedly")]
FutureChannelClosed,

#[error("Future handle {slab_key} became invalid during execution")]
FutureHandleInvalid { slab_key: usize },

#[error("VM internal error: {0}")]
VmInternalError(bex_vm::errors::VmInternalError),

Expand Down Expand Up @@ -1185,6 +1193,17 @@ impl BexEngine {
Ok(())
}

fn resolve_future_handle(
vm: &ActiveHeapPermit<BexVm>,
future: &Handle,
) -> Result<HeapPtr, EngineError> {
future
.object_ptr(&vm.epoch_guard())
.ok_or(EngineError::FutureHandleInvalid {
slab_key: future.slab_key(),
})
}

/// Drive the VM to completion, dispatching sys-ops, awaits, span
/// notifications, and early-yield events.
///
Expand Down Expand Up @@ -1362,14 +1381,15 @@ impl BexEngine {
Self::cancellation_safepoint(cancel, &abort_handles)?;

// Async operation — wrap in Abortable and spawn.
let future_handle = self.heap.create_handle(id);
let pending_futures = pending_futures.clone();
let (abort_handle, abort_reg) =
futures::future::AbortHandle::new_pair();
let abortable = futures::future::Abortable::new(
async move {
let result = fut.await;
let _ = pending_futures.send(FutureResult {
id,
id: future_handle,
result: result.map_err(EngineError::from),
});
},
Expand All @@ -1389,17 +1409,21 @@ impl BexEngine {
}

VmExecState::Await(future_id) => {
let awaited_future = self.heap.create_handle(future_id);
Self::cancellation_safepoint(cancel, &abort_handles)?;
vm = self.gc_safepoint(vm).await;

// First, drain any already-completed futures.
while let Ok(future) = processed_futures.try_recv() {
let future_ptr = Self::resolve_future_handle(&vm, &future.id)?;
let awaited_future_ptr =
Self::resolve_future_handle(&vm, &awaited_future)?;
let external = future.result?;
let value = self.convert_external_to_vm_value(&mut vm, external);
vm.fulfil_future(future.id, value)
vm.fulfil_future(future_ptr, value)
.map_err(EngineError::VmInternalError)?;

if future.id == future_id {
if future_ptr == awaited_future_ptr {
continue 'vm_exec;
}
}
Expand All @@ -1419,15 +1443,18 @@ impl BexEngine {
future = processed_futures.recv() => {
let future = future
.ok_or(EngineError::FutureChannelClosed)?;
let future_ptr = Self::resolve_future_handle(&vm, &future.id)?;
let awaited_future_ptr =
Self::resolve_future_handle(&vm, &awaited_future)?;
let external = future.result?;
let value = self.convert_external_to_vm_value(
&mut vm,
external,
);
vm.fulfil_future(future.id, value)
vm.fulfil_future(future_ptr, value)
.map_err(EngineError::VmInternalError)?;

if future.id == future_id {
if future_ptr == awaited_future_ptr {
break;
}
}
Expand Down
31 changes: 31 additions & 0 deletions baml_language/crates/bex_engine/tests/early_yield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,34 @@ async fn objects_allocated_after_mid_call_gc_survive() {
"post-GC allocations must not corrupt pre-GC ones, and vice versa"
);
}

/// Regression test for async future completion across a GC triggered at the
/// `Await` safepoint. The loop creates enough allocation pressure to force a
/// minor collection right before the engine blocks on `sleep()`.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn sleep_future_survives_gc_safepoint() {
let source = r#"
function alloc_then_sleep(n: int) -> int {
let i = 0;
while (i < n) {
let _ = [i, i + 1, i + 2];
i += 1;
}
baml.sys.sleep(1);
42
}
"#;

let engine = make_engine(source);
let value = engine
.call_function(
"alloc_then_sleep",
vec![BexExternalValue::Int(12_000)],
FunctionCallContextBuilder::new(sys_types::CallId::next()).build(),
true,
)
.await
.expect("sleep future must survive the GC safepoint");

assert_eq!(value, BexExternalValue::Int(42));
}