Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 59 additions & 83 deletions rstsr-common/src/layout/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ Layout manuplication for matmul and other linalg operations

We Refer to [Python array API](https://data-apis.org/array-api/2024.12/specification/generated/array_api.matmul.html) for more information.

Please note that the following rule only applies to row-major.
The rules below are written for row-major; the last two axes of each operand
are the matmul dimensions and any leading axes broadcast.

| Id | A | B | C |
|----|---|---|---|
Expand All @@ -18,7 +19,12 @@ Please note that the following rule only applies to row-major.
| 6. | `..., M, K` | ` K, N` | `..., M, N` |
| 7. | `..., M, K` | `..., K, N` | `..., M, N` |

For col-major, only rule 1, 2, (part of) 3, (part of) 4 are valid.
For col-major, the same rules apply *with all axes reversed*: the matmul
dimensions are the first two of each operand and any trailing axes broadcast.
This is implemented by delegating to the row-major routine on reversed-and-
swapped inputs, using the identity
`C[t, m, n] = A[t, m, k] @ B[t, k, n]` (row-major) `==`
`C[n, m, t] = B[n, k, t] @ A[k, m, t]` (col-major).

*/

Expand Down Expand Up @@ -266,83 +272,37 @@ fn layout_matmul_dyn_row_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<Lay
}

