ggml-webgpu: Add fused RMS_NORM + MUL#21983
Conversation
|
@reeselevine This PR should conflicts with #21873 (especially in |
8fd976c to
5fef017
Compare
reeselevine
left a comment
There was a problem hiding this comment.
Thanks for starting to work on fusion, this is a big step!
The performance not changing too much is a little disappointing, but also not a blocker. Once we get the fusion format working we can optimize. I wonder if the reason for the lack of performance is just that the current RMS_NORM is not very well-optimized? So the reduction in bandwidth ends up being hidden because RMS_NORM is too slow.
Do you know if this fusion path leads to significant performance gains in other backends?
| size_t memset_bytes_per_thread; | ||
|
|
||
| bool disable_fusion; | ||
| uint32_t num_additional_fused_ops; |
There was a problem hiding this comment.
this general structure comes from the vulkan backend right? I haven't looked into it too closely, but my first thought is that it seems too general, at least based on this PR, because you end up having to check which ops you are actually fusing, and this doesn't encode that at all.
|
|
||
| static bool ggml_webgpu_can_fuse_check(webgpu_context & ctx, const struct ggml_cgraph * cgraph, int node_idx) { | ||
| // RMS_NORM + MUL | ||
| if (ggml_webgpu_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { |
There was a problem hiding this comment.
under the hood can_fuse ends up repeating the if condition on RMS_NORM + MUL, so really should we have separate functions for each set of potential fusions?
| if (!ctx->disable_fusion) { | ||
| ggml_webgpu_can_fuse_check(ctx, cgraph, i); | ||
| } | ||
| if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes, i)) { |
There was a problem hiding this comment.
Right now we're hiding whether the encode_node encodes based on whether the next node will be fused.
Instead, what do you think of just calling encode node (maybe function name should just change to encode to encompass the fact it might encode multiple nodes), and updating i based on the number of fused operations. So for the new RMS_NORM + MUL, we end updating i by 2. That avoids hiding the fusion in the additional_fused_ops variable. That to me seems cleaner for now, but maybe I'm missing something that doesn't translate well to future fusions?
| (uint32_t) dst->ne[1], | ||
| (uint32_t) dst->ne[2], | ||
| (uint32_t) dst->ne[3], | ||
| *(uint32_t *) rn_dst->op_params // epsilon, treated as f32 in the shader |
There was a problem hiding this comment.
this leads to compiler warnings and will fail when the new ggml-webgpu-nvidia-ci is enabled, I moved to a new format: https://www.accessdata.fda.gov/scripts/cdrh/cfdocs/cfpcd/classification.cfm?id=5911
Overview
This PR adds the initial kernel fusion to WebGPU backend with RMS_NORM + MUL (it is similar to #14800).
The performance on the major models on my device (M2, Metal 4) is as follows, but unfortunately, the performance is almost the same on this implementation.
The command is like this:
llama-bench -m Llama-3.2-3B-Instruct-Q4_K_M.gguf -fa 1 -p 512 -n 0a620695)Requirements