From 876b3dafdf3ebb8181a1c7c1b9af6b136efa7765 Mon Sep 17 00:00:00 2001 From: ajz34 Date: Sun, 7 Jun 2026 13:19:53 +0800 Subject: [PATCH] rstsr-common: support broadcasting matmul in col-major layout Replace the col-major branch of `LayoutMatMulConfig::layout_matmul` (which rejected rules 3-7) with a thin wrapper that delegates to the row-major routine via the identity C[t, m, n] = A[t, m, k] @ B[t, k, n] (row-major) <=> C[n, m, t] = B[n, k, t] @ A[k, m, t] (col-major) i.e. reverse all axes and swap A/B. The returned layouts are reversed and the A/B fields are swapped back so callers see the original operand order. This brings full broadcasting (rules 3-7) to col-major for free, matching the convention that `DeviceMatMulAPI` impls already use (see e.g. `device_faer/matmul.rs`, which reverses axes and swaps A/B before dispatching to the row-major kernel). Also: - Update the module-level docs to describe the col-major convention (matmul dims are the first two; trailing dims broadcast). - Replace the test that asserted col-major broadcasting fails with positive tests covering rules 3, 4 and 7 (incl. a `1` broadcast). - Update the device_faer matmul test to exercise col-major broadcasting end-to-end via `set_default_order(ColMajor)`. Co-authored-by: Claude Code Co-authored-by: glm-5.1 --- rstsr-common/src/layout/matmul.rs | 142 +++++++++++---------------- rstsr-core/src/device_faer/matmul.rs | 72 +++++++++++--- 2 files changed, 115 insertions(+), 99 deletions(-) diff --git a/rstsr-common/src/layout/matmul.rs b/rstsr-common/src/layout/matmul.rs index dc6df22..d6de124 100644 --- a/rstsr-common/src/layout/matmul.rs +++ b/rstsr-common/src/layout/matmul.rs @@ -6,7 +6,8 @@ Layout manuplication for matmul and other linalg operations We Refer to [Python array API](https://data-apis.org/array-api/2024.12/specification/generated/array_api.matmul.html) for more information. -Please note that the following rule only applies to row-major. +The rules below are written for row-major; the last two axes of each operand +are the matmul dimensions and any leading axes broadcast. | Id | A | B | C | |----|---|---|---| @@ -18,7 +19,12 @@ Please note that the following rule only applies to row-major. | 6. | `..., M, K` | ` K, N` | `..., M, N` | | 7. | `..., M, K` | `..., K, N` | `..., M, N` | -For col-major, only rule 1, 2, (part of) 3, (part of) 4 are valid. +For col-major, the same rules apply *with all axes reversed*: the matmul +dimensions are the first two of each operand and any trailing axes broadcast. +This is implemented by delegating to the row-major routine on reversed-and- +swapped inputs, using the identity +`C[t, m, n] = A[t, m, k] @ B[t, k, n]` (row-major) `==` +`C[n, m, t] = B[n, k, t] @ A[k, m, t]` (col-major). */ @@ -266,83 +272,37 @@ fn layout_matmul_dyn_row_major(la: &Layout, lb: &Layout) -> Result, lb: &Layout) -> Result> { + // For col-major, we re-use the row-major implementation via the identity + // C[t, m, n] = A[t, m, k] @ B[t, k, n] (row-major) + // <=> C[n, m, t] = B[n, k, t] @ A[k, m, t] (col-major) + // i.e. reverse all axes and swap A/B. So we delegate to + // `layout_matmul_dyn_row_major(lb_rev, la_rev)` and then reverse-axes (and + // swap A/B back) on every layout field that the row-major impl returns. + // + // Note that rules 5/6/7 (broadcasting matmul) are only supported by the + // row-major rules in the array API spec; for col-major we accept them + // here, but the corresponding `DeviceMatMulAPI` impl must follow the same + // reverse-axes-and-swap convention (see e.g. `device_faer/matmul.rs`). let na = la.ndim(); let nb = lb.ndim(); - match (na, nb) { - (1, 1) => { - // rule 1: vector inner dot - rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?; - let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) }; - Ok(LayoutMatMulConfig { - matmul_type: MatMulType::InnerDot, - lc: lc.clone(), - la_rest: None, - lb_rest: None, - lc_rest: None, - la_matmul: la.to_dim()?, - lb_matmul: lb.to_dim()?, - lc_matmul: lc.to_dim()?, - }) - }, - (2, 2) => { - // rule 2: matrix multiplication - // check and generate shape - rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?; - let sc = vec![la.shape()[0], lb.shape()[1]]; - // layout order determination - let lc = sc.f(); - // return layout configuration - Ok(LayoutMatMulConfig { - matmul_type: MatMulType::GEMM22, - lc: lc.clone(), - la_rest: None, - lb_rest: None, - lc_rest: None, - la_matmul: la.to_dim()?, - lb_matmul: lb.to_dim()?, - lc_matmul: lc.to_dim()?, - }) - }, - (1, 2) => { - // rule 3: | ` K` | ` K, N` | ` N` | - // check and generate shape - rstsr_assert_eq!(la.shape()[0], lb.shape()[0], InvalidLayout)?; - let sc = vec![lb.shape()[1]]; - let lc = sc.f(); - Ok(LayoutMatMulConfig { - matmul_type: MatMulType::GEVM, - lc: lc.to_dim()?, - la_rest: None, - lb_rest: None, - lc_rest: None, - la_matmul: la.to_dim()?, - lb_matmul: lb.to_dim()?, - lc_matmul: lc.to_dim()?, - }) - }, - (2, 1) => { - // rule 4: | ` M, K` | ` K` | ` M` | - // check and generate shape - rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?; - let sc = vec![la.shape()[0]]; - let lc = sc.f(); - // return layout configuration - Ok(LayoutMatMulConfig { - matmul_type: MatMulType::GEMV, - lc: lc.to_dim()?, - la_rest: None, - lb_rest: None, - lc_rest: None, - la_matmul: la.to_dim()?, - lb_matmul: lb.to_dim()?, - lc_matmul: lc.to_dim()?, - }) - }, - (1, 3..) | (3.., 1) | (2, 3..) | (3.., 2) | (3.., 3..) => { - rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.") - }, - (0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."), + if na == 0 || nb == 0 { + return rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."); } + let la_rev = la.reverse_axes(); + let lb_rev = lb.reverse_axes(); + let cfg = layout_matmul_dyn_row_major(&lb_rev, &la_rev)?; + Ok(LayoutMatMulConfig { + matmul_type: cfg.matmul_type, + lc: cfg.lc.reverse_axes(), + // row-major's `la_*` corresponds to (reversed) B, so it maps back to + // col-major's `lb_*` (after reversing axes again). + la_rest: cfg.lb_rest.map(|l| l.reverse_axes()), + lb_rest: cfg.la_rest.map(|l| l.reverse_axes()), + lc_rest: cfg.lc_rest.map(|l| l.reverse_axes()), + la_matmul: cfg.lb_matmul.reverse_axes(), + lb_matmul: cfg.la_matmul.reverse_axes(), + lc_matmul: cfg.lc_matmul.reverse_axes(), + }) } impl LayoutMatMulAPI for LayoutMatMulConfig { @@ -490,14 +450,30 @@ mod test_fixed { let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap(); assert_eq!(config.lc, [2, 3, 4, 5, 7].c()); - let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap(); - let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap(); - let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor); - assert!(config.is_err()); - - let la = [5, 6].c(); - let lb = [6, 7].c(); + // col-major broadcasting (mirror of the row-major cases above; the + // matmul dims are the first two, the trailing dims broadcast). + let la = [5, 6].f(); + let lb = [6, 7].f(); let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap(); assert_eq!(config.lc, [5, 7].f()); + + // rule 3 mirrored: K @ (K, N, ...) -> (N, ...) + let la = [5].f(); + let lb = [5, 6, 3, 4].f(); + let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap(); + assert_eq!(config.lc, [6, 3, 4].f()); + + // rule 4 mirrored: (M, K, ...) @ K -> (M, ...) + let la = [5, 6, 3, 4].f(); + let lb = [6].f(); + let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap(); + assert_eq!(config.lc, [5, 3, 4].f()); + + // rule 7 mirrored: full 5D x 5D batched matmul, including broadcast + // on the trailing dims (`1` broadcasts against `2`/`3`). + let la = [5, 6, 2, 1, 4].f(); + let lb = [6, 7, 1, 3, 4].f(); + let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap(); + assert_eq!(config.lc, [5, 7, 2, 3, 4].f()); } } diff --git a/rstsr-core/src/device_faer/matmul.rs b/rstsr-core/src/device_faer/matmul.rs index d6faab3..717788f 100644 --- a/rstsr-core/src/device_faer/matmul.rs +++ b/rstsr-core/src/device_faer/matmul.rs @@ -316,7 +316,9 @@ mod test { #[test] fn test_matmul() { - let device = DeviceFaer::default(); + let mut device = DeviceFaer::default(); + device.set_default_order(RowMajor); + let a = linspace((0.0, 14.0, 15, &device)).into_shape([3, 5]); let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]); @@ -327,23 +329,61 @@ mod test { let b = linspace((0.0, 14.0, 15, &device)); println!("{:}", &a % &b); - #[cfg(not(feature = "col_major"))] - { - let a = linspace((0.0, 2.0, 3, &device)); - let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); - println!("{:}", &a % &b); + // check broadcasting in row-major - let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); - let b = linspace((0.0, 4.0, 5, &device)); - println!("{:}", &a % &b); + println!("check broadcasting in column-major"); - let a = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]); - let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); - println!("{:}", &a % &b); + let a = linspace((0.0, 2.0, 3, &device)); + let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); + let c = &a % &b; + println!("{:}", c); + assert!(c.shape() == &[2, 5]); - let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); - let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]); - println!("{:}", &a % &b); - } + let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); + let b = linspace((0.0, 4.0, 5, &device)); + let c = &a % &b; + println!("{:}", c); + assert!(c.shape() == &[2, 3]); + + let a = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]); + let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); + let c = &a % &b; + println!("{:}", &a % &b); + assert!(c.shape() == &[2, 5, 5]); + + let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]); + let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]); + let c = &a % &b; + println!("{:}", c); + assert!(c.shape() == &[2, 3, 3]); + + // check broadcasting in column-major + + println!("check broadcasting in column-major"); + device.set_default_order(ColMajor); + + let a = linspace((0.0, 29.0, 30, &device)).into_shape([5, 3, 2]); + let b = linspace((0.0, 2.0, 3, &device)); + let c = &a % &b; + println!("{:}", c); + assert!(c.shape() == &[5, 2]); + + let a = linspace((0.0, 4.0, 5, &device)); + let b = linspace((0.0, 29.0, 30, &device)).into_shape([5, 3, 2]); + let c = &a % &b; + println!("{:}", c); + assert!(c.shape() == &[3, 2]); + + let a = linspace((0.0, 29.0, 30, &device)).into_shape([5, 3, 2]); + let b = linspace((0.0, 14.0, 15, &device)).into_shape([3, 5]); + let c = &a % &b; + println!("{:}", &a % &b); + assert!(c.shape() == &[5, 5, 2]); + + let a = linspace((0.0, 14.0, 15, &device)).into_shape([3, 5]); + let b = linspace((0.0, 29.0, 30, &device)).into_shape([5, 3, 2]); + let c = &a % &b; + println!("{:}", c); + assert!(c.shape() == &[3, 3, 2]); } }