diff --git a/xtask/src/utils/operator/cast.rs b/xtask/src/utils/operator/cast.rs index 93eb848..56da87b 100644 --- a/xtask/src/utils/operator/cast.rs +++ b/xtask/src/utils/operator/cast.rs @@ -28,7 +28,7 @@ impl Operator { impl Content<'_> { pub(super) fn cast(&mut self, types: HashMap) { match self.general_architecture().unwrap() { - "llama" | "gpt2" | "qwen2" => { + "llama" | "gpt2" | "qwen2" | "qwen3" => { let [linear, embd, norm, else_] = ["linear", "embd", "norm", "else"].map(|name| types.get(name).copied()); self.cast_(linear, |name, shape| { diff --git a/xtask/src/utils/operator/permute_qk.rs b/xtask/src/utils/operator/permute_qk.rs index e71ba98..0270dc8 100644 --- a/xtask/src/utils/operator/permute_qk.rs +++ b/xtask/src/utils/operator/permute_qk.rs @@ -22,6 +22,8 @@ impl Content<'_> { for (name, tensor) in tensors { static QK_REGEX: LazyLock = LazyLock::new(|| Regex::new(r"(attn_q|attn_k|attn_qkv)\.(weight|bias)$").unwrap()); + static QK_NORM_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"(attn_q_norm|attn_k_norm)\.(weight)$").unwrap()); let tensor = if let Some(captures) = QK_REGEX.captures(&name) { match &captures[1] { @@ -35,6 +37,8 @@ impl Content<'_> { } _ => unreachable!(), } + } else if QK_NORM_REGEX.is_match(&name) { + permute_qk_norm(tensor, !direction) } else { tensor }; @@ -73,3 +77,29 @@ fn permute_qk(tensor: Tensor, nh: usize, rev: bool) -> Tensor { }); Tensor { ty, shape, data } } + +fn permute_qk_norm(tensor: Tensor, rev: bool) -> Tensor { + let Tensor { ty, shape, data } = tensor; + let [r] = match &*shape { + &[r] => [r], + [..] => todo!("permute_qk_norm only supports 1D tensors"), + }; + let c = ty.size().elements_to_bytes(&[1]); + let r = r as usize; + + let tiles = if rev { [2, r / 2] } else { [r / 2, 2] }; + + type Layout = mem_rearrange::ndarray_layout::ArrayLayout<4>; + let src = Layout::new_contiguous(&[c, r], LittleEndian, 1) + .tile_le(1, &tiles) + .transpose(&[2, 1]); + let dst = Layout::new_contiguous(src.shape(), LittleEndian, 1); + let rearrange = Rearranging::new(&dst, &src, 1).unwrap(); + + let data = DataPromise::lazy(move || { + let mut ans = MmapMut::map_anon(c * r).unwrap(); + unsafe { rearrange.launch(ans.as_mut_ptr(), data.get().as_ptr()) }; + ans + }); + Tensor { ty, shape, data } +}