Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xtask/src/utils/operator/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ impl Operator {
impl Content<'_> {
pub(super) fn cast(&mut self, types: HashMap<String, Ty>) {
match self.general_architecture().unwrap() {
"llama" | "gpt2" | "qwen2" => {
"llama" | "gpt2" | "qwen2" | "qwen3" => {
let [linear, embd, norm, else_] =
["linear", "embd", "norm", "else"].map(|name| types.get(name).copied());
self.cast_(linear, |name, shape| {
Expand Down
30 changes: 30 additions & 0 deletions xtask/src/utils/operator/permute_qk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ impl Content<'_> {
for (name, tensor) in tensors {
static QK_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(attn_q|attn_k|attn_qkv)\.(weight|bias)$").unwrap());
static QK_NORM_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(attn_q_norm|attn_k_norm)\.(weight)$").unwrap());

let tensor = if let Some(captures) = QK_REGEX.captures(&name) {
match &captures[1] {
Expand All @@ -35,6 +37,8 @@ impl Content<'_> {
}
_ => unreachable!(),
}
} else if QK_NORM_REGEX.is_match(&name) {
permute_qk_norm(tensor, !direction)
} else {
tensor
};
Expand Down Expand Up @@ -73,3 +77,29 @@ fn permute_qk(tensor: Tensor, nh: usize, rev: bool) -> Tensor {
});
Tensor { ty, shape, data }
}

fn permute_qk_norm(tensor: Tensor, rev: bool) -> Tensor {
let Tensor { ty, shape, data } = tensor;
let [r] = match &*shape {
&[r] => [r],
[..] => todo!("permute_qk_norm only supports 1D tensors"),
};
let c = ty.size().elements_to_bytes(&[1]);
let r = r as usize;

let tiles = if rev { [2, r / 2] } else { [r / 2, 2] };

type Layout = mem_rearrange::ndarray_layout::ArrayLayout<4>;
let src = Layout::new_contiguous(&[c, r], LittleEndian, 1)
.tile_le(1, &tiles)
.transpose(&[2, 1]);
let dst = Layout::new_contiguous(src.shape(), LittleEndian, 1);
let rearrange = Rearranging::new(&dst, &src, 1).unwrap();

let data = DataPromise::lazy(move || {
let mut ans = MmapMut::map_anon(c * r).unwrap();
unsafe { rearrange.launch(ans.as_mut_ptr(), data.get().as_ptr()) };
ans
});
Tensor { ty, shape, data }
}