diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py b/python/tvm/relax/frontend/nn/llm/position_embedding.py index dec80c50c2..ab034c7097 100644 --- a/python/tvm/relax/frontend/nn/llm/position_embedding.py +++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py @@ -67,9 +67,23 @@ def rope_freq_default(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtyp return cos_freq, sin_freq, {freq_var: freq} -def rope_freq_gptj(s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str): - """Compute the inverse frequency of RoPE for gptj RoPE scaling.""" - freq = s / tirx.power(theta, 2 * (d // 2) % d_range / tirx.const(d_range, "float32")) +def rope_freq_gptj( + s: tirx.Var, d: tirx.Var, d_range: int, theta: float, dtype: str, + freq_dim_base: int = 0, +): + """Compute the inverse frequency of RoPE for gptj RoPE scaling. + + Parameters + ---------- + freq_dim_base : int + If > 0, use this as the denominator in the frequency exponent instead + of d_range. This supports partial rotary embeddings where the frequency + base dimension (head_dim) differs from the number of rotated dimensions + (rotary_dim). E.g., Gemma 4 full-attention layers have head_dim=512 + but only rotate 128 dims (partial_rotary_factor=0.25). + """ + denom = freq_dim_base if freq_dim_base > 0 else d_range + freq = s / tirx.power(theta, 2 * (d // 2) % d_range / tirx.const(denom, "float32")) freq_var = tirx.Var("freq", "float32") cos_freq = tirx.cos(freq_var).astype(dtype) sin_freq = tirx.sin(freq_var).astype(dtype) @@ -262,6 +276,9 @@ def switch_rope_freq_func(rope_scaling: dict[str, Any]) -> Callable: if "rope_type" not in rope_scaling: return rope_freq_default if rope_scaling["rope_type"] == "gptj": + freq_dim_base = rope_scaling.get("freq_dim_base", 0) + if freq_dim_base > 0: + return partial(rope_freq_gptj, freq_dim_base=freq_dim_base) return rope_freq_gptj if rope_scaling["rope_type"] == "llama3": return partial( @@ -522,14 +539,14 @@ def _rope( # pylint: disable=too-many-arguments expr = tirx.Let(var, value, expr) return expr - @T.prim_func + @T.prim_func(private=True) def fused_rope( # pylint: disable=too-many-locals var_qkv: T.handle, var_position_map: T.handle, var_q: T.handle, var_k: T.handle, var_v: T.handle, - apply_rope: T.int32, + apply_rope: T.int64, ): T.func_attr( { @@ -564,7 +581,7 @@ def fused_rope( # pylint: disable=too-many-locals else: v[s, h - (num_q_heads + num_kv_heads), d] = qkv[s, h, d] - @T.prim_func + @T.prim_func(private=True) def fused_rope_longrope_scaling( # pylint: disable=too-many-locals var_qkv: T.handle, var_position_map: T.handle, diff --git a/python/tvm/tirx/build.py b/python/tvm/tirx/build.py index 020730d2f9..d4167962c7 100644 --- a/python/tvm/tirx/build.py +++ b/python/tvm/tirx/build.py @@ -104,18 +104,25 @@ def is_host_func(f): host_mod = tvm.tirx.transform.Filter(is_host_func)(mod) device_mod = tvm.tirx.transform.Filter(lambda f: not is_host_func(f))(mod) - # TODO(syfeng): Here we use str as key since target hash is not correct - target_str2target = {} - device_func_dict = {} + # Group device functions by target kind name (e.g. "webgpu", "cuda") rather + # than the full target string. Different TIR passes may attach slightly + # different target objects (e.g. with or without max_num_threads) to + # functions that should all end up in the same device module. Using the + # full str(target) as key splits them into separate modules, causing the + # later module to shadow the earlier one at runtime. + kind2target: dict[str, "Target"] = {} + kind2funcs: dict[str, dict] = {} device_mod_dict: dict[Target, IRModule] = {} for gv, func in device_mod.functions.items(): target = func.attrs.get("target", None) - target_str = str(target) if target is not None else "" - target_str2target[target_str] = target # This might be overridden by the last one - device_func_dict.setdefault(target_str, dict()).update({gv: func}) - for target_str in target_str2target.keys(): - target = target_str2target[target_str] - device_mod_dict[target] = tvm.IRModule(device_func_dict[target_str], attrs=device_mod.attrs) + kind = target.kind.name if target is not None else "" + # Keep the first target encountered for each kind as the canonical one + if kind not in kind2target: + kind2target[kind] = target + kind2funcs.setdefault(kind, dict()).update({gv: func}) + for kind in kind2target: + target = kind2target[kind] + device_mod_dict[target] = tvm.IRModule(kind2funcs[kind], attrs=device_mod.attrs) return host_mod, device_mod_dict diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc index 36f7697237..663fa2765d 100644 --- a/src/runtime/vm/paged_kv_cache.cc +++ b/src/runtime/vm/paged_kv_cache.cc @@ -958,13 +958,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { std::fill(is_chain_on_depths_.begin(), is_chain_on_depths_.end(), true); } - if (append_before_attn_) { - // Right now we use different kernels when depth is 1 or not 1. - // For the case where maximum depth is 1, we create the auxiliary - // data structure with regard to the page table after appending. - for (int i = 0; i < cur_batch_size_; ++i) { - ReserveAppendLengthInSeq(sequences[i], append_lengths[i]); - } + // Reserve pages BEFORE the aux-data loop unconditionally. The aux-data + // loop below reads `block.page_ids.size()` to populate page_indptr / + // page_indices / length_info, so pages must already be reserved in this + // call's blocks for the metadata to reflect the current prefill state. + // + // Previously the reserve was conditional on `append_before_attn_`: the + // `=true` branch reserved before the loop (correct), but the `=false` + // branch reserved after the loop, producing zero-page metadata for the + // first prefill into an empty cache. That broke models that perform + // intra-prefill shared-KV cross-attention (e.g. Gemma 4 layers 15-34 + // reading the K/V written by layers 13/14 inside the same prefill call): + // MHACrossAttnInternal saw `page_indices->shape[0] == 0` and skipped + // the entire computation, leaving the model-supplied `o_data` as + // uninitialised memory. + // + // The K/V-append timing is unchanged: the actual append (via + // `f_transpose_append_mha`) is still controlled by `append_before_attn_` + // at attention time, not by when page slots are reserved here. + for (int i = 0; i < cur_batch_size_; ++i) { + ReserveAppendLengthInSeq(sequences[i], append_lengths[i]); } for (int d = 0; d < num_depths_; ++d) { @@ -1106,15 +1119,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - if (!append_before_attn_) { - // Right now we use different kernels when depth is 1 or not 1. - // For the case where maximum depth is not 1, we create the auxiliary - // data structure with regard to the page table before appending. - for (int i = 0; i < cur_batch_size_; ++i) { - ReserveAppendLengthInSeq(sequences[i], append_lengths[i]); - } - } - // Map each the token position in the input batch to the position // in the global KV cache. The mapping is used in when appending k/v values. q_rope_position_map_host_.clear(); @@ -1415,7 +1419,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // The auxiliary data structure on device must have been synchronized. TVM_FFI_ICHECK(!dirty_aux_data_device_); - if (attn_kind == AttnKind::kMHA) { + if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) { MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); } else { MLASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale); @@ -1454,7 +1458,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // The auxiliary data structure on device must have been synchronized. TVM_FFI_ICHECK(!dirty_aux_data_device_); - if (attn_kind == AttnKind::kMHA) { + if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) { MHACrossAttnInternal(local_layer_id, q_data, o_data, lse_data, sm_scale, /*is_first_kernel=*/true); } else { diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 2fb5e17d5f..054ba06faf 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -463,6 +463,11 @@ TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) // thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no // subgroup ops are emitted. .add_attr_option("thread_warp_size", refl::DefaultValue(1)) + // The WebGPU spec mandates `maxComputeWorkgroupStorageSize >= 16384`; + // Chrome/Dawn currently exposes 32768. Without this default the Dlight + // scheduler falls back to the generic (48 KB) budget and emits kernels + // that exceed Chrome's limit at launch time. + .add_attr_option("max_shared_memory_per_block", refl::DefaultValue(32768)) .set_target_canonicalizer(UpdateWebGPUAttrs) .set_default_keys({"webgpu", "gpu"}); diff --git a/web/emcc/wasm_runtime.cc b/web/emcc/wasm_runtime.cc index f112a48b8e..eea62f4e88 100644 --- a/web/emcc/wasm_runtime.cc +++ b/web/emcc/wasm_runtime.cc @@ -33,6 +33,25 @@ #include #include +// FFI core must come before runtime .cc includes in this single translation +// unit. Otherwise, static initialisation can resolve ffi globals before +// object.cc registers them, leading to crashes at module init time. +#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc" +#include "3rdparty/tvm-ffi/src/ffi/container.cc" +#include "3rdparty/tvm-ffi/src/ffi/dtype.cc" +#include "3rdparty/tvm-ffi/src/ffi/error.cc" +#include "3rdparty/tvm-ffi/src/ffi/function.cc" +#include "3rdparty/tvm-ffi/src/ffi/object.cc" +#include "3rdparty/tvm-ffi/src/ffi/tensor.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc" +#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc" +#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc" + #include "src/runtime/contrib/sort/sort.cc" #include "src/runtime/cpu_device_api.cc" #include "src/runtime/device_api.cc" @@ -47,22 +66,6 @@ #include "src/runtime/rpc/rpc_session.cc" #include "src/runtime/tensor.cc" #include "src/runtime/workspace_pool.cc" -// relax setup -#include "3rdparty/tvm-ffi/src/ffi/backtrace.cc" -#include "3rdparty/tvm-ffi/src/ffi/container.cc" -#include "3rdparty/tvm-ffi/src/ffi/dtype.cc" -#include "3rdparty/tvm-ffi/src/ffi/error.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/env_c_api.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/env_context.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/json_parser.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/json_writer.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/library_module.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/library_module_system_lib.cc" -#include "3rdparty/tvm-ffi/src/ffi/extra/module.cc" -#include "3rdparty/tvm-ffi/src/ffi/function.cc" -#include "3rdparty/tvm-ffi/src/ffi/object.cc" -#include "3rdparty/tvm-ffi/src/ffi/tensor.cc" -#include "3rdparty/tvm-ffi/src/ffi/testing/testing.cc" #include "src/runtime/memory/memory_manager.cc" #include "src/runtime/nvtx.cc" #include "src/runtime/vm/attn_backend.cc" @@ -132,20 +135,28 @@ void ArrayDecodeStorage(Tensor cpu_arr, TVMFFIByteArray* bytes, const std::strin const char* byte_data = bytes->data; const size_t byte_size = bytes->size; if (format == "f32-to-bf16" && dtype == "float32") { - const uint16_t* bf16 = reinterpret_cast(byte_data); - uint32_t* data = static_cast(cpu_arr->data); TVM_FFI_ICHECK(cpu_arr.IsContiguous()); size_t size = 1; for (int i = 0; i < cpu_arr->ndim; ++i) { size *= cpu_arr->shape[i]; } - TVM_FFI_ICHECK_EQ(size, byte_size / 2); - for (size_t i = 0; i < size; ++i) { - data[i] = static_cast(bf16[i]) << 16; + // The "f32-to-bf16" format encodes a float32 tensor as packed bf16 (2 + // bytes per element). When the byte_size matches that expectation, expand + // back to f32. If the byte_size matches the native float32 width + // (4 bytes per element), the payload is already raw float32 — fall through + // to the generic byte copy. This makes the loader tolerant of weight + // shards produced by older / alternate quantisation pipelines that retain + // the "f32-to-bf16" tag without performing the bf16 truncation. + if (size == byte_size / 2) { + const uint16_t* bf16 = reinterpret_cast(byte_data); + uint32_t* data = static_cast(cpu_arr->data); + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(bf16[i]) << 16; + } + return; } - } else { - cpu_arr.CopyFromBytes(byte_data, byte_size); } + cpu_arr.CopyFromBytes(byte_data, byte_size); } TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/web/src/runtime.ts b/web/src/runtime.ts index a7b3a56f3e..5ab25dc772 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -1323,6 +1323,7 @@ export class Instance implements Disposable { artifactCache: ArtifactCacheTemplate, signal?: AbortSignal, ) { + const maxChunkBytes = 128 * 1024 * 1024; const perf = compact.getPerformance(); const tstart = perf.now(); let totalBytes = 0; @@ -1421,9 +1422,53 @@ export class Instance implements Disposable { this.empty(rec.shape, rec.dtype, this.cpu()) ) }); - const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes); + const shardBytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); + const recSource = + rec.byteOffset === 0 && rec.nbytes === shardBytes.byteLength + ? shardBytes + : shardBytes.subarray(rec.byteOffset, rec.byteOffset + rec.nbytes); + const canChunkRecord = + rec.nbytes > maxChunkBytes && + rec.shape.length >= 1 && + Number.isInteger(rec.shape[0]) && + rec.shape[0] > 0 && + rec.nbytes % rec.shape[0] === 0; + const copyRecordToTensor = (targetTensor: Tensor, sourceBytes: Uint8Array) => { + if (!canChunkRecord) { + this.ctx.arrayDecodeStorage(targetTensor, sourceBytes, rec.format, rec.dtype); + return; + } + const outerDim = rec.shape[0]; + const chunkStrideBytes = rec.nbytes / outerDim; + const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / chunkStrideBytes)); + for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) { + const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset); + const chunkByteOffset = outerOffset * chunkStrideBytes; + const chunkBytes = outerCount * chunkStrideBytes; + const chunkShape = rec.shape.slice(); + chunkShape[0] = outerCount; + // Wrap in withNewScope so TVM intermediate objects (shape tuple) + // are disposed after each chunk, but detach the view we need. + const chunkView = this.withNewScope(() => { + return this.detachFromCurrentScope( + this.ctx.tensorCreateView( + targetTensor, + this.ctx.makeShapeTuple(...chunkShape.map((value) => new Scalar(value, "int"))), + rec.dtype, + new Scalar(chunkByteOffset, "int"), + ) + ); + }); + const chunkSource = sourceBytes.subarray(chunkByteOffset, chunkByteOffset + chunkBytes); + try { + this.ctx.arrayDecodeStorage(chunkView, chunkSource, rec.format, rec.dtype); + } finally { + chunkView.dispose(); + } + } + }; // first sync copy to cpu. - this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype); + copyRecordToTensor(cpu_arr, recSource); // then async stream into GPU if needed if (device.deviceType === DeviceStrToEnum.cpu) { this.tensorCacheUpdate(rec.name, cpu_arr, false); @@ -1435,7 +1480,40 @@ export class Instance implements Disposable { this.empty(rec.shape, rec.dtype, device) ) }); - gpu_arr.copyFrom(cpu_arr); + if (!canChunkRecord) { + gpu_arr.copyFrom(cpu_arr); + } else { + const outerDim = rec.shape[0]; + const chunkStrideBytes = rec.nbytes / outerDim; + const chunkOuterDim = Math.max(1, Math.floor(maxChunkBytes / chunkStrideBytes)); + for (let outerOffset = 0; outerOffset < outerDim; outerOffset += chunkOuterDim) { + const outerCount = Math.min(chunkOuterDim, outerDim - outerOffset); + const chunkByteOffset = outerOffset * chunkStrideBytes; + const chunkShape = rec.shape.slice(); + chunkShape[0] = outerCount; + // Use withNewScope so the shape tuple is auto-disposed, + // and detach the views we need for manual lifetime control. + const [cpuView, gpuView] = this.withNewScope(() => { + const chunkShapeTuple = this.ctx.makeShapeTuple( + ...chunkShape.map((value) => new Scalar(value, "int")), + ); + return [ + this.detachFromCurrentScope( + this.ctx.tensorCreateView(cpu_arr, chunkShapeTuple, rec.dtype, new Scalar(chunkByteOffset, "int")) + ), + this.detachFromCurrentScope( + this.ctx.tensorCreateView(gpu_arr, chunkShapeTuple, rec.dtype, new Scalar(chunkByteOffset, "int")) + ), + ]; + }); + try { + gpuView.copyFrom(cpuView); + } finally { + cpuView.dispose(); + gpuView.dispose(); + } + } + } await device.sync(); this.tensorCacheUpdate(rec.name, gpu_arr, false); cpu_arr.dispose(); @@ -2258,6 +2336,25 @@ export class Instance implements Disposable { case TypeIndex.kTVMFFIOpaquePtr: { return this.memory.loadPointer(valuePtr); } + case TypeIndex.kTVMFFIShape: { + const shapeObjPtr = this.memory.loadPointer(valuePtr); + if (callbackArg) { + const shapeCellPtr = shapeObjPtr + SizeOf.ObjectHeader; + const shapeDataPtr = this.memory.loadPointer(shapeCellPtr); + const shapeLen = this.memory.loadUSize(shapeCellPtr + this.memory.sizeofPtr()); + const result = new Array(shapeLen); + for (let i = 0; i < shapeLen; ++i) { + result[i] = this.memory.loadI64(shapeDataPtr + i * SizeOf.I64); + } + this.lib.checkCall( + (this.lib.exports.TVMFFIObjectDecRef as ctypes.FTVMFFIObjectDecRef)(shapeObjPtr) + ); + return result; + } + return this.ctx.attachToCurrentScope( + new TVMObject(shapeObjPtr, this.lib, this.ctx) + ); + } case TypeIndex.kTVMFFITensor: { return this.ctx.attachToCurrentScope( new Tensor(this.memory.loadPointer(valuePtr), this.lib, this.ctx, false)