fn layout_matmul_dyn_col_major(la: &Layout<IxD>, lb: &Layout<IxD>) -> Result<LayoutMatMulConfig<IxD, IxD>> {
// For col-major, we re-use the row-major implementation via the identity
// C[t, m, n] = A[t, m, k] @ B[t, k, n] (row-major)
// <=> C[n, m, t] = B[n, k, t] @ A[k, m, t] (col-major)
// i.e. reverse all axes and swap A/B. So we delegate to
// `layout_matmul_dyn_row_major(lb_rev, la_rev)` and then reverse-axes (and
// swap A/B back) on every layout field that the row-major impl returns.
//
// Note that rules 5/6/7 (broadcasting matmul) are only supported by the
// row-major rules in the array API spec; for col-major we accept them
// here, but the corresponding `DeviceMatMulAPI` impl must follow the same
// reverse-axes-and-swap convention (see e.g. `device_faer/matmul.rs`).
let na = la.ndim();
let nb = lb.ndim();
match (na, nb) {
(1, 1) => {
// rule 1: vector inner dot
rstsr_assert_eq!(la.shape(), lb.shape(), InvalidLayout)?;
let lc = unsafe { Layout::new_unchecked(vec![], vec![], 0) };
Ok(LayoutMatMulConfig {
matmul_type: MatMulType::InnerDot,
lc: lc.clone(),
la_rest: None,
lb_rest: None,
lc_rest: None,
la_matmul: la.to_dim()?,
lb_matmul: lb.to_dim()?,
lc_matmul: lc.to_dim()?,
})
},
(2, 2) => {
// rule 2: matrix multiplication
// check and generate shape
rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
let sc = vec![la.shape()[0], lb.shape()[1]];
// layout order determination
let lc = sc.f();
// return layout configuration
Ok(LayoutMatMulConfig {
matmul_type: MatMulType::GEMM22,
lc: lc.clone(),
la_rest: None,
lb_rest: None,
lc_rest: None,
la_matmul: la.to_dim()?,
lb_matmul: lb.to_dim()?,
lc_matmul: lc.to_dim()?,
})
},
(1, 2) => {
// rule 3: | ` K` | ` K, N` | ` N` |
// check and generate shape
rstsr_assert_eq!(la.shape()[0], lb.shape()[0], InvalidLayout)?;
let sc = vec![lb.shape()[1]];
let lc = sc.f();
Ok(LayoutMatMulConfig {
matmul_type: MatMulType::GEVM,
lc: lc.to_dim()?,
la_rest: None,
lb_rest: None,
lc_rest: None,
la_matmul: la.to_dim()?,
lb_matmul: lb.to_dim()?,
lc_matmul: lc.to_dim()?,
})
},
(2, 1) => {
// rule 4: | ` M, K` | ` K` | ` M` |
// check and generate shape
rstsr_assert_eq!(la.shape()[1], lb.shape()[0], InvalidLayout)?;
let sc = vec![la.shape()[0]];
let lc = sc.f();
// return layout configuration
Ok(LayoutMatMulConfig {
matmul_type: MatMulType::GEMV,
lc: lc.to_dim()?,
la_rest: None,
lb_rest: None,
lc_rest: None,
la_matmul: la.to_dim()?,
lb_matmul: lb.to_dim()?,
lc_matmul: lc.to_dim()?,
})
},
(1, 3..) | (3.., 1) | (2, 3..) | (3.., 2) | (3.., 3..) => {
rstsr_raise!(InvalidLayout, "Broadcasting matmul is not supported in col-major.")
},
(0, _) | (_, 0) => rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed."),
if na == 0 || nb == 0 {
return rstsr_invalid!((na, nb), "In matmul, 0-dim is not allowed.");
}
let la_rev = la.reverse_axes();
let lb_rev = lb.reverse_axes();
let cfg = layout_matmul_dyn_row_major(&lb_rev, &la_rev)?;
Ok(LayoutMatMulConfig {
matmul_type: cfg.matmul_type,
lc: cfg.lc.reverse_axes(),
// row-major's `la_*` corresponds to (reversed) B, so it maps back to
// col-major's `lb_*` (after reversing axes again).
la_rest: cfg.lb_rest.map(|l| l.reverse_axes()),
lb_rest: cfg.la_rest.map(|l| l.reverse_axes()),
lc_rest: cfg.lc_rest.map(|l| l.reverse_axes()),
la_matmul: cfg.lb_matmul.reverse_axes(),
lb_matmul: cfg.la_matmul.reverse_axes(),
lc_matmul: cfg.lc_matmul.reverse_axes(),
})
}

impl LayoutMatMulAPI<IxD, IxD> for LayoutMatMulConfig<IxD, IxD> {
Expand Down Expand Up @@ -490,14 +450,30 @@ mod test_fixed {
let config = LayoutMatMulConfig::layout_matmul(&la, &lb, RowMajor).unwrap();
assert_eq!(config.lc, [2, 3, 4, 5, 7].c());

let la = [4, 3, 2, 5, 6].f().swapaxes(0, 2).unwrap();
let lb = [4, 3, 2, 6, 7].f().swapaxes(0, 2).unwrap();
let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor);
assert!(config.is_err());

let la = [5, 6].c();
let lb = [6, 7].c();
// col-major broadcasting (mirror of the row-major cases above; the
// matmul dims are the first two, the trailing dims broadcast).
let la = [5, 6].f();
let lb = [6, 7].f();
let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
assert_eq!(config.lc, [5, 7].f());

// rule 3 mirrored: K @ (K, N, ...) -> (N, ...)
let la = [5].f();
let lb = [5, 6, 3, 4].f();
let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
assert_eq!(config.lc, [6, 3, 4].f());

// rule 4 mirrored: (M, K, ...) @ K -> (M, ...)
let la = [5, 6, 3, 4].f();
let lb = [6].f();
let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
assert_eq!(config.lc, [5, 3, 4].f());

// rule 7 mirrored: full 5D x 5D batched matmul, including broadcast
// on the trailing dims (`1` broadcasts against `2`/`3`).
let la = [5, 6, 2, 1, 4].f();
let lb = [6, 7, 1, 3, 4].f();
let config = LayoutMatMulConfig::layout_matmul(&la, &lb, ColMajor).unwrap();
assert_eq!(config.lc, [5, 7, 2, 3, 4].f());
}
}
72 changes: 56 additions & 16 deletions rstsr-core/src/device_faer/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,9 @@ mod test {

#[test]
fn test_matmul() {
let device = DeviceFaer::default();
let mut device = DeviceFaer::default();
device.set_default_order(RowMajor);

let a = linspace((0.0, 14.0, 15, &device)).into_shape([3, 5]);
let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);

Expand All @@ -327,23 +329,61 @@ mod test {
let b = linspace((0.0, 14.0, 15, &device));
println!("{:}", &a % &b);

#[cfg(not(feature = "col_major"))]
{
let a = linspace((0.0, 2.0, 3, &device));
let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
println!("{:}", &a % &b);
// check broadcasting in row-major

let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
let b = linspace((0.0, 4.0, 5, &device));
println!("{:}", &a % &b);
println!("check broadcasting in column-major");

let a = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
println!("{:}", &a % &b);
let a = linspace((0.0, 2.0, 3, &device));
let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
let c = &a % &b;
println!("{:}", c);
assert!(c.shape() == &[2, 5]);

let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
println!("{:}", &a % &b);
}
let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
let b = linspace((0.0, 4.0, 5, &device));
let c = &a % &b;
println!("{:}", c);
assert!(c.shape() == &[2, 3]);

let a = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
let b = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
let c = &a % &b;
println!("{:}", &a % &b);
assert!(c.shape() == &[2, 5, 5]);

let a = linspace((0.0, 29.0, 30, &device)).into_shape([2, 3, 5]);
let b = linspace((0.0, 14.0, 15, &device)).into_shape([5, 3]);
let c = &a % &b;
println!("{:}", c);
assert!(c.shape() == &[2, 3, 3]);

// check broadcasting in column-major

println!("check broadcasting in column-major");
device.set_default_order(ColMajor);

let a = linspace((0.0, 29.0, 30, &device)).into_shape([5, 3, 2]);
let b = linspace((0.0, 2.0, 3, &device));
let c = &a % &b;
println!("{:}", c);
assert!(c.shape() == &[5, 2]);

let a = linspace((0.0, 4.0, 5, &device));
let b = linspace((0.0, 29.0, 30, &device)).into_shape([5, 3, 2]);
let c = &a % &b;
println!("{:}", c);
assert!(c.shape() == &[3, 2]);

let a = linspace((0.0, 29.0, 30, &device)).into_shape([5, 3, 2]);
let b = linspace((0.0, 14.0, 15, &device)).into_shape([3, 5]);
let c = &a % &b;
println!("{:}", &a % &b);
assert!(c.shape() == &[5, 5, 2]);

let a = linspace((0.0, 14.0, 15, &device)).into_shape([3, 5]);
let b = linspace((0.0, 29.0, 30, &device)).into_shape([5, 3, 2]);
let c = &a % &b;
println!("{:}", c);
assert!(c.shape() == &[3, 3, 2]);
}
}
Loading