diff --git a/qdp/DEVELOPMENT.md b/qdp/DEVELOPMENT.md index fc927022d3..42ad61c4b4 100644 --- a/qdp/DEVELOPMENT.md +++ b/qdp/DEVELOPMENT.md @@ -84,6 +84,12 @@ cargo test --workspace cd .. ``` +**Encoding / pipeline dtype:** `qdp_core::Encoding::supports_f32` gates whether +`PipelineConfig::normalize()` keeps `dtype = Float32` for the synthetic pipeline. It reflects +**which encoders implement `encode_batch_f32` today** (currently amplitude only), not every +encoding that might eventually get a batch f32 path. When angle/basis gain real batch f32 +support, widen `supports_f32` and adjust tests accordingly. + Run Python tests: ```bash diff --git a/qdp/qdp-core/src/encoding/mod.rs b/qdp/qdp-core/src/encoding/mod.rs index d795ca4a75..2d09b51460 100644 --- a/qdp/qdp-core/src/encoding/mod.rs +++ b/qdp/qdp-core/src/encoding/mod.rs @@ -63,6 +63,7 @@ use crate::dlpack::DLManagedTensor; use crate::gpu::PipelineContext; use crate::gpu::memory::{GpuStateVector, PinnedHostBuffer}; use crate::reader::StreamingDataReader; +use crate::types::Encoding; use crate::{MahoutError, QdpEngine, Result}; /// 512MB staging buffer for large Parquet row groups (reduces fragmentation) @@ -370,22 +371,23 @@ pub(crate) fn encode_from_parquet( num_qubits: usize, encoding_method: &str, ) -> Result<*mut DLManagedTensor> { - match encoding_method { - "amplitude" => { + let encoding = Encoding::from_str_ci(encoding_method)?; + match encoding { + Encoding::Amplitude => { crate::profile_scope!("Mahout::EncodeAmplitudeFromParquet"); stream_encode(engine, path, num_qubits, amplitude::AmplitudeEncoder) } - "angle" => { + Encoding::Angle => { crate::profile_scope!("Mahout::EncodeAngleFromParquet"); stream_encode(engine, path, num_qubits, angle::AngleEncoder) } - "basis" => { + Encoding::Basis => { crate::profile_scope!("Mahout::EncodeBasisFromParquet"); stream_encode(engine, path, num_qubits, basis::BasisEncoder) } _ => Err(MahoutError::NotImplemented(format!( "Encoding method '{}' not supported for streaming", - encoding_method + encoding.as_str() ))), } } diff --git a/qdp/qdp-core/src/gpu/encodings/iqp.rs b/qdp/qdp-core/src/gpu/encodings/iqp.rs index c6ecf17624..33d18cfaf0 100644 --- a/qdp/qdp-core/src/gpu/encodings/iqp.rs +++ b/qdp/qdp-core/src/gpu/encodings/iqp.rs @@ -23,6 +23,7 @@ use crate::error::{MahoutError, Result}; use crate::gpu::memory::{GpuStateVector, Precision}; use cudarc::driver::CudaDevice; use std::sync::Arc; +use std::sync::OnceLock; #[cfg(target_os = "linux")] use crate::gpu::memory::map_allocation_error; @@ -405,3 +406,18 @@ impl QuantumEncoder for IqpEncoder { } } } + +static IQP_FULL: OnceLock = OnceLock::new(); +static IQP_Z_ONLY: OnceLock = OnceLock::new(); + +/// Shared `'static` IQP encoder (full ZZ). Used by [`crate::Encoding::encoder`](crate::Encoding::encoder). +#[must_use] +pub fn iqp_full_encoder() -> &'static IqpEncoder { + IQP_FULL.get_or_init(IqpEncoder::full) +} + +/// Shared `'static` IQP-Z encoder. Used by [`crate::Encoding::encoder`](crate::Encoding::encoder). +#[must_use] +pub fn iqp_z_encoder() -> &'static IqpEncoder { + IQP_Z_ONLY.get_or_init(IqpEncoder::z_only) +} diff --git a/qdp/qdp-core/src/gpu/encodings/mod.rs b/qdp/qdp-core/src/gpu/encodings/mod.rs index fa6362d4cf..3f256e68a2 100644 --- a/qdp/qdp-core/src/gpu/encodings/mod.rs +++ b/qdp/qdp-core/src/gpu/encodings/mod.rs @@ -58,7 +58,7 @@ pub fn validate_qubit_count(num_qubits: usize) -> Result<()> { /// Quantum encoding strategy interface /// Implemented by: AmplitudeEncoder, AngleEncoder, BasisEncoder -pub trait QuantumEncoder: Send + Sync { +pub trait QuantumEncoder: Send + Sync + 'static { /// Encode classical data to quantum state on GPU fn encode( &self, @@ -181,21 +181,5 @@ pub mod phase; pub use amplitude::AmplitudeEncoder; pub use angle::AngleEncoder; pub use basis::BasisEncoder; -pub use iqp::IqpEncoder; +pub use iqp::{IqpEncoder, iqp_full_encoder, iqp_z_encoder}; pub use phase::PhaseEncoder; - -/// Create encoder by name: "amplitude", "angle", "basis", "iqp", or "iqp-z" -pub fn get_encoder(name: &str) -> Result> { - match name.to_lowercase().as_str() { - "amplitude" => Ok(Box::new(AmplitudeEncoder)), - "angle" => Ok(Box::new(AngleEncoder)), - "basis" => Ok(Box::new(BasisEncoder)), - "iqp" => Ok(Box::new(IqpEncoder::full())), - "iqp-z" => Ok(Box::new(IqpEncoder::z_only())), - "phase" => Ok(Box::new(PhaseEncoder)), - _ => Err(crate::error::MahoutError::InvalidInput(format!( - "Unknown encoder: {}. Available: amplitude, angle, basis, iqp, iqp-z, phase", - name - ))), - } -} diff --git a/qdp/qdp-core/src/gpu/mod.rs b/qdp/qdp-core/src/gpu/mod.rs index 7e16be7be3..73c4d46285 100644 --- a/qdp/qdp-core/src/gpu/mod.rs +++ b/qdp/qdp-core/src/gpu/mod.rs @@ -31,7 +31,7 @@ pub(crate) mod cuda_ffi; #[cfg(target_os = "linux")] pub use buffer_pool::{PinnedBufferHandle, PinnedBufferPool}; -pub use encodings::{AmplitudeEncoder, AngleEncoder, BasisEncoder, QuantumEncoder, get_encoder}; +pub use encodings::{AmplitudeEncoder, AngleEncoder, BasisEncoder, QuantumEncoder}; pub use memory::GpuStateVector; pub use pipeline::run_dual_stream_pipeline; diff --git a/qdp/qdp-core/src/lib.rs b/qdp/qdp-core/src/lib.rs index 3db9accdc1..2c1d1c5f49 100644 --- a/qdp/qdp-core/src/lib.rs +++ b/qdp/qdp-core/src/lib.rs @@ -31,12 +31,14 @@ pub mod readers; #[cfg(feature = "remote-io")] pub mod remote; pub mod tf_proto; +pub mod types; #[macro_use] mod profiling; pub use error::{MahoutError, Result, cuda_error_to_string}; pub use gpu::memory::Precision; pub use reader::{NullHandling, handle_float64_nulls}; +pub use types::{Dtype, Encoding}; // Throughput/latency pipeline runner: single path using QdpEngine and encode_batch in Rust. #[cfg(target_os = "linux")] @@ -52,7 +54,6 @@ use std::ffi::c_void; use std::sync::Arc; use crate::dlpack::DLManagedTensor; -use crate::gpu::get_encoder; use cudarc::driver::CudaDevice; #[cfg(target_os = "linux")] @@ -160,7 +161,8 @@ impl QdpEngine { ) -> Result<*mut DLManagedTensor> { crate::profile_scope!("Mahout::Encode"); - let encoder = get_encoder(encoding_method)?; + let encoding = Encoding::from_str_ci(encoding_method)?; + let encoder = encoding.encoder(); let state_vector = encoder.encode(&self.device, data, num_qubits)?; let state_vector = state_vector.to_precision(&self.device, self.precision)?; let dlpack_ptr = { @@ -205,10 +207,23 @@ impl QdpEngine { sample_size: usize, num_qubits: usize, encoding_method: &str, + ) -> Result<*mut DLManagedTensor> { + let encoding = Encoding::from_str_ci(encoding_method)?; + self.encode_batch_for_pipeline(batch_data, num_samples, sample_size, num_qubits, encoding) + } + + /// Same as [`encode_batch`](Self::encode_batch) with a resolved [`Encoding`] (no string parse). + pub(crate) fn encode_batch_for_pipeline( + &self, + batch_data: &[f64], + num_samples: usize, + sample_size: usize, + num_qubits: usize, + encoding: Encoding, ) -> Result<*mut DLManagedTensor> { crate::profile_scope!("Mahout::EncodeBatch"); - let encoder = get_encoder(encoding_method)?; + let encoder = encoding.encoder(); let state_vector = encoder.encode_batch( &self.device, batch_data, @@ -230,10 +245,29 @@ impl QdpEngine { sample_size: usize, num_qubits: usize, encoding_method: &str, + ) -> Result<*mut DLManagedTensor> { + let encoding = Encoding::from_str_ci(encoding_method)?; + self.encode_batch_f32_for_pipeline( + batch_data, + num_samples, + sample_size, + num_qubits, + encoding, + ) + } + + /// Same as [`encode_batch_f32`](Self::encode_batch_f32) with a resolved [`Encoding`]. + pub(crate) fn encode_batch_f32_for_pipeline( + &self, + batch_data: &[f32], + num_samples: usize, + sample_size: usize, + num_qubits: usize, + encoding: Encoding, ) -> Result<*mut DLManagedTensor> { crate::profile_scope!("Mahout::EncodeBatchF32"); - let encoder = get_encoder(encoding_method)?; + let encoder = encoding.encoder(); let state_vector = encoder.encode_batch_f32( &self.device, batch_data, @@ -263,8 +297,9 @@ impl QdpEngine { encoding_method: &str, ) -> Result<()> { crate::profile_scope!("Mahout::RunDualStreamEncode"); - match encoding_method.to_lowercase().as_str() { - "amplitude" => { + let encoding = Encoding::from_str_ci(encoding_method)?; + match encoding { + Encoding::Amplitude => { gpu::encodings::amplitude::AmplitudeEncoder::run_amplitude_dual_stream_pipeline( &self.device, host_data, @@ -273,7 +308,7 @@ impl QdpEngine { } _ => Err(MahoutError::InvalidInput(format!( "run_dual_stream_encode supports only 'amplitude' for now, got '{}'", - encoding_method + encoding.as_str() ))), } } @@ -507,7 +542,8 @@ impl QdpEngine { validate_cuda_input_ptr(&self.device, input_d)?; - let encoder = get_encoder(encoding_method)?; + let encoding = Encoding::from_str_ci(encoding_method)?; + let encoder = encoding.encoder(); let state_vector = unsafe { encoder.encode_from_gpu_ptr(&self.device, input_d, input_len, num_qubits, stream) }?; @@ -841,7 +877,8 @@ impl QdpEngine { validate_cuda_input_ptr(&self.device, input_batch_d)?; - let encoder = get_encoder(encoding_method)?; + let encoding = Encoding::from_str_ci(encoding_method)?; + let encoder = encoding.encoder(); let batch_state_vector = unsafe { encoder.encode_batch_from_gpu_ptr( &self.device, diff --git a/qdp/qdp-core/src/pipeline_runner.rs b/qdp/qdp-core/src/pipeline_runner.rs index 42bb5cc655..49da3964f1 100644 --- a/qdp/qdp-core/src/pipeline_runner.rs +++ b/qdp/qdp-core/src/pipeline_runner.rs @@ -24,9 +24,11 @@ use std::time::Instant; use crate::QdpEngine; use crate::dlpack::DLManagedTensor; use crate::error::{MahoutError, Result}; +use crate::gpu::memory::Precision; use crate::io; use crate::reader::{NullHandling, StreamingDataReader}; use crate::readers::ParquetStreamingReader; +use crate::types::Encoding; /// Configuration for throughput/latency pipeline runs (Python run_throughput_pipeline_py). #[derive(Clone, Debug)] @@ -35,24 +37,31 @@ pub struct PipelineConfig { pub num_qubits: u32, pub batch_size: usize, pub total_batches: usize, - pub encoding_method: String, + pub encoding: Encoding, pub seed: Option, pub warmup_batches: usize, pub null_handling: NullHandling, - pub float32_pipeline: bool, + /// Pipeline element dtype for synthetic batch fill and `encode_batch` dispatch. + /// + /// If [`Encoding::supports_f32`](crate::types::Encoding::supports_f32) is false for the + /// chosen [`encoding`](PipelineConfig::encoding), [`normalize`](PipelineConfig::normalize) + /// downgrades this to [`Precision::Float64`] (see `types` module docs: batch f32 is wired + /// only for encodings with a real `encode_batch_f32` today). + pub dtype: Precision, pub prefetch_depth: usize, } impl PipelineConfig { - /// Normalizes the configuration, such as falling back to f64 if f32 is requested - /// but the encoding doesn't support it. + /// Normalizes the configuration: if `dtype` is float32 but the encoding cannot use the + /// f32 batch encode path ([`Encoding::supports_f32`](crate::types::Encoding::supports_f32)), + /// falls back to float64. pub fn normalize(&mut self) { - if self.float32_pipeline && !encoding_supports_f32(&self.encoding_method) { + if matches!(self.dtype, Precision::Float32) && !self.encoding.supports_f32() { log::info!( - "float32_pipeline requested but encoding '{}' does not support f32; falling back to f64", - self.encoding_method + "float32 pipeline requested but encoding '{}' does not support f32; falling back to f64", + self.encoding.as_str() ); - self.float32_pipeline = false; + self.dtype = Precision::Float64; } } } @@ -64,11 +73,11 @@ impl Default for PipelineConfig { num_qubits: 16, batch_size: 64, total_batches: 100, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, seed: None, warmup_batches: 0, null_handling: NullHandling::FillZero, - float32_pipeline: false, + dtype: Precision::Float64, prefetch_depth: 16, } } @@ -99,12 +108,6 @@ pub trait BatchProducer: Send + 'static { fn produce(&mut self, recycled: Option) -> Result>; } -/// Returns true if the given encoding method has a native f32 GPU kernel. -/// Used to auto-gate `float32_pipeline` so unsupported encodings fall back to f64. -fn encoding_supports_f32(encoding_method: &str) -> bool { - matches!(encoding_method.to_lowercase().as_str(), "amplitude") -} - pub struct SyntheticProducer { pub config: PipelineConfig, pub vector_len: usize, @@ -131,16 +134,16 @@ impl BatchProducer for SyntheticProducer { } let mut data = match recycled { - Some(BatchData::F32(mut buf)) if self.config.float32_pipeline => { + Some(BatchData::F32(mut buf)) if matches!(self.config.dtype, Precision::Float32) => { buf.resize(self.config.batch_size * self.vector_len, 0.0); BatchData::F32(buf) } - Some(BatchData::F64(mut buf)) if !self.config.float32_pipeline => { + Some(BatchData::F64(mut buf)) if matches!(self.config.dtype, Precision::Float64) => { buf.resize(self.config.batch_size * self.vector_len, 0.0); BatchData::F64(buf) } _ => { - if self.config.float32_pipeline { + if matches!(self.config.dtype, Precision::Float32) { BatchData::F32(vec![0.0f32; self.config.batch_size * self.vector_len]) } else { BatchData::F64(vec![0.0f64; self.config.batch_size * self.vector_len]) @@ -366,7 +369,7 @@ pub struct PipelineIterator { impl PipelineIterator { pub fn new_synthetic(engine: QdpEngine, mut config: PipelineConfig) -> Result { config.normalize(); - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); let producer = SyntheticProducer::new(config.clone(), vector_len); let prefetch_depth = config.prefetch_depth; let (rx, recycle_tx, _producer_handle) = spawn_producer(producer, prefetch_depth)?; @@ -393,13 +396,16 @@ impl PipelineIterator { config.normalize(); let path = path.as_ref(); let (data, num_samples, sample_size) = read_file_by_extension(path, config.null_handling)?; - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); // Dimension validation at construction. if sample_size != vector_len { return Err(MahoutError::InvalidInput(format!( "File feature length {} does not match vector_len {} for num_qubits={}, encoding={}", - sample_size, vector_len, config.num_qubits, config.encoding_method + sample_size, + vector_len, + config.num_qubits, + config.encoding.as_str() ))); } if data.len() != num_samples * sample_size { @@ -454,7 +460,7 @@ impl PipelineIterator { Some(DEFAULT_PARQUET_ROW_GROUP_SIZE), config.null_handling, )?; - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); // Read first chunk to learn sample_size; reuse as initial buffer. const INITIAL_CHUNK_CAP: usize = 64 * 1024; @@ -474,7 +480,10 @@ impl PipelineIterator { if sample_size != vector_len { return Err(MahoutError::InvalidInput(format!( "File feature length {} does not match vector_len {} for num_qubits={}, encoding={}", - sample_size, vector_len, config.num_qubits, config.encoding_method + sample_size, + vector_len, + config.num_qubits, + config.encoding.as_str() ))); } @@ -511,19 +520,19 @@ impl PipelineIterator { Err(_) => return Ok(None), }; let ptr = match &batch.data { - BatchData::F64(buf) => self.engine.encode_batch( + BatchData::F64(buf) => self.engine.encode_batch_for_pipeline( buf, batch.batch_n, batch.sample_size, batch.num_qubits, - &self.config.encoding_method, + self.config.encoding, )?, - BatchData::F32(buf) => self.engine.encode_batch_f32( + BatchData::F32(buf) => self.engine.encode_batch_f32_for_pipeline( buf, batch.batch_n, batch.sample_size, batch.num_qubits, - &self.config.encoding_method, + self.config.encoding, )?, }; let _ = self.recycle_tx.lock().unwrap().send(batch.data); @@ -532,38 +541,33 @@ impl PipelineIterator { } /// Vector length per sample for given encoding (used by pipeline and iterator). -pub fn vector_len(num_qubits: u32, encoding_method: &str) -> usize { - let n = num_qubits as usize; - match encoding_method.to_lowercase().as_str() { - "angle" => n, - "basis" => 1, - _ => 1 << n, // amplitude - } +pub fn vector_len(num_qubits: u32, encoding: Encoding) -> usize { + encoding.vector_len(num_qubits) } /// Deterministic sample generation matching Python utils.build_sample (amplitude/angle/basis). -fn fill_sample(seed: u64, out: &mut [f64], encoding_method: &str, num_qubits: usize) -> Result<()> { +fn fill_sample(seed: u64, out: &mut [f64], encoding: Encoding, num_qubits: usize) -> Result<()> { let len = out.len(); if len == 0 { return Ok(()); } - match encoding_method.to_lowercase().as_str() { - "basis" => { + match encoding { + Encoding::Basis => { // For basis encoding, use 2^num_qubits as the state space size for mask calculation let state_space_size = 1 << num_qubits; let mask = (state_space_size - 1) as u64; let idx = seed & mask; out[0] = idx as f64; } - "angle" => { + Encoding::Angle => { let scale = (2.0 * PI) / len as f64; for (i, v) in out.iter_mut().enumerate() { let mixed = (i as u64 + seed) % (len as u64); *v = mixed as f64 * scale; } } - _ => { - // amplitude + Encoding::Amplitude | Encoding::Iqp | Encoding::IqpZ | Encoding::Phase => { + // amplitude-like synthetic pattern let mask = (len - 1) as u64; let scale = 1.0 / len as f64; for (i, v) in out.iter_mut().enumerate() { @@ -600,7 +604,7 @@ fn fill_batch_inplace( let _ = fill_sample( seed_base + i as u64, &mut batch_buf[offset..offset + vector_len], - &config.encoding_method, + config.encoding, config.num_qubits as usize, ); } @@ -610,29 +614,28 @@ fn fill_batch_inplace( fn fill_sample_f32( seed: u64, out: &mut [f32], - encoding_method: &str, + encoding: Encoding, num_qubits: usize, ) -> Result<()> { let len = out.len(); if len == 0 { return Ok(()); } - match encoding_method.to_lowercase().as_str() { - "basis" => { + match encoding { + Encoding::Basis => { let state_space_size = 1 << num_qubits; let mask = (state_space_size - 1) as u64; let idx = seed & mask; out[0] = idx as f32; } - "angle" => { + Encoding::Angle => { let scale = (2.0 * std::f32::consts::PI) / len as f32; for (i, v) in out.iter_mut().enumerate() { let mixed = (i as u64 + seed) % (len as u64); *v = mixed as f32 * scale; } } - _ => { - // amplitude + Encoding::Amplitude | Encoding::Iqp | Encoding::IqpZ | Encoding::Phase => { let mask = (len - 1) as u64; let scale = 1.0 / len as f32; for (i, v) in out.iter_mut().enumerate() { @@ -660,7 +663,7 @@ fn fill_batch_inplace_f32( let _ = fill_sample_f32( seed_base + i as u64, &mut batch_buf[offset..offset + vector_len], - &config.encoding_method, + config.encoding, config.num_qubits as usize, ); } @@ -683,20 +686,20 @@ pub fn run_throughput_pipeline(config: &PipelineConfig) -> Result Result Result engine.encode_batch( + BatchData::F64(buf) => engine.encode_batch_for_pipeline( buf, batch.batch_n, batch.sample_size, batch.num_qubits, - &config.encoding_method, + config.encoding, )?, - BatchData::F32(buf) => engine.encode_batch_f32( + BatchData::F32(buf) => engine.encode_batch_f32_for_pipeline( buf, batch.batch_n, batch.sample_size, batch.num_qubits, - &config.encoding_method, + config.encoding, )?, }; unsafe { release_dlpack(ptr) }; @@ -781,12 +784,12 @@ mod tests { let config = PipelineConfig { num_qubits: 5, batch_size: 8, - encoding_method: encoding_method.to_string(), + encoding: Encoding::from_str_ci(encoding_method).unwrap(), seed: Some(123), ..Default::default() }; - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); // Test edge cases: 0 and batch_size-1 for batch_idx in [0, config.batch_size - 1, 7] { @@ -802,12 +805,12 @@ mod tests { let config = PipelineConfig { num_qubits: 5, batch_size: 8, - encoding_method: encoding_method.to_string(), + encoding: Encoding::from_str_ci(encoding_method).unwrap(), seed: Some(123), ..Default::default() }; - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); let batch0 = generate_batch(&config, 0, vector_len); let batch1 = generate_batch(&config, 1, vector_len); @@ -849,12 +852,12 @@ mod tests { let config = PipelineConfig { num_qubits: 5, batch_size: 8, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, seed: None, ..Default::default() }; - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); let batch = generate_batch(&config, 0, vector_len); assert_eq!(batch.len(), config.batch_size * vector_len); @@ -868,12 +871,12 @@ mod tests { let config = PipelineConfig { num_qubits: 5, batch_size: 1, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, seed: Some(123), ..Default::default() }; - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); let batch = generate_batch(&config, 0, vector_len); assert_eq!(batch.len(), vector_len); @@ -891,7 +894,7 @@ mod tests { let config_lower = PipelineConfig { num_qubits: 5, batch_size: 8, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, seed: Some(123), ..Default::default() }; @@ -899,12 +902,12 @@ mod tests { let config_upper = PipelineConfig { num_qubits: 5, batch_size: 8, - encoding_method: "AMPLITUDE".to_string(), + encoding: Encoding::from_str_ci("AMPLITUDE").unwrap(), seed: Some(123), ..Default::default() }; - let vector_len = vector_len(config_lower.num_qubits, &config_lower.encoding_method); + let vector_len = vector_len(config_lower.num_qubits, config_lower.encoding); let batch_lower = generate_batch(&config_lower, 0, vector_len); let batch_upper = generate_batch(&config_upper, 0, vector_len); assert_eq!(batch_lower, batch_upper); @@ -915,12 +918,12 @@ mod tests { let config = PipelineConfig { num_qubits: 5, batch_size: 8, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, seed: Some(123), ..Default::default() }; - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); for batch_idx in 0..5 { let batch = generate_batch(&config, batch_idx, vector_len); @@ -940,12 +943,12 @@ mod tests { let config = PipelineConfig { num_qubits: 5, batch_size: 8, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, seed: None, ..Default::default() }; - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); let batch = generate_batch(&config, 0, vector_len); for &value in &batch { @@ -962,12 +965,12 @@ mod tests { let config = PipelineConfig { num_qubits: 5, batch_size: 1, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, seed: Some(123), ..Default::default() }; - let vector_len = vector_len(config.num_qubits, &config.encoding_method); + let vector_len = vector_len(config.num_qubits, config.encoding); let batch = generate_batch(&config, 0, vector_len); for &value in &batch { @@ -984,10 +987,10 @@ mod tests { total_batches: 5, num_qubits: 3, batch_size: 4, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, ..Default::default() }; - let vector_len = super::vector_len(config.num_qubits, &config.encoding_method); + let vector_len = super::vector_len(config.num_qubits, config.encoding); let mut producer = SyntheticProducer::new(config, vector_len); let mut count = 0; @@ -1003,10 +1006,10 @@ mod tests { total_batches: 1, num_qubits: 3, batch_size: 4, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, ..Default::default() }; - let vector_len = super::vector_len(config.num_qubits, &config.encoding_method); + let vector_len = super::vector_len(config.num_qubits, config.encoding); let mut producer = SyntheticProducer::new(config.clone(), vector_len); let batch_from_producer = producer.produce(None).unwrap().unwrap(); @@ -1020,7 +1023,7 @@ mod tests { let config = PipelineConfig { batch_size: 5, num_qubits: 2, - encoding_method: "amplitude".to_string(), + encoding: Encoding::Amplitude, ..Default::default() }; let sample_size = 4; // 2^2 @@ -1050,7 +1053,7 @@ mod tests { prefetch_depth: 16, ..Default::default() }; - let vector_len = super::vector_len(config.num_qubits, &config.encoding_method); + let vector_len = super::vector_len(config.num_qubits, config.encoding); let producer = SyntheticProducer::new(config, vector_len); let (rx, _recycle_tx, handle) = spawn_producer(producer, 16).unwrap(); @@ -1072,7 +1075,7 @@ mod tests { prefetch_depth: 16, ..Default::default() }; - let vector_len = super::vector_len(config.num_qubits, &config.encoding_method); + let vector_len = super::vector_len(config.num_qubits, config.encoding); let producer = SyntheticProducer::new(config, vector_len); let (rx, _recycle_tx, handle) = spawn_producer(producer, 16).unwrap(); @@ -1093,18 +1096,18 @@ mod tests { total_batches: 2, num_qubits: 3, batch_size: 4, - encoding_method: "amplitude".to_string(), - float32_pipeline: true, + encoding: Encoding::Amplitude, + dtype: Precision::Float32, ..Default::default() }; config.normalize(); - let vector_len = super::vector_len(config.num_qubits, &config.encoding_method); + let vector_len = super::vector_len(config.num_qubits, config.encoding); let mut producer = SyntheticProducer::new(config, vector_len); let batch = producer.produce(None).unwrap().unwrap(); assert!( matches!(batch.data, BatchData::F32(_)), - "amplitude with float32_pipeline=true should produce F32 data" + "amplitude with dtype=Float32 should produce F32 data" ); // Verify data is non-zero (was actually filled) @@ -1122,18 +1125,18 @@ mod tests { total_batches: 1, num_qubits: 3, batch_size: 4, - encoding_method: "angle".to_string(), - float32_pipeline: true, // requested f32, but angle doesn't support it + encoding: Encoding::Angle, + dtype: Precision::Float32, // requested f32, but angle doesn't support native f32 batch path ..Default::default() }; config.normalize(); - let vector_len = super::vector_len(config.num_qubits, &config.encoding_method); + let vector_len = super::vector_len(config.num_qubits, config.encoding); let mut producer = SyntheticProducer::new(config, vector_len); let batch = producer.produce(None).unwrap().unwrap(); assert!( matches!(batch.data, BatchData::F64(_)), - "angle with float32_pipeline=true should fall back to F64 data" + "angle with requested Float32 should fall back to F64 batch data (no encode_batch_f32 yet)" ); } @@ -1143,28 +1146,28 @@ mod tests { total_batches: 1, num_qubits: 3, batch_size: 4, - encoding_method: "basis".to_string(), - float32_pipeline: true, + encoding: Encoding::Basis, + dtype: Precision::Float32, ..Default::default() }; config.normalize(); - let vector_len = super::vector_len(config.num_qubits, &config.encoding_method); + let vector_len = super::vector_len(config.num_qubits, config.encoding); let mut producer = SyntheticProducer::new(config, vector_len); let batch = producer.produce(None).unwrap().unwrap(); assert!( matches!(batch.data, BatchData::F64(_)), - "basis with float32_pipeline=true should fall back to F64 data" + "basis with requested Float32 should fall back to F64 batch data (no encode_batch_f32 yet)" ); } #[test] fn test_encoding_supports_f32() { - assert!(super::encoding_supports_f32("amplitude")); - assert!(super::encoding_supports_f32("Amplitude")); - assert!(super::encoding_supports_f32("AMPLITUDE")); - assert!(!super::encoding_supports_f32("angle")); - assert!(!super::encoding_supports_f32("basis")); - assert!(!super::encoding_supports_f32("iqp")); + assert!(Encoding::Amplitude.supports_f32()); + assert!(Encoding::from_str_ci("Amplitude").unwrap().supports_f32()); + assert!(Encoding::from_str_ci("AMPLITUDE").unwrap().supports_f32()); + assert!(!Encoding::Angle.supports_f32()); + assert!(!Encoding::Basis.supports_f32()); + assert!(!Encoding::Iqp.supports_f32()); } } diff --git a/qdp/qdp-core/src/types.rs b/qdp/qdp-core/src/types.rs new file mode 100644 index 0000000000..f8a98834b1 --- /dev/null +++ b/qdp/qdp-core/src/types.rs @@ -0,0 +1,160 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Canonical domain types for encodings and element dtypes (`Dtype`). +//! +//! ## `Encoding::supports_f32` +//! +//! A future shape of this API may return true for amplitude, angle, and basis once each encoder +//! has a batch float32 GPU path. **Today only amplitude implements** +//! [`QuantumEncoder::encode_batch_f32`] for the synthetic prefetch pipeline, so +//! [`Encoding::supports_f32`](Encoding::supports_f32) stays amplitude-only and +//! [`crate::pipeline_runner::PipelineConfig::normalize`] avoids routing other encodings through +//! `encode_batch_f32`. Widen this method when angle/basis gain real `encode_batch_f32` +//! implementations. + +use crate::error::{MahoutError, Result}; +use crate::gpu::encodings::{ + AmplitudeEncoder, AngleEncoder, BasisEncoder, PhaseEncoder, QuantumEncoder, iqp_full_encoder, + iqp_z_encoder, +}; + +/// Dtype for pipeline configuration (re-export of [`crate::gpu::memory::Precision`]). +pub use crate::gpu::memory::Precision as Dtype; + +impl crate::gpu::memory::Precision { + /// Parse dtype from a short user string (case-insensitive, trimmed). + pub fn from_str_ci(s: &str) -> Result { + let t = s.trim(); + if t.eq_ignore_ascii_case("f32") + || t.eq_ignore_ascii_case("float32") + || t.eq_ignore_ascii_case("float") + { + Ok(Self::Float32) + } else if t.eq_ignore_ascii_case("f64") + || t.eq_ignore_ascii_case("float64") + || t.eq_ignore_ascii_case("double") + { + Ok(Self::Float64) + } else { + Err(MahoutError::InvalidInput(format!( + "Unknown dtype: {s}. Use 'f32' or 'f64'." + ))) + } + } + + /// Element size in bytes for real scalar components (f32/f64). + #[must_use] + pub const fn bytes(self) -> usize { + match self { + Self::Float32 => 4, + Self::Float64 => 8, + } + } +} + +/// Quantum encoding method (canonical; parse user strings once at API boundaries). +#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] +pub enum Encoding { + Amplitude, + Angle, + Basis, + Iqp, + IqpZ, + Phase, +} + +impl Encoding { + /// Parse encoding name (case-insensitive ASCII, stack buffer; no heap allocation). + pub fn from_str_ci(s: &str) -> Result { + let mut buf = [0u8; 16]; + let bytes = s.as_bytes(); + if bytes.len() > buf.len() { + return Err(MahoutError::InvalidInput(format!( + "Unknown encoding: {s}. Available: amplitude, angle, basis, iqp, iqp-z, phase" + ))); + } + for (i, b) in bytes.iter().enumerate() { + buf[i] = b.to_ascii_lowercase(); + } + match &buf[..bytes.len()] { + b"amplitude" => Ok(Self::Amplitude), + b"angle" => Ok(Self::Angle), + b"basis" => Ok(Self::Basis), + b"iqp" => Ok(Self::Iqp), + b"iqp-z" => Ok(Self::IqpZ), + b"phase" => Ok(Self::Phase), + _ => Err(MahoutError::InvalidInput(format!( + "Unknown encoding: {s}. Available: amplitude, angle, basis, iqp, iqp-z, phase" + ))), + } + } + + #[must_use] + pub const fn as_str(self) -> &'static str { + match self { + Self::Amplitude => "amplitude", + Self::Angle => "angle", + Self::Basis => "basis", + Self::Iqp => "iqp", + Self::IqpZ => "iqp-z", + Self::Phase => "phase", + } + } + + /// Input feature dimension per sample for this encoding and qubit count. + /// + /// Matches each encoder's `expected_data_len` / `sample_size` contract: + /// - `Amplitude`: full state vector (`2^n`) + /// - `Angle` / `IqpZ` / `Phase`: one value per qubit (`n`) + /// - `Iqp`: single-qubit + pairwise ZZ terms (`n + n*(n-1)/2`) + /// - `Basis`: single integer index (`1`) + #[must_use] + pub const fn vector_len(self, num_qubits: u32) -> usize { + let n = num_qubits as usize; + match self { + Self::Amplitude => 1 << n, + Self::Angle | Self::IqpZ | Self::Phase => n, + Self::Iqp => n + n * n.saturating_sub(1) / 2, + Self::Basis => 1, + } + } + + /// Whether the **synthetic batch pipeline** may keep [`crate::gpu::memory::Precision::Float32`] + /// end-to-end (prefetched host `Vec` plus [`crate::QdpEngine::encode_batch_f32`]). + /// + /// This must match encoders that actually implement [`QuantumEncoder::encode_batch_f32`]. + /// Long-term design may include angle/basis here; today only amplitude does, so angle/basis + /// still normalize to `Float64` in [`crate::pipeline_runner::PipelineConfig::normalize`] + /// until their batch f32 GPU paths exist in the encoder implementations. + #[must_use] + pub const fn supports_f32(self) -> bool { + matches!(self, Self::Amplitude) + } + + /// Static encoder dispatch (no per-call heap allocation). + #[must_use] + pub fn encoder(self) -> &'static dyn QuantumEncoder { + match self { + Self::Amplitude => &AmplitudeEncoder, + Self::Angle => &AngleEncoder, + Self::Basis => &BasisEncoder, + Self::Iqp => iqp_full_encoder(), + Self::IqpZ => iqp_z_encoder(), + Self::Phase => &PhaseEncoder, + } + } +} diff --git a/qdp/qdp-core/tests/gpu_angle_encoding.rs b/qdp/qdp-core/tests/gpu_angle_encoding.rs index c66e5eda60..e8a4e0ceeb 100644 --- a/qdp/qdp-core/tests/gpu_angle_encoding.rs +++ b/qdp/qdp-core/tests/gpu_angle_encoding.rs @@ -114,6 +114,28 @@ fn test_angle_infinity_rejected() { // ---- Successful encoding (kernel launch path) ---- +/// Regression: streaming Parquet path accepts mixed-case encoding names via `Encoding::from_str_ci`. +#[test] +fn test_angle_parquet_encoding_case_insensitive() { + let Some(engine) = common::qdp_engine() else { + return; + }; + + let num_qubits = 2; + let data: Vec = vec![0.1, 0.2]; + let path = "/tmp/test_angle_case.parquet"; + common::write_fixed_size_list_parquet(path, &data, num_qubits); + + let dlpack_ptr = engine + .encode_from_parquet(path, num_qubits, "Angle") + .expect("mixed-case 'Angle' should match streaming angle encoder"); + let _ = std::fs::remove_file(path); + + unsafe { + common::assert_dlpack_shape_2d_and_delete(dlpack_ptr, 1, (1 << num_qubits) as i64); + } +} + #[test] fn test_angle_successful_encoding_from_parquet() { let Some(engine) = common::qdp_engine() else { diff --git a/qdp/qdp-core/tests/gpu_iqp_encoding.rs b/qdp/qdp-core/tests/gpu_iqp_encoding.rs index f45ba3eac0..4954ab5b38 100644 --- a/qdp/qdp-core/tests/gpu_iqp_encoding.rs +++ b/qdp/qdp-core/tests/gpu_iqp_encoding.rs @@ -797,7 +797,7 @@ fn test_iqp_fwt_zero_parameters_identity() { #[test] #[cfg(target_os = "linux")] fn test_iqp_encoder_via_factory() { - println!("Testing IQP encoder creation via get_encoder..."); + println!("Testing IQP encoder creation via Encoding::from_str_ci / encode..."); let Some(engine) = common::qdp_engine() else { println!("SKIP: No GPU available"); @@ -836,7 +836,7 @@ fn test_iqp_encoder_via_factory() { #[test] #[cfg(target_os = "linux")] fn test_iqp_z_encoder_via_factory() { - println!("Testing IQP-Z encoder creation via get_encoder..."); + println!("Testing IQP-Z encoder creation via encode..."); let Some(engine) = common::qdp_engine() else { println!("SKIP: No GPU available"); diff --git a/qdp/qdp-core/tests/gpu_validation.rs b/qdp/qdp-core/tests/gpu_validation.rs index 3235b24fa9..291f92dce8 100644 --- a/qdp/qdp-core/tests/gpu_validation.rs +++ b/qdp/qdp-core/tests/gpu_validation.rs @@ -38,8 +38,8 @@ fn test_input_validation_invalid_strategy() { match result { Err(MahoutError::InvalidInput(msg)) => { assert!( - msg.contains("Unknown encoder"), - "Error message should mention unknown encoder" + msg.contains("Unknown encoding"), + "Error message should mention unknown encoding" ); println!("PASS: Correctly rejected invalid strategy: {}", msg); } diff --git a/qdp/qdp-core/tests/types.rs b/qdp/qdp-core/tests/types.rs new file mode 100644 index 0000000000..f0ccd97534 --- /dev/null +++ b/qdp/qdp-core/tests/types.rs @@ -0,0 +1,63 @@ +// +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Tests for [`qdp_core::Encoding`] and [`qdp_core::Dtype`]. + +use qdp_core::{Dtype, Encoding}; + +#[test] +fn encoding_case_insensitive() { + assert_eq!( + Encoding::from_str_ci("Amplitude").unwrap(), + Encoding::Amplitude + ); + assert_eq!( + Encoding::from_str_ci("AMPLITUDE").unwrap(), + Encoding::Amplitude + ); + assert_eq!(Encoding::from_str_ci("iqp-z").unwrap(), Encoding::IqpZ); +} + +#[test] +fn encoding_unknown_returns_err() { + assert!(Encoding::from_str_ci("not_real").is_err()); +} + +#[test] +fn vector_len_matches_encoder_contracts() { + let n = 5u32; + assert_eq!(Encoding::Amplitude.vector_len(n), 32); // 2^5 + assert_eq!(Encoding::Angle.vector_len(n), 5); // n + assert_eq!(Encoding::IqpZ.vector_len(n), 5); // n (z-only) + assert_eq!(Encoding::Phase.vector_len(n), 5); // n (one angle per qubit) + assert_eq!(Encoding::Iqp.vector_len(n), 5 + 5 * 4 / 2); // n + n*(n-1)/2 = 15 + assert_eq!(Encoding::Basis.vector_len(n), 1); +} + +#[test] +fn static_encoder_same_instance_across_calls() { + assert!( + std::ptr::eq(Encoding::Amplitude.encoder(), Encoding::Amplitude.encoder(),), + "static dispatch must return the same 'static reference" + ); +} + +#[test] +fn dtype_from_str_ci() { + assert_eq!(Dtype::from_str_ci("f32").unwrap(), Dtype::Float32); + assert_eq!(Dtype::from_str_ci("Float64").unwrap(), Dtype::Float64); + assert!(Dtype::from_str_ci("bf16").is_err()); +} diff --git a/qdp/qdp-python/README.md b/qdp/qdp-python/README.md index ee70834ad4..074fd75c11 100644 --- a/qdp/qdp-python/README.md +++ b/qdp/qdp-python/README.md @@ -36,6 +36,16 @@ print(tensor) # Complex tensor on CUDA | `basis` | Encode integer as computational basis state | | `iqp` | IQP-style encoding with entanglement | +### Pipeline / loader dtype (Rust internals) + +`QuantumDataLoader` and `run_throughput_pipeline` build a Rust `PipelineConfig` with an +`encoding` plus a `dtype` (float32 vs float64). The prefetch thread can only keep an +end-to-end **float32 host batch** for encodings whose GPU stack implements the batch **f32** +path (`encode_batch_f32`). **Today that is amplitude only.** Angle and basis still fall back +to float64 for that loop until their batch f32 implementations exist. The eventual full +matrix (e.g. angle/basis under `supports_f32` once kernels are wired) is broader than what +the pipeline uses today. + ## Input Sources ```python diff --git a/qdp/qdp-python/qumat_qdp/loader.py b/qdp/qdp-python/qumat_qdp/loader.py index 9a180baf00..d6f58f0de1 100644 --- a/qdp/qdp-python/qumat_qdp/loader.py +++ b/qdp/qdp-python/qumat_qdp/loader.py @@ -41,6 +41,11 @@ # Seed must fit Rust u64: 0 <= seed <= 2^64 - 1. _U64_MAX = 2**64 - 1 +# Canonical encoding names (must match Encoding enum in qdp-core/src/types.rs). +_VALID_ENCODINGS: frozenset[str] = frozenset( + {"amplitude", "angle", "basis", "iqp", "iqp-z", "phase"} +) + # Fallback-supported file extensions (loadable without _qdp). _TORCH_FILE_EXTS = frozenset({".pt", ".pth"}) _NUMPY_FILE_EXTS = frozenset({".npy"}) @@ -71,6 +76,11 @@ def _validate_loader_args( raise ValueError( f"encoding_method must be a non-empty string, got {encoding_method!r}" ) + if encoding_method.lower() not in _VALID_ENCODINGS: + raise ValueError( + f"Unknown encoding_method {encoding_method!r}. " + f"Valid options: {sorted(_VALID_ENCODINGS)}" + ) if seed is not None: if not isinstance(seed, int): raise ValueError( @@ -162,6 +172,11 @@ def encoding(self, method: str) -> QuantumDataLoader: raise ValueError( f"encoding_method must be a non-empty string, got {method!r}" ) + if method.lower() not in _VALID_ENCODINGS: + raise ValueError( + f"Unknown encoding {method!r}. " + f"Valid options: {sorted(_VALID_ENCODINGS)}" + ) self._encoding_method = method return self diff --git a/qdp/qdp-python/src/engine.rs b/qdp/qdp-python/src/engine.rs index b2b006ff78..92fcbc8d17 100644 --- a/qdp/qdp-python/src/engine.rs +++ b/qdp/qdp-python/src/engine.rs @@ -23,7 +23,7 @@ use crate::tensor::QuantumTensor; use numpy::{PyReadonlyArray1, PyReadonlyArray2, PyUntypedArrayMethods}; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; -use qdp_core::{Precision, QdpEngine as CoreEngine}; +use qdp_core::{Dtype, Encoding, QdpEngine as CoreEngine}; #[cfg(target_os = "linux")] use crate::loader::{PyQuantumLoader, config_from_args, parse_null_handling, path_from_py}; @@ -52,16 +52,8 @@ impl QdpEngine { #[new] #[pyo3(signature = (device_id=0, precision="float32"))] fn new(device_id: usize, precision: &str) -> PyResult { - let precision = match precision.to_ascii_lowercase().as_str() { - "float32" | "f32" | "float" => Precision::Float32, - "float64" | "f64" | "double" => Precision::Float64, - other => { - return Err(PyRuntimeError::new_err(format!( - "Unsupported precision '{}'. Use 'float32' (default) or 'float64'.", - other - ))); - } - }; + let precision = + Dtype::from_str_ci(precision).map_err(|e| PyRuntimeError::new_err(e.to_string()))?; let engine = CoreEngine::new_with_precision(device_id, precision) .map_err(|e| PyRuntimeError::new_err(format!("Failed to initialize: {}", e)))?; @@ -485,17 +477,18 @@ impl QdpEngine { num_qubits: usize, encoding_method: &str, ) -> PyResult { - validate_cuda_tensor_for_encoding(data, self.engine.device().ordinal(), encoding_method)?; - + let encoding = validate_cuda_tensor_for_encoding( + data, + self.engine.device().ordinal(), + encoding_method, + )?; let dtype = data.getattr("dtype")?; let dtype_str: String = dtype.str()?.extract()?; - let dtype_str_lower = dtype_str.to_ascii_lowercase(); - let is_f32 = dtype_str_lower.contains("float32"); - let method = encoding_method.to_ascii_lowercase(); + let is_f32 = dtype_str.to_ascii_lowercase().contains("float32"); let ndim: usize = data.call_method0("dim")?.extract()?; let tensor_info = extract_cuda_tensor_info(data)?; - if method.as_str() == "amplitude" && is_f32 { + if encoding == Encoding::Amplitude && is_f32 { match ndim { 1 => { let input_len: usize = data.call_method0("numel")?.extract()?; @@ -634,7 +627,7 @@ impl QdpEngine { seed, nh, true, - ); + )?; let iter = qdp_core::PipelineIterator::new_synthetic(self.engine.clone(), config).map_err( |e| PyRuntimeError::new_err(format!("create_synthetic_loader failed: {}", e)), )?; @@ -667,7 +660,7 @@ impl QdpEngine { None, nh, true, // float32_pipeline - ); + )?; let engine = self.engine.clone(); // Resolve remote URLs before detaching from GIL. The _resolved guard keeps the // temp file alive until after the file is fully read inside py.detach. @@ -716,7 +709,7 @@ impl QdpEngine { None, nh, true, // float32_pipeline - ); + )?; let engine = self.engine.clone(); // Resolve remote URLs before detaching from GIL. The _resolved guard keeps the // temp file alive; the streaming reader's open fd preserves data after drop. diff --git a/qdp/qdp-python/src/lib.rs b/qdp/qdp-python/src/lib.rs index 5348c3f4af..082437aa23 100644 --- a/qdp/qdp-python/src/lib.rs +++ b/qdp/qdp-python/src/lib.rs @@ -49,11 +49,16 @@ fn run_throughput_pipeline_py( num_qubits, batch_size, total_batches, - encoding_method, + encoding: qdp_core::Encoding::from_str_ci(&encoding_method) + .map_err(|e| PyRuntimeError::new_err(format!("Invalid encoding_method: {e}")))?, seed, warmup_batches, null_handling: qdp_core::NullHandling::default(), - float32_pipeline, + dtype: if float32_pipeline { + qdp_core::Precision::Float32 + } else { + qdp_core::Precision::Float64 + }, prefetch_depth: 16, }; let result = py diff --git a/qdp/qdp-python/src/loader.rs b/qdp/qdp-python/src/loader.rs index 7ad7632cb9..32531b8311 100644 --- a/qdp/qdp-python/src/loader.rs +++ b/qdp/qdp-python/src/loader.rs @@ -21,7 +21,9 @@ mod loader_impl { use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use qdp_core::reader::NullHandling; - use qdp_core::{PipelineConfig, PipelineIterator, QdpEngine as CoreEngine}; + use qdp_core::{ + Encoding, PipelineConfig, PipelineIterator, Precision, QdpEngine as CoreEngine, + }; /// Rust-backed iterator yielding one QuantumTensor per batch; used by QuantumDataLoader. #[pyclass] @@ -94,19 +96,25 @@ mod loader_impl { seed: Option, null_handling: NullHandling, float32_pipeline: bool, - ) -> PipelineConfig { - PipelineConfig { + ) -> PyResult { + let encoding = Encoding::from_str_ci(encoding_method) + .map_err(|e| PyRuntimeError::new_err(format!("Invalid encoding: {e}")))?; + Ok(PipelineConfig { device_id: 0, num_qubits, batch_size, total_batches, - encoding_method: encoding_method.to_string(), + encoding, seed, warmup_batches: 0, null_handling, - float32_pipeline, + dtype: if float32_pipeline { + Precision::Float32 + } else { + Precision::Float64 + }, prefetch_depth: 16, - } + }) } /// Resolve path from Python str or pathlib.Path (__fspath__). diff --git a/qdp/qdp-python/src/pytorch.rs b/qdp/qdp-python/src/pytorch.rs index 17185fd548..ae8caaee2e 100644 --- a/qdp/qdp-python/src/pytorch.rs +++ b/qdp/qdp-python/src/pytorch.rs @@ -18,7 +18,8 @@ use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; use std::ffi::c_void; -use crate::constants::{CUDA_ENCODING_METHODS, format_supported_cuda_encoding_methods}; +use crate::constants::format_supported_cuda_encoding_methods; +use qdp_core::Encoding; /// Helper to detect PyTorch tensor pub fn is_pytorch_tensor(obj: &Bound<'_, PyAny>) -> PyResult { @@ -143,16 +144,20 @@ pub fn get_torch_cuda_stream_ptr(tensor: &Bound<'_, PyAny>) -> PyResult<*mut c_v }) } -/// Validate a CUDA tensor for direct GPU encoding -/// Checks: dtype matches encoding method, contiguous, non-empty, device_id matches engine +/// Validate a CUDA tensor for direct GPU encoding and return the parsed `Encoding`. +/// +/// Checks dtype compatibility, contiguity, non-empty, and device match. +/// Returns the parsed `Encoding` so the caller avoids re-parsing the same string. pub fn validate_cuda_tensor_for_encoding( tensor: &Bound<'_, PyAny>, expected_device_id: usize, encoding_method: &str, -) -> PyResult<()> { - let method = encoding_method.to_ascii_lowercase(); +) -> PyResult { + let encoding = Encoding::from_str_ci(encoding_method) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; - if !CUDA_ENCODING_METHODS.contains(&method.as_str()) { + // Phase has no CUDA tensor path yet. + if matches!(encoding, Encoding::Phase) { return Err(PyRuntimeError::new_err(format!( "CUDA tensor encoding currently only supports {} methods, got '{}'. \ Use tensor.cpu() to convert to CPU tensor for other encoding methods.", @@ -161,12 +166,11 @@ pub fn validate_cuda_tensor_for_encoding( ))); } - // Check encoding method support and dtype (ASCII lowercase for case-insensitive match). let dtype = tensor.getattr("dtype")?; let dtype_str: String = dtype.str()?.extract()?; let dtype_str_lower = dtype_str.to_ascii_lowercase(); - match method.as_str() { - "amplitude" => { + match encoding { + Encoding::Amplitude => { if !(dtype_str_lower.contains("float64") || dtype_str_lower.contains("float32")) { return Err(PyRuntimeError::new_err(format!( "CUDA tensor must have dtype float64 or float32 for amplitude encoding, got {}. \ @@ -175,16 +179,17 @@ pub fn validate_cuda_tensor_for_encoding( ))); } } - "angle" | "iqp" | "iqp-z" => { + Encoding::Angle | Encoding::Iqp | Encoding::IqpZ => { if !dtype_str_lower.contains("float64") { return Err(PyRuntimeError::new_err(format!( "CUDA tensor must have dtype float64 for {} encoding, got {}. \ Use tensor.to(torch.float64)", - method, dtype_str + encoding.as_str(), + dtype_str ))); } } - "basis" => { + Encoding::Basis => { if !dtype_str_lower.contains("int64") { return Err(PyRuntimeError::new_err(format!( "CUDA tensor must have dtype int64 for basis encoding, got {}. \ @@ -193,12 +198,7 @@ pub fn validate_cuda_tensor_for_encoding( ))); } } - _ => { - return Err(PyRuntimeError::new_err(format!( - "Internal error: missing CUDA validation branch for supported method '{}'", - method - ))); - } + Encoding::Phase => unreachable!("Phase filtered above"), } // Check contiguous @@ -225,7 +225,7 @@ pub fn validate_cuda_tensor_for_encoding( ))); } - Ok(()) + Ok(encoding) } /// Minimal CUDA tensor metadata extracted via PyTorch APIs.