diff --git a/baml_language/crates/bex_engine/src/lib.rs b/baml_language/crates/bex_engine/src/lib.rs index 6e9a9732f7..8665fc7174 100644 --- a/baml_language/crates/bex_engine/src/lib.rs +++ b/baml_language/crates/bex_engine/src/lib.rs @@ -82,6 +82,7 @@ use ::core::sync::atomic::AtomicBool; use async_trait::async_trait; use bex_events::{EventKind, FunctionEnd, FunctionEvent, FunctionStart, SpanContext}; pub use bex_events::{HostSpanContext, RuntimeEvent, SpanId}; +use bex_external_types::Handle; pub use bex_external_types::{BexExternalValue, EpochGuard, Ty, TypeName, UnionMetadata}; use bex_heap::BexHeap; // Re-export GcStats for users of the engine @@ -116,8 +117,12 @@ pub struct UserFunctionInfo { } /// Result of an external future. +/// +/// We carry a GC-stable handle rather than a raw `HeapPtr` because the VM can +/// hit a GC safepoint while awaiting the result. The handle resolves to the +/// future's current location when the async task completes. struct FutureResult { - id: HeapPtr, + id: Handle, result: Result, } @@ -195,6 +200,9 @@ pub enum EngineError { #[error("Future channel closed unexpectedly")] FutureChannelClosed, + #[error("Future handle {slab_key} became invalid during execution")] + FutureHandleInvalid { slab_key: usize }, + #[error("VM internal error: {0}")] VmInternalError(bex_vm::errors::VmInternalError), @@ -1185,6 +1193,17 @@ impl BexEngine { Ok(()) } + fn resolve_future_handle( + vm: &ActiveHeapPermit, + future: &Handle, + ) -> Result { + future + .object_ptr(&vm.epoch_guard()) + .ok_or(EngineError::FutureHandleInvalid { + slab_key: future.slab_key(), + }) + } + /// Drive the VM to completion, dispatching sys-ops, awaits, span /// notifications, and early-yield events. /// @@ -1362,6 +1381,7 @@ impl BexEngine { Self::cancellation_safepoint(cancel, &abort_handles)?; // Async operation — wrap in Abortable and spawn. + let future_handle = self.heap.create_handle(id); let pending_futures = pending_futures.clone(); let (abort_handle, abort_reg) = futures::future::AbortHandle::new_pair(); @@ -1369,7 +1389,7 @@ impl BexEngine { async move { let result = fut.await; let _ = pending_futures.send(FutureResult { - id, + id: future_handle, result: result.map_err(EngineError::from), }); }, @@ -1389,17 +1409,21 @@ impl BexEngine { } VmExecState::Await(future_id) => { + let awaited_future = self.heap.create_handle(future_id); Self::cancellation_safepoint(cancel, &abort_handles)?; vm = self.gc_safepoint(vm).await; // First, drain any already-completed futures. while let Ok(future) = processed_futures.try_recv() { + let future_ptr = Self::resolve_future_handle(&vm, &future.id)?; + let awaited_future_ptr = + Self::resolve_future_handle(&vm, &awaited_future)?; let external = future.result?; let value = self.convert_external_to_vm_value(&mut vm, external); - vm.fulfil_future(future.id, value) + vm.fulfil_future(future_ptr, value) .map_err(EngineError::VmInternalError)?; - if future.id == future_id { + if future_ptr == awaited_future_ptr { continue 'vm_exec; } } @@ -1419,15 +1443,18 @@ impl BexEngine { future = processed_futures.recv() => { let future = future .ok_or(EngineError::FutureChannelClosed)?; + let future_ptr = Self::resolve_future_handle(&vm, &future.id)?; + let awaited_future_ptr = + Self::resolve_future_handle(&vm, &awaited_future)?; let external = future.result?; let value = self.convert_external_to_vm_value( &mut vm, external, ); - vm.fulfil_future(future.id, value) + vm.fulfil_future(future_ptr, value) .map_err(EngineError::VmInternalError)?; - if future.id == future_id { + if future_ptr == awaited_future_ptr { break; } } diff --git a/baml_language/crates/bex_engine/tests/early_yield.rs b/baml_language/crates/bex_engine/tests/early_yield.rs index fef942b6e4..613fabd7a2 100644 --- a/baml_language/crates/bex_engine/tests/early_yield.rs +++ b/baml_language/crates/bex_engine/tests/early_yield.rs @@ -266,3 +266,34 @@ async fn objects_allocated_after_mid_call_gc_survive() { "post-GC allocations must not corrupt pre-GC ones, and vice versa" ); } + +/// Regression test for async future completion across a GC triggered at the +/// `Await` safepoint. The loop creates enough allocation pressure to force a +/// minor collection right before the engine blocks on `sleep()`. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn sleep_future_survives_gc_safepoint() { + let source = r#" + function alloc_then_sleep(n: int) -> int { + let i = 0; + while (i < n) { + let _ = [i, i + 1, i + 2]; + i += 1; + } + baml.sys.sleep(1); + 42 + } + "#; + + let engine = make_engine(source); + let value = engine + .call_function( + "alloc_then_sleep", + vec![BexExternalValue::Int(12_000)], + FunctionCallContextBuilder::new(sys_types::CallId::next()).build(), + true, + ) + .await + .expect("sleep future must survive the GC safepoint"); + + assert_eq!(value, BexExternalValue::Int(42)); +}