diff --git a/crates/witness/src/lib.rs b/crates/witness/src/lib.rs index 3710134..d72cca6 100644 --- a/crates/witness/src/lib.rs +++ b/crates/witness/src/lib.rs @@ -9,6 +9,7 @@ use rayon::{ prelude::ParallelSliceMut, }; use std::{ + any::Any, ops::{Deref, DerefMut, Index}, slice::{Chunks, ChunksMut}, sync::Arc, @@ -33,7 +34,6 @@ pub enum InstancePaddingStrategy { Custom(Arc u64 + Send + Sync>), } -#[derive(Clone)] pub struct RowMajorMatrix { inner: p3::matrix::dense::RowMajorMatrix, // num_row is the real instance BEFORE padding @@ -41,9 +41,34 @@ pub struct RowMajorMatrix { log2_num_rotation: usize, is_padded: bool, padding_strategy: InstancePaddingStrategy, + // Optional opaque handle to device-resident storage that mirrors `inner.values`. + // This lets GPU-side code keep an associated buffer/layout without forcing witness + // to depend on a concrete device runtime. There is no automatic host<->device sync: + // host-side mutation invalidates this cache. + device_backing: Option, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum DeviceMatrixLayout { + /// Device buffer is laid out identically to `inner.values`. + RowMajor, + /// Device buffer stores the same logical matrix in column-major order. + ColMajor, +} + +#[derive(Clone)] +struct DeviceMatrixBacking { + // Type-erased device handle owned outside of witness. An `Arc` keeps clone-free + // sharing cheap for readers while the matrix itself remains host-data owned. + storage: Arc, + layout: DeviceMatrixLayout, } impl RowMajorMatrix { + fn invalidate_device_backing(&mut self) { + self.device_backing = None; + } + pub fn rand(rng: &mut R, rows: usize, cols: usize) -> Self where Standard: Distribution, @@ -56,6 +81,7 @@ impl RowMajorMat is_padded: true, log2_num_rotation: 0, padding_strategy: InstancePaddingStrategy::Default, + device_backing: None, } } pub fn empty() -> Self { @@ -65,6 +91,7 @@ impl RowMajorMat log2_num_rotation: 0, is_padded: true, padding_strategy: InstancePaddingStrategy::Default, + device_backing: None, } } @@ -130,6 +157,7 @@ impl RowMajorMat log2_num_rotation, is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default), padding_strategy, + device_backing: None, } } @@ -148,6 +176,7 @@ impl RowMajorMat log2_num_rotation: 0, is_padded: matches!(padding_strategy, InstancePaddingStrategy::Default), padding_strategy, + device_backing: None, } } @@ -166,6 +195,46 @@ impl RowMajorMat next_pow2_instance_padding(self.num_instances()) - self.num_instances() } + /// Attach opaque device-resident storage for callers that materialize this witness + /// on accelerators. Witness keeps only metadata here so GPU integrations can cache a + /// buffer next to the host matrix without introducing device-specific dependencies. + /// + /// There is no automatic host<->device synchronization. The backing is only valid + /// while the host-side matrix contents and shape remain unchanged. Any mutable access + /// to the matrix clears this metadata conservatively. + pub fn set_device_backing( + &mut self, + storage: D, + layout: DeviceMatrixLayout, + ) { + self.device_backing = Some(DeviceMatrixBacking { + storage: Arc::new(storage), + layout, + }); + } + + /// Explicitly drop any attached device metadata. + pub fn clear_device_backing(&mut self) { + self.invalidate_device_backing(); + } + + /// Whether this matrix currently has device metadata attached. + pub fn has_device_backing(&self) -> bool { + self.device_backing.is_some() + } + + /// Report how the attached device buffer is laid out, if present. + pub fn device_backing_layout(&self) -> Option { + self.device_backing.as_ref().map(|backing| backing.layout) + } + + /// Downcast the opaque device handle to the concrete type stored by the caller. + pub fn device_backing_ref(&self) -> Option<&D> { + self.device_backing + .as_ref() + .and_then(|backing| backing.storage.downcast_ref::()) + } + // return raw num_instances without rotation pub fn num_instances(&self) -> usize { self.num_rows @@ -182,18 +251,21 @@ impl RowMajorMat } pub fn iter_mut(&mut self) -> ChunksMut<'_, T> { + self.invalidate_device_backing(); let num_rotation = Self::num_rotation(self.log2_num_rotation); let max_range = self.num_instances() * num_rotation * self.n_col(); self.inner.values[..max_range].chunks_mut(num_rotation * self.inner.width) } pub fn par_batch_iter_mut(&mut self, num_rows: usize) -> rayon::slice::ChunksMut<'_, T> { + self.invalidate_device_backing(); let num_rotation = Self::num_rotation(self.log2_num_rotation); let max_range = self.num_instances() * self.n_col() * num_rotation; self.inner.values[..max_range].par_chunks_mut(num_rows * num_rotation * self.inner.width) } pub fn padding_by_strategy(&mut self) { + self.invalidate_device_backing(); let num_rotation = Self::num_rotation(self.log2_num_rotation); let start_index = self.num_instances() * num_rotation * self.n_col(); @@ -224,7 +296,8 @@ impl RowMajorMat pub fn pad_to_height(&mut self, new_height: usize, fill: T) { let (cur_height, n_cols) = (self.height(), self.n_col()); assert!(new_height >= cur_height); - self.values.par_extend( + self.invalidate_device_backing(); + self.inner.values.par_extend( (0..(new_height - cur_height) * n_cols) .into_par_iter() .map(|_| fill), @@ -232,6 +305,19 @@ impl RowMajorMat } } +impl Clone for RowMajorMatrix { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + num_rows: self.num_rows, + log2_num_rotation: self.log2_num_rotation, + is_padded: self.is_padded, + padding_strategy: self.padding_strategy.clone(), + device_backing: None, + } + } +} + impl RowMajorMatrix { pub fn to_mles<'a, E: ff_ext::ExtensionField>( &self, @@ -299,6 +385,7 @@ impl Deref for RowMajorMatrix DerefMut for RowMajorMatrix { fn deref_mut(&mut self) -> &mut Self::Target { + self.device_backing = None; &mut self.inner } } @@ -326,3 +413,37 @@ macro_rules! set_fixed_val { $ins[$field.0] = $val; }; } + +#[cfg(test)] +mod tests { + use super::{DeviceMatrixLayout, InstancePaddingStrategy, RowMajorMatrix}; + use p3::goldilocks::Goldilocks; + + #[test] + fn clone_clears_device_backing() { + let mut matrix = RowMajorMatrix::::new(2, 2, InstancePaddingStrategy::Default); + matrix.set_device_backing(vec![1_u8, 2, 3], DeviceMatrixLayout::RowMajor); + + let cloned = matrix.clone(); + + assert!(matrix.has_device_backing()); + assert!(!cloned.has_device_backing()); + } + + #[test] + fn mutable_access_invalidates_device_backing() { + let mut matrix = RowMajorMatrix::::new(2, 2, InstancePaddingStrategy::Default); + matrix.set_device_backing(vec![1_u8, 2, 3], DeviceMatrixLayout::RowMajor); + + let _ = matrix.iter_mut(); + assert!(!matrix.has_device_backing()); + + matrix.set_device_backing(vec![1_u8, 2, 3], DeviceMatrixLayout::RowMajor); + matrix.pad_to_height(4, Goldilocks::default()); + assert!(!matrix.has_device_backing()); + + matrix.set_device_backing(vec![1_u8, 2, 3], DeviceMatrixLayout::RowMajor); + let _ = &mut *matrix; + assert!(!matrix.has_device_backing()); + } +}