Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions python/tvm/relax/frontend/nn/llm/position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,23 @@ def rope_freq_default(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtyp
return cos_freq, sin_freq, {freq_var: freq}


def rope_freq_gptj(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str):
"""Compute the inverse frequency of RoPE for gptj RoPE scaling."""
freq = s / tirx.power(theta, 2 * (d // 2) % d_range / tirx.const(d_range, "float32"))
def rope_freq_gptj(
s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str,
freq_dim_base: int = 0,
):
"""Compute the inverse frequency of RoPE for gptj RoPE scaling.

Parameters
----------
freq_dim_base : int
If > 0, use this as the denominator in the frequency exponent instead
of d_range. This supports partial rotary embeddings where the frequency
base dimension (head_dim) differs from the number of rotated dimensions
(rotary_dim). E.g., Gemma 4 full-attention layers have head_dim=512
but only rotate 128 dims (partial_rotary_factor=0.25).
"""
denom = freq_dim_base if freq_dim_base > 0 else d_range
freq = s / tirx.power(theta, 2 * (d // 2) % d_range / tirx.const(denom, "float32"))
freq_var = tirx.Var("freq", "float32")
cos_freq = tirx.cos(freq_var).astype(dtype)
sin_freq = tirx.sin(freq_var).astype(dtype)
Expand Down Expand Up @@ -262,6 +276,9 @@ def switch_rope_freq_func(rope_scaling: dict[str, Any]) -> Callable:
if "rope_type" not in rope_scaling:
return rope_freq_default
if rope_scaling["rope_type"] == "gptj":
freq_dim_base = rope_scaling.get("freq_dim_base", 0)
if freq_dim_base > 0:
return partial(rope_freq_gptj, freq_dim_base=freq_dim_base)
return rope_freq_gptj
if rope_scaling["rope_type"] == "llama3":
return partial(
Expand Down Expand Up @@ -522,14 +539,14 @@ def _rope( # pylint: disable=too-many-arguments
expr = tirx.Let(var, value, expr)
return expr

@T.prim_func
@T.prim_func(private=True)
def fused_rope( # pylint: disable=too-many-locals
var_qkv: T.handle,
var_position_map: T.handle,
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
apply_rope: T.int32,
apply_rope: T.int64,
):
T.func_attr(
{
Expand Down Expand Up @@ -564,7 +581,7 @@ def fused_rope( # pylint: disable=too-many-locals
else:
v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d]

@T.prim_func
@T.prim_func(private=True)
def fused_rope_longrope_scaling( # pylint: disable=too-many-locals
var_qkv: T.handle,
var_position_map: T.handle,
Expand Down
25 changes: 16 additions & 9 deletions python/tvm/tirx/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,18 +104,25 @@ def is_host_func(f):

host_mod = tvm.tirx.transform.Filter(is_host_func)(mod)
device_mod = tvm.tirx.transform.Filter(lambda f: not is_host_func(f))(mod)
# TODO(syfeng): Here we use str as key since target hash is not correct
target_str2target = {}
device_func_dict = {}
# Group device functions by target kind name (e.g. "webgpu", "cuda") rather
# than the full target string. Different TIR passes may attach slightly
# different target objects (e.g. with or without max_num_threads) to
# functions that should all end up in the same device module. Using the
# full str(target) as key splits them into separate modules, causing the
# later module to shadow the earlier one at runtime.
kind2target: dict[str, "Target"] = {}
kind2funcs: dict[str, dict] = {}
device_mod_dict: dict[Target, IRModule] = {}
for gv, func in device_mod.functions.items():
target = func.attrs.get("target", None)
target_str = str(target) if target is not None else ""
target_str2target[target_str] = target # This might be overridden by the last one
device_func_dict.setdefault(target_str, dict()).update({gv: func})
for target_str in target_str2target.keys():
target = target_str2target[target_str]
device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], attrs=device_mod.attrs)
kind = target.kind.name if target is not None else ""
# Keep the first target encountered for each kind as the canonical one
if kind not in kind2target:
kind2target[kind] = target
kind2funcs.setdefault(kind, dict()).update({gv: func})
for kind in kind2target:
target = kind2target[kind]
device_mod_dict[target] = tvm.IRModule(kind2funcs[kind], attrs=device_mod.attrs)
return host_mod, device_mod_dict


Expand Down
40 changes: 22 additions & 18 deletions src/runtime/vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -958,13 +958,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true);
}

if (append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is 1, we create the auxiliary
// data structure with regard to the page table after appending.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}
// Reserve pages BEFORE the aux-data loop unconditionally. The aux-data
// loop below reads `block.page_ids.size()` to populate page_indptr /
// page_indices / length_info, so pages must already be reserved in this
// call's blocks for the metadata to reflect the current prefill state.
//
// Previously the reserve was conditional on `append_before_attn_`: the
// `=true` branch reserved before the loop (correct), but the `=false`
// branch reserved after the loop, producing zero-page metadata for the
// first prefill into an empty cache. That broke models that perform
// intra-prefill shared-KV cross-attention (e.g. Gemma 4 layers 15-34
// reading the K/V written by layers 13/14 inside the same prefill call):
// MHACrossAttnInternal saw `page_indices->shape[0] == 0` and skipped
// the entire computation, leaving the model-supplied `o_data` as
// uninitialised memory.
//
// The K/V-append timing is unchanged: the actual append (via
// `f_transpose_append_mha`) is still controlled by `append_before_attn_`
// at attention time, not by when page slots are reserved here.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}

for (int d = 0; d < num_depths_; ++d) {
Expand Down Expand Up @@ -1106,15 +1119,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

if (!append_before_attn_) {
// Right now we use different kernels when depth is 1 or not 1.
// For the case where maximum depth is not 1, we create the auxiliary
// data structure with regard to the page table before appending.
for (int i = 0; i < cur_batch_size_; ++i) {
ReserveAppendLengthInSeq(sequences[i], append_lengths[i]);
}
}

// Map each the token position in the input batch to the position
// in the global KV cache. The mapping is used in when appending k/v values.
q_rope_position_map_host_.clear();
Expand Down Expand Up @@ -1415,7 +1419,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// The auxiliary data structure on device must have been synchronized.
TVM_FFI_ICHECK(!dirty_aux_data_device_);

if (attn_kind == AttnKind::kMHA) {
if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) {
MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale);
} else {
MLASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale);
Expand Down Expand Up @@ -1454,7 +1458,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// The auxiliary data structure on device must have been synchronized.
TVM_FFI_ICHECK(!dirty_aux_data_device_);

if (attn_kind == AttnKind::kMHA) {
if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) {
MHACrossAttnInternal(local_layer_id, q_data, o_data, lse_data, sm_scale,
/*is_first_kernel=*/true);
} else {
Expand Down
5 changes: 5 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,11 @@ TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
// thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no
// subgroup ops are emitted.
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
// The WebGPU spec mandates `maxComputeWorkgroupStorageSize >= 16384`;
// Chrome/Dawn currently exposes 32768. Without this default the Dlight
// scheduler falls back to the generic (48 KB) budget and emits kernels
// that exceed Chrome's limit at launch time.
.add_attr_option<int64_t>("max_shared_memory_per_block", refl::DefaultValue(32768))
.set_target_canonicalizer(UpdateWebGPUAttrs)
.set_default_keys({"webgpu", "gpu"});

Expand Down
57 changes: 34 additions & 23 deletions web/emcc/wasm_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@
#include <tvm/ffi/reflection/registry.h>
#include <tvm/runtime/logging.h>

// FFI core must come before runtime .cc includes in this single translation
// unit. Otherwise, static initialisation can resolve ffi globals before
// object.cc registers them, leading to crashes at module init time.
#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc"
#include "3rdparty/tvm-ffi/src/ffi/container.cc"
#include "3rdparty/tvm-ffi/src/ffi/dtype.cc"
#include "3rdparty/tvm-ffi/src/ffi/error.cc"
#include "3rdparty/tvm-ffi/src/ffi/function.cc"
#include "3rdparty/tvm-ffi/src/ffi/object.cc"
#include "3rdparty/tvm-ffi/src/ffi/tensor.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc"
#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc"

#include "src/runtime/contrib/sort/sort.cc"
#include "src/runtime/cpu_device_api.cc"
#include "src/runtime/device_api.cc"
Expand All @@ -47,22 +66,6 @@
#include "src/runtime/rpc/rpc_session.cc"
#include "src/runtime/tensor.cc"
#include "src/runtime/workspace_pool.cc"
// relax setup
#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc"
#include "3rdparty/tvm-ffi/src/ffi/container.cc"
#include "3rdparty/tvm-ffi/src/ffi/dtype.cc"
#include "3rdparty/tvm-ffi/src/ffi/error.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc"
#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc"
#include "3rdparty/tvm-ffi/src/ffi/function.cc"
#include "3rdparty/tvm-ffi/src/ffi/object.cc"
#include "3rdparty/tvm-ffi/src/ffi/tensor.cc"
#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc"
#include "src/runtime/memory/memory_manager.cc"
#include "src/runtime/nvtx.cc"
#include "src/runtime/vm/attn_backend.cc"
Expand Down Expand Up @@ -132,20 +135,28 @@ void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::strin
const char* byte_data = bytes->data;
const size_t byte_size = bytes->size;
if (format == "f32-to-bf16" && dtype == "float32") {
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
TVM_FFI_ICHECK(cpu_arr.IsContiguous());
size_t size = 1;
for (int i = 0; i < cpu_arr->ndim; ++i) {
size *= cpu_arr->shape[i];
}
TVM_FFI_ICHECK_EQ(size, byte_size / 2);
for (size_t i = 0; i < size; ++i) {
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
// The "f32-to-bf16" format encodes a float32 tensor as packed bf16 (2
// bytes per element). When the byte_size matches that expectation, expand
// back to f32. If the byte_size matches the native float32 width
// (4 bytes per element), the payload is already raw float32 — fall through
// to the generic byte copy. This makes the loader tolerant of weight
// shards produced by older / alternate quantisation pipelines that retain
// the "f32-to-bf16" tag without performing the bf16 truncation.
if (size == byte_size / 2) {
const uint16_t* bf16 = reinterpret_cast<const uint16_t*>(byte_data);
uint32_t* data = static_cast<uint32_t*>(cpu_arr->data);
for (size_t i = 0; i < size; ++i) {
data[i] = static_cast<uint32_t>(bf16[i]) << 16;
}
return;
}
} else {
cpu_arr.CopyFromBytes(byte_data, byte_size);
}
cpu_arr.CopyFromBytes(byte_data, byte_size);
}

TVM_FFI_STATIC_INIT_BLOCK() {
Expand Down
103 changes: 100 additions & 3 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,7 @@ export class Instance implements Disposable {
artifactCache: ArtifactCacheTemplate,
signal?: AbortSignal,
) {
const maxChunkBytes = 128 * 1024 * 1024;
const perf = compact.getPerformance();
const tstart = perf.now();
let totalBytes = 0;
Expand Down Expand Up @@ -1421,9 +1422,53 @@ export class Instance implements Disposable {
this.empty(rec.shape, rec.dtype, this.cpu())
)
});
const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes);
const shardBytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer);
const recSource =
rec.byteOffset === 0 && rec.nbytes === shardBytes.byteLength
? shardBytes
: shardBytes.subarray(rec.byteOffset, rec.byteOffset + rec.nbytes);
const canChunkRecord =
rec.nbytes > maxChunkBytes &&
rec.shape.length >= 1 &&
Number.isInteger(rec.shape[0]) &&
rec.shape[0] > 0 &&
rec.nbytes % rec.shape[0] === 0;
const copyRecordToTensor = (targetTensor: Tensor, sourceBytes: Uint8Array) => {
if (!canChunkRecord) {
this.ctx.arrayDecodeStorage(targetTensor, sourceBytes, rec.format, rec.dtype);
return;
}
const outerDim = rec.shape[0];
const chunkStrideBytes = rec.nbytes / outerDim;
const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / chunkStrideBytes));
for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) {
const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset);
const chunkByteOffset = outerOffset * chunkStrideBytes;
const chunkBytes = outerCount * chunkStrideBytes;
const chunkShape = rec.shape.slice();
chunkShape[0] = outerCount;
// Wrap in withNewScope so TVM intermediate objects (shape tuple)
// are disposed after each chunk, but detach the view we need.
const chunkView = this.withNewScope(() => {
return this.detachFromCurrentScope(
this.ctx.tensorCreateView(
targetTensor,
this.ctx.makeShapeTuple(...chunkShape.map((value) => new Scalar(value, "int"))),
rec.dtype,
new Scalar(chunkByteOffset, "int"),
)
);
});
const chunkSource = sourceBytes.subarray(chunkByteOffset, chunkByteOffset + chunkBytes);
try {
this.ctx.arrayDecodeStorage(chunkView, chunkSource, rec.format, rec.dtype);
} finally {
chunkView.dispose();
}
}
};
// first sync copy to cpu.
this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype);
copyRecordToTensor(cpu_arr, recSource);
// then async stream into GPU if needed
if (device.deviceType === DeviceStrToEnum.cpu) {
this.tensorCacheUpdate(rec.name, cpu_arr, false);
Expand All @@ -1435,7 +1480,40 @@ export class Instance implements Disposable {
this.empty(rec.shape, rec.dtype, device)
)
});
gpu_arr.copyFrom(cpu_arr);
if (!canChunkRecord) {
gpu_arr.copyFrom(cpu_arr);
} else {
const outerDim = rec.shape[0];
const chunkStrideBytes = rec.nbytes / outerDim;
const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / chunkStrideBytes));
for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) {
const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset);
const chunkByteOffset = outerOffset * chunkStrideBytes;
const chunkShape = rec.shape.slice();
chunkShape[0] = outerCount;
// Use withNewScope so the shape tuple is auto-disposed,
// and detach the views we need for manual lifetime control.
const [cpuView, gpuView] = this.withNewScope(() => {
const chunkShapeTuple = this.ctx.makeShapeTuple(
...chunkShape.map((value) => new Scalar(value, "int")),
);
return [
this.detachFromCurrentScope(
this.ctx.tensorCreateView(cpu_arr, chunkShapeTuple, rec.dtype, new Scalar(chunkByteOffset, "int"))
),
this.detachFromCurrentScope(
this.ctx.tensorCreateView(gpu_arr, chunkShapeTuple, rec.dtype, new Scalar(chunkByteOffset, "int"))
),
];
});
try {
gpuView.copyFrom(cpuView);
} finally {
cpuView.dispose();
gpuView.dispose();
}
}
}
await device.sync();
this.tensorCacheUpdate(rec.name, gpu_arr, false);
cpu_arr.dispose();
Expand Down Expand Up @@ -2258,6 +2336,25 @@ export class Instance implements Disposable {
case TypeIndex.kTVMFFIOpaquePtr: {
return this.memory.loadPointer(valuePtr);
}
case TypeIndex.kTVMFFIShape: {
const shapeObjPtr = this.memory.loadPointer(valuePtr);
if (callbackArg) {
const shapeCellPtr = shapeObjPtr + SizeOf.ObjectHeader;
const shapeDataPtr = this.memory.loadPointer(shapeCellPtr);
const shapeLen = this.memory.loadUSize(shapeCellPtr + this.memory.sizeofPtr());
const result = new Array<number>(shapeLen);
for (let i = 0; i < shapeLen; ++i) {
result[i] = this.memory.loadI64(shapeDataPtr + i * SizeOf.I64);
}
this.lib.checkCall(
(this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(shapeObjPtr)
);
return result;
}
return this.ctx.attachToCurrentScope(
new TVMObject(shapeObjPtr, this.lib, this.ctx)
);
}
case TypeIndex.kTVMFFITensor: {
return this.ctx.attachToCurrentScope(
new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false)
Expand Down