Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions crates/core_arch/src/x86_64/amx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,72 @@ pub unsafe fn _tile_cvtrowps2phli<const TILE: i32, const ROW: i32>() -> __m512h
tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
Copy link
Copy Markdown
Contributor

@folkertdev folkertdev Apr 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's up with windows gnu?

View changes since the review

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc it had some problems decoding the amx instructions, didn't check recently though. let me just check in CI once

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, so only apple-darwin fails now. let me reduce the conditions a bit

assert_instr(tcvtrowps2bf16h, TILE = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2bf16h<const TILE: i32>(row: u32) -> __m512bh {
static_assert_uimm_bits!(TILE, 3);
tcvtrowps2bf16h(TILE as i8, row).as_m512bh()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tcvtrowps2bf16h, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2bf16hi<const TILE: i32, const ROW: i32>() -> __m512bh {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tcvtrowps2bf16l, TILE = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2bf16l<const TILE: i32>(row: u32) -> __m512bh {
static_assert_uimm_bits!(TILE, 3);
tcvtrowps2bf16l(TILE as i8, row).as_m512bh()
}

/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
#[inline]
#[rustc_legacy_const_generics(0, 1)]
#[target_feature(enable = "amx-avx512,avx10.2")]
#[cfg_attr(
all(test, any(target_os = "linux", target_env = "msvc")),
assert_instr(tcvtrowps2bf16l, TILE = 0, ROW = 0)
)]
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
pub unsafe fn _tile_cvtrowps2bf16li<const TILE: i32, const ROW: i32>() -> __m512bh {
static_assert_uimm_bits!(TILE, 3);
static_assert_uimm_bits!(ROW, 6);
tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh()
}

/// Moves one row of tile data into a zmm vector register
#[inline]
#[rustc_legacy_const_generics(0)]
Expand Down Expand Up @@ -567,6 +633,14 @@ unsafe extern "C" {
fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2phli"]
fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16h"]
fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16hi"]
fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16l"]
fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tcvtrowps2bf16li"]
fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32;
#[link_name = "llvm.x86.tilemovrow"]
fn tilemovrow(tile: i8, row: u32) -> i32x16;
#[link_name = "llvm.x86.tilemovrowi"]
Expand Down Expand Up @@ -1276,6 +1350,110 @@ mod tests {
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16h() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = _tile_cvtrowps2bf16h::<0>(i);
assert_eq!(
*row.as_u16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 {
0
} else {
_mm_cvtness_sbh(i as _).to_bits()
})
);
}
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16hi() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowps2bf16hi::<0>, i);
assert_eq!(
*row.as_u16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 {
0
} else {
_mm_cvtness_sbh(i as _).to_bits()
})
);
}
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16l() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = _tile_cvtrowps2bf16l::<0>(i);
assert_eq!(
*row.as_u16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 {
_mm_cvtness_sbh(i as _).to_bits()
} else {
0
})
);
}
}
}

#[simd_test(enable = "amx-avx512,avx10.2")]
fn test_tile_cvtrowps2bf16li() {
unsafe {
_init_amx();
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);

let mut config = __tilecfg::default();
config.palette = 1;
config.colsb[0] = 64;
config.rows[0] = 16;
_tile_loadconfig(config.as_ptr());
_tile_loadd::<0>(array.as_ptr().cast(), 64);
for i in 0..16 {
let row = wrap_imm4!(_tile_cvtrowps2bf16li::<0>, i);
assert_eq!(
*row.as_u16x32().as_array(),
array::from_fn(|j| if j & 1 == 0 {
_mm_cvtness_sbh(i as _).to_bits()
} else {
0
})
);
}
}
}

#[simd_test(enable = "amx-tf32")]
fn test_tile_mmultf32ps() {
unsafe {
Expand Down