diff --git a/.gitmodules b/.gitmodules index 4b9b6084..28f0e627 100644 --- a/.gitmodules +++ b/.gitmodules @@ -3,4 +3,4 @@ url = https://github.com/ml-explore/mlx [submodule "submodules/mlx-c"] path = Source/Cmlx/mlx-c - url = https://github.com/ml-explore/mlx-c + url = https://github.com/RNT56/mlx-c diff --git a/MAINTENANCE.md b/MAINTENANCE.md index f71c5b26..decc2902 100644 --- a/MAINTENANCE.md +++ b/MAINTENANCE.md @@ -126,13 +126,13 @@ git submodules to include the `mlx` and `mlx-c` repositories. When a new version of `mlx` and its equivalent `mlx-c` are to be used, there is a process to go through to update `mlx-swift`. -Additionally, SwiftPM supports plugins that can produce derived source for -building, but this can only produce new swift source. It is possible to use -plugins to generate new source `.cpp` files and even compile them, but at -best the `.o` is copied into the output as a resource, not linked. -This is important because `mlx` has some build-time source generation -(e.g. `make_compiled_preamble.sh`). This is handled in `mlx-swift` by -pre-generating the source when updating the `mlx` version. +Additionally, SwiftPM supports plugins that can produce derived source and +resources for building. It is possible to use plugins to generate new source +`.cpp` files and even compile them, but at best the `.o` is copied into the +output as a resource, not linked. This is important because `mlx` has some +build-time source generation (e.g. `make_compiled_preamble.sh`). This is +handled in `mlx-swift` by pre-generating the source when updating the `mlx` +version, while the SwiftPM Metal library is generated as a build resource. 1. Update the `mlx` and `mlx-c` submodules via `git pull` or `git checkout ...` - `Source/Cmlx/mlx` @@ -143,6 +143,9 @@ pre-generating the source when updating the `mlx` version. - this updates headers in Source/Cmlx/include - this updates headers in Source/Cmlx/include-framework - this generates various files in Source/Cmlx/mlx-generated + - SwiftPM builds generate `default.metallib` through the + `BuildSwiftPMMetalLibrary` plugin; do not check in copied Metal sources or + a concatenated embedded fallback. 4. Fix any build issues with SwiftPM build (opening Package.swift) 5. Fix any build issues with xcodeproj build (opening xcode/MLX.codeproj), see also [README.xcodeproj.md] @@ -163,7 +166,9 @@ After updating the mlx/mlx-c version the xcodeproj needs to be brought up to dat - no other headers in the project should be included as resources (public/private/project) - the easiest way to adjust is look at Project -> Cmlx -> Build Phases and then look at the Headers task - similarly there should be _no_ Copy Bundle Resources from the same section -- compilation issues in .metal files typically mean they are new to the project and need to be removed from Cmlx target membership +- compilation issues in `.metal` files usually mean the SwiftPM Metal plugin's + kernel list or include dependencies need to be updated, or the files need to + remain excluded from normal Cmlx target membership ### Cmlx @@ -181,4 +186,3 @@ Settings, including header search paths are in xcode/xcconfig. ### MLX, etc. These are just normal frameworks that link to Cmlx and others as needed. The source files are all swift and there are no special settings needed. - diff --git a/Package.swift b/Package.swift index 17a4178f..5be309ab 100644 --- a/Package.swift +++ b/Package.swift @@ -71,6 +71,8 @@ import PackageDescription "MLXFast.swift", "MLXFastKernel.swift", ] + + let cmlxPlugins: [Target.PluginUsage]? = nil #else let platformExcludes: [String] = [ "mlx/mlx/backend/cpu/compiled.cpp", @@ -102,6 +104,10 @@ import PackageDescription ] let mlxSwiftExcludes: [String] = [] + + let cmlxPlugins: [Target.PluginUsage]? = [ + "BuildSwiftPMMetalLibrary" + ] #endif let cmlx = Target.target( @@ -211,7 +217,8 @@ let cmlx = Target.target( .headerSearchPath("fmt/include"), .define("MLX_VERSION", to: "\"0.31.1\""), ], - linkerSettings: linkerSettings + linkerSettings: linkerSettings, + plugins: cmlxPlugins ) let package = Package( @@ -240,6 +247,10 @@ let package = Package( ], targets: [ cmlx, + .plugin( + name: "BuildSwiftPMMetalLibrary", + capability: .buildTool() + ), .testTarget( name: "CmlxTests", dependencies: ["Cmlx"] diff --git a/Plugins/BuildSwiftPMMetalLibrary/plugin.swift b/Plugins/BuildSwiftPMMetalLibrary/plugin.swift new file mode 100644 index 00000000..b61460e1 --- /dev/null +++ b/Plugins/BuildSwiftPMMetalLibrary/plugin.swift @@ -0,0 +1,55 @@ +import Foundation +import PackagePlugin + +@main +struct BuildSwiftPMMetalLibrary: BuildToolPlugin { + func createBuildCommands(context: PluginContext, target: any Target) async throws -> [Command] { + #if os(Linux) + return [] + #else + let packageRoot = context.package.directory + let script = packageRoot.appending("tools", "build-swiftpm-metallib.sh") + let output = context.pluginWorkDirectory.appending("default.metallib") + + return [ + .buildCommand( + displayName: "Build SwiftPM default.metallib", + executable: Path("/bin/bash"), + arguments: [script, output], + inputFiles: inputFiles(packageRoot: packageRoot, script: script), + outputFiles: [output] + ) + ] + #endif + } + + #if !os(Linux) + private func inputFiles(packageRoot: Path, script: Path) -> [Path] { + let kernelsDirectory = packageRoot.appending( + "Source", + "Cmlx", + "mlx", + "mlx", + "backend", + "metal", + "kernels" + ) + var files = [script] + files.append(contentsOf: recursivelyCollectedMetalInputs(in: kernelsDirectory)) + return files + } + + private func recursivelyCollectedMetalInputs(in directory: Path) -> [Path] { + let fileManager = FileManager.default + guard let enumerator = fileManager.enumerator(atPath: directory.string) else { + return [] + } + + return enumerator.compactMap { entry -> Path? in + guard let entry = entry as? String else { return nil } + guard entry.hasSuffix(".metal") || entry.hasSuffix(".h") else { return nil } + return directory.appending(subpath: entry) + }.sorted { $0.string < $1.string } + } + #endif +} diff --git a/Source/Cmlx/include/mlx/c/fast.h b/Source/Cmlx/include/mlx/c/fast.h index c825d00e..44027130 100644 --- a/Source/Cmlx/include/mlx/c/fast.h +++ b/Source/Cmlx/include/mlx/c/fast.h @@ -63,6 +63,10 @@ int mlx_fast_cuda_kernel_config_add_template_arg_int( mlx_fast_cuda_kernel_config cls, const char* name, int value); +int mlx_fast_cuda_kernel_config_add_template_arg_uint32( + mlx_fast_cuda_kernel_config cls, + const char* name, + uint32_t value); int mlx_fast_cuda_kernel_config_add_template_arg_bool( mlx_fast_cuda_kernel_config cls, const char* name, @@ -133,6 +137,10 @@ int mlx_fast_metal_kernel_config_add_template_arg_int( mlx_fast_metal_kernel_config cls, const char* name, int value); +int mlx_fast_metal_kernel_config_add_template_arg_uint32( + mlx_fast_metal_kernel_config cls, + const char* name, + uint32_t value); int mlx_fast_metal_kernel_config_add_template_arg_bool( mlx_fast_metal_kernel_config cls, const char* name, diff --git a/Source/Cmlx/mlx b/Source/Cmlx/mlx index ce45c525..d999c27e 160000 --- a/Source/Cmlx/mlx +++ b/Source/Cmlx/mlx @@ -1 +1 @@ -Subproject commit ce45c52505c8158ea48d2a54e8caae05efd86bfe +Subproject commit d999c27ecd549e65f8f689bdd5c83648da977b81 diff --git a/Source/Cmlx/mlx-c b/Source/Cmlx/mlx-c index 0726ca92..f710a589 160000 --- a/Source/Cmlx/mlx-c +++ b/Source/Cmlx/mlx-c @@ -1 +1 @@ -Subproject commit 0726ca922fc902c4c61ef9c27d94132be418e945 +Subproject commit f710a589ede164b9e8afb49d60163db8083a2550 diff --git a/Source/Cmlx/mlx-generated/metal/arange.h b/Source/Cmlx/mlx-generated/metal/arange.h deleted file mode 100644 index 5448fe9a..00000000 --- a/Source/Cmlx/mlx-generated/metal/arange.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. -template -[[kernel]] void arange( - constant const T& start, - constant const T& step, - device T* out, - uint index [[thread_position_in_grid]]) { - out[index] = start + index * step; -} diff --git a/Source/Cmlx/mlx-generated/metal/arg_reduce.metal b/Source/Cmlx/mlx-generated/metal/arg_reduce.metal deleted file mode 100644 index 3cd95c52..00000000 --- a/Source/Cmlx/mlx-generated/metal/arg_reduce.metal +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#include - -#include "utils.h" - -using namespace metal; - -template -struct IndexValPair { - uint32_t index; - U val; -}; - -template -struct ArgMin { - static constexpr constant U init = Limits::max; - - IndexValPair reduce(IndexValPair best, IndexValPair current) { - if (best.val > current.val || - (best.val == current.val && best.index > current.index)) { - return current; - } else { - return best; - } - } - - template - IndexValPair - reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { - for (int i = 0; i < N; i++) { - if (vals[i] < best.val) { - best.val = vals[i]; - best.index = offset + i; - } - } - return best; - } -}; - -template -struct ArgMax { - static constexpr constant U init = Limits::min; - - IndexValPair reduce(IndexValPair best, IndexValPair current) { - if (best.val < current.val || - (best.val == current.val && best.index > current.index)) { - return current; - } else { - return best; - } - } - - template - IndexValPair - reduce_many(IndexValPair best, thread U* vals, uint32_t offset) { - for (int i = 0; i < N; i++) { - if (vals[i] > best.val) { - best.val = vals[i]; - best.index = offset + i; - } - } - return best; - } -}; - -template -IndexValPair simd_shuffle_down(IndexValPair data, uint16_t delta) { - return IndexValPair{ - simd_shuffle_down(data.index, delta), simd_shuffle_down(data.val, delta)}; -} - -template -[[kernel]] void arg_reduce_general( - const device T* in [[buffer(0)]], - device uint32_t* out [[buffer(1)]], - const constant int* shape [[buffer(2)]], - const constant int64_t* in_strides [[buffer(3)]], - const constant int64_t* out_strides [[buffer(4)]], - const constant size_t& ndim [[buffer(5)]], - const constant int64_t& axis_stride [[buffer(6)]], - const constant size_t& axis_size [[buffer(7)]], - uint3 gid [[thread_position_in_grid]], - uint3 gsize [[threads_per_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint simd_size [[threads_per_simdgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - // Shapes and strides *do not* contain the reduction axis. The reduction size - // and stride are provided in axis_stride and axis_size. - // - // Note: in shape == out shape with this convention. - // - // The sketch of the kernel is as follows. - // 1. Launch prod(shape) * thread_group_size threads. - // 2. Loop ceildiv(axis_size / lsize) times - // 3. Read input values - // 4. Reduce among them and go to 3 - // 4. Reduce in each simd_group - // 6. Write in the thread local memory - // 6. Reduce them across thread group - // 7. Write the output without need for atomic - Op op; - - // Compute the input/output index. There is one beginning and one output for - // the whole threadgroup. - int64_t row_idx = gid.y + static_cast(gsize.y) * gid.z; - auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim); - auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim); - - IndexValPair best{0, Op::init}; - - threadgroup IndexValPair local_data[32]; - - // Loop over the reduction axis in lsize*N_READS buckets - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { - // Read the current value - uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS; - uint32_t offset = current_index; - const device T* current_in = in + in_idx + current_index * axis_stride; - T vals[N_READS]; - for (int i = 0; i < N_READS; i++) { - vals[i] = (current_index < axis_size) ? *current_in : T(Op::init); - current_index++; - current_in += axis_stride; - } - best = op.template reduce_many(best, vals, offset); - } - // At this point we have reduced the axis into thread group best values so we - // need to reduce across the thread group. - - // First per simd reduction. - for (uint offset = simd_size / 2; offset > 0; offset /= 2) { - IndexValPair neighbor = simd_shuffle_down(best, offset); - best = op.reduce(best, neighbor); - } - - // Write to the threadgroup memory - if (simd_lane_id == 0) { - local_data[simd_group_id] = best; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_group_id != 0) { - return; - } - - // Read the appropriate value from local data and perform one simd reduction - uint simd_groups = ceildiv(lsize.x, simd_size); - if (simd_lane_id < simd_groups) { - best = local_data[simd_lane_id]; - } - for (uint offset = simd_size / 2; offset > 0; offset /= 2) { - IndexValPair neighbor = simd_shuffle_down(best, offset); - best = op.reduce(best, neighbor); - } - - // Finally write the output - if (lid.x == 0) { - out[out_idx] = best.index; - } -} - -// clang-format off -#define instantiate_arg_reduce(name, itype) \ - instantiate_kernel( \ - "argmin_" #name, arg_reduce_general, itype, ArgMin) \ - instantiate_kernel( \ - "argmax_" #name, arg_reduce_general, itype, ArgMax) - -instantiate_arg_reduce(bool_, bool) -instantiate_arg_reduce(uint8, uint8_t) -instantiate_arg_reduce(uint16, uint16_t) -instantiate_arg_reduce(uint32, uint32_t) -instantiate_arg_reduce(uint64, uint64_t) -instantiate_arg_reduce(int8, int8_t) -instantiate_arg_reduce(int16, int16_t) -instantiate_arg_reduce(int32, int32_t) -instantiate_arg_reduce(int64, int64_t) -instantiate_arg_reduce(float16, half) -instantiate_arg_reduce(float32, float) -instantiate_arg_reduce(bfloat16, bfloat16_t) // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/atomic.h b/Source/Cmlx/mlx-generated/metal/atomic.h deleted file mode 100644 index 93952c2c..00000000 --- a/Source/Cmlx/mlx-generated/metal/atomic.h +++ /dev/null @@ -1,345 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -#include -#include - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// Atomic utils -/////////////////////////////////////////////////////////////////////////////// - -#pragma METAL internals : enable -template -constexpr constant bool is_metal_atomic = _disjunction< - is_same, - is_same, - is_same, - is_same>::value; - -#pragma METAL internals : disable - -template -struct mlx_atomic { - atomic val; -}; - -template -struct mlx_atomic>> { - atomic val; -}; - -/////////////////////////////////////////////////////////////////////////////// -// Native metal atomics -/////////////////////////////////////////////////////////////////////////////// - -template , bool> = true> -METAL_FUNC T -mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { - return atomic_load_explicit(&(object[offset].val), memory_order_relaxed); -} - -template , bool> = true> -METAL_FUNC void -mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { - atomic_store_explicit(&(object[offset].val), val, memory_order_relaxed); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_and_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - atomic_fetch_and_explicit(&(object[offset].val), val, memory_order_relaxed); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_or_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - atomic_fetch_or_explicit(&(object[offset].val), val, memory_order_relaxed); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_min_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - atomic_fetch_min_explicit(&(object[offset].val), val, memory_order_relaxed); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_max_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - atomic_fetch_max_explicit(&(object[offset].val), val, memory_order_relaxed); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_add_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - atomic_fetch_add_explicit(&(object[offset].val), val, memory_order_relaxed); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_mul_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - T expected = mlx_atomic_load_explicit(object, offset); - while (!mlx_atomic_compare_exchange_weak_explicit( - object, &expected, val * expected, offset)) { - } -} - -template , bool> = true> -METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( - device mlx_atomic* object, - thread T* expected, - T val, - size_t offset) { - return atomic_compare_exchange_weak_explicit( - &(object[offset].val), - expected, - val, - memory_order_relaxed, - memory_order_relaxed); -} - -// Specialization for float since it does not atomic_fetch_min_explicit -template <> -METAL_FUNC void mlx_atomic_fetch_min_explicit( - device mlx_atomic* object, - float val, - size_t offset) { - float expected = mlx_atomic_load_explicit(object, offset); - while (val < expected) { - if (mlx_atomic_compare_exchange_weak_explicit( - object, &expected, val, offset)) { - return; - } - } -} - -// Specialization for float since it does not atomic_fetch_max_explicit -template <> -METAL_FUNC void mlx_atomic_fetch_max_explicit( - device mlx_atomic* object, - float val, - size_t offset) { - float expected = mlx_atomic_load_explicit(object, offset); - while (val > expected) { - if (mlx_atomic_compare_exchange_weak_explicit( - object, &expected, val, offset)) { - return; - } - } -} - -/////////////////////////////////////////////////////////////////////////////// -// Custom atomics -/////////////////////////////////////////////////////////////////////////////// - -namespace { - -template -constexpr constant uint packing_size = sizeof(uint) / sizeof(T); - -template -union uint_or_packed { - T val[packing_size]; - uint bits; -}; - -template -struct mlx_atomic_update_helper { - uint operator()(uint_or_packed init, T update, size_t elem_offset) { - Op op; - init.val[elem_offset] = op(update, init.val[elem_offset]); - return init.bits; - } -}; - -template -METAL_FUNC void mlx_atomic_update_and_store( - device mlx_atomic* object, - T update, - size_t offset) { - size_t pack_offset = offset / packing_size; - size_t elem_offset = offset % packing_size; - - mlx_atomic_update_helper helper; - uint_or_packed expected; - expected.bits = - atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); - - while (Op::condition(update, expected.val[elem_offset]) && - !mlx_atomic_compare_exchange_weak_explicit( - object, - &(expected.bits), - helper(expected, update, elem_offset), - pack_offset)) { - } -} - -template -struct __None { - static bool condition(T a, T b) { -#pragma unused(a) -#pragma unused(b) - return true; - } - - T operator()(T a, T b) { -#pragma unused(b) - return a; - } -}; - -template -struct __Add { - static bool condition(T a, T b) { -#pragma unused(a) -#pragma unused(b) - return true; - } - - T operator()(T a, T b) { - return a + b; - } -}; - -template -struct __Mul { - static bool condition(T a, T b) { -#pragma unused(a) - return b != 0; - } - - T operator()(T a, T b) { - return a * b; - } -}; - -template -struct __Max { - static bool condition(T a, T b) { - return a > b; - } - - T operator()(T a, T b) { - return max(a, b); - } -}; - -template -struct __Min { - static bool condition(T a, T b) { - return a < b; - } - - T operator()(T a, T b) { - return min(a, b); - } -}; - -} // namespace - -template , bool> = true> -METAL_FUNC T -mlx_atomic_load_explicit(device mlx_atomic* object, size_t offset) { - size_t pack_offset = offset / sizeof(T); - size_t elem_offset = offset % sizeof(T); - uint_or_packed packed_val; - packed_val.bits = - atomic_load_explicit(&(object[pack_offset].val), memory_order_relaxed); - return packed_val.val[elem_offset]; -} - -template , bool> = true> -METAL_FUNC void -mlx_atomic_store_explicit(device mlx_atomic* object, T val, size_t offset) { - mlx_atomic_update_and_store>(object, val, offset); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_and_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - size_t pack_offset = offset / packing_size; - size_t elem_offset = offset % packing_size; - uint_or_packed identity; - identity.bits = __UINT32_MAX__; - identity.val[elem_offset] = val; - - atomic_fetch_and_explicit( - &(object[pack_offset].val), identity.bits, memory_order_relaxed); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_or_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - size_t pack_offset = offset / packing_size; - size_t elem_offset = offset % packing_size; - uint_or_packed identity; - identity.bits = 0; - identity.val[elem_offset] = val; - - atomic_fetch_or_explicit( - &(object[pack_offset].val), identity.bits, memory_order_relaxed); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_min_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - mlx_atomic_update_and_store>(object, val, offset); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_max_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - mlx_atomic_update_and_store>(object, val, offset); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_add_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - mlx_atomic_update_and_store>(object, val, offset); -} - -template , bool> = true> -METAL_FUNC void mlx_atomic_fetch_mul_explicit( - device mlx_atomic* object, - T val, - size_t offset) { - mlx_atomic_update_and_store>(object, val, offset); -} - -template , bool> = true> -METAL_FUNC bool mlx_atomic_compare_exchange_weak_explicit( - device mlx_atomic* object, - thread uint* expected, - uint val, - size_t offset) { - return atomic_compare_exchange_weak_explicit( - &(object[offset].val), - expected, - val, - memory_order_relaxed, - memory_order_relaxed); -} diff --git a/Source/Cmlx/mlx-generated/metal/bf16.h b/Source/Cmlx/mlx-generated/metal/bf16.h deleted file mode 100644 index aa3c3c78..00000000 --- a/Source/Cmlx/mlx-generated/metal/bf16.h +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -#include - -using namespace metal; - -typedef bfloat bfloat16_t; -inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { - return as_type(x); -} - -inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { - return as_type(x); -} diff --git a/Source/Cmlx/mlx-generated/metal/bf16_math.h b/Source/Cmlx/mlx-generated/metal/bf16_math.h deleted file mode 100644 index 0643fb3e..00000000 --- a/Source/Cmlx/mlx-generated/metal/bf16_math.h +++ /dev/null @@ -1,380 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -/////////////////////////////////////////////////////////////////////////////// -// Metal math for bfloat16 -/////////////////////////////////////////////////////////////////////////////// - -/* - -Following the Metal Shading Language Specification (Metal 3.1) - -"bfloat is an extended itypeing point type that only allows implicit conversion - to a type of greater itypeing point rank. While bfloat can be implicitly - converted to itype, it cannot be implicitly converted to half, and neither - itype nor half can be implicitly converted to bfloat." - -Further, as far as I can tell, the stdlib math/simd functions are not defined -for bfloat and calling with an argument of type bfloat will result in that -argument getting implicitly converted to itype which then returns an output -that is (likely) a itype which cannot be implicitly converted into a bfloat - -This leads to situations where -bfloat a = 5.0bf; -bfloat b = metal::abs(a); // this will throw an error since abs return itype -bfloat c = static_cast(metal::abs(a)); // this is fine - -For the moment, I will be adding overloaded instantiations of the math -functions to accordingly automatically handle the casting - -*/ - -#define instantiate_metal_math_funcs(itype, otype, ctype, mfast) \ - \ - METAL_FUNC otype abs(itype x) { \ - return static_cast(__metal_fabs(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype acos(itype x) { \ - return static_cast(__metal_acos(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype acosh(itype x) { \ - return static_cast(__metal_acosh(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype asin(itype x) { \ - return static_cast(__metal_asin(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype asinh(itype x) { \ - return static_cast(__metal_asinh(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype atan(itype y_over_x) { \ - return static_cast( \ - __metal_atan(static_cast(y_over_x), mfast)); \ - } \ - METAL_FUNC otype atan2(itype y, itype x) { \ - return static_cast( \ - __metal_atan2(static_cast(y), static_cast(x), mfast)); \ - } \ - METAL_FUNC otype atanh(itype x) { \ - return static_cast(__metal_atanh(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype ceil(itype x) { \ - return static_cast(__metal_ceil(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype cos(itype x) { \ - return static_cast(__metal_cos(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype cosh(itype x) { \ - return static_cast(__metal_cosh(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype cospi(itype x) { \ - return static_cast(__metal_cospi(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype divide(itype x, itype y) { \ - return static_cast( \ - __metal_divide(static_cast(x), static_cast(y), mfast)); \ - } \ - METAL_FUNC otype exp(itype x) { \ - return static_cast(__metal_exp(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype exp10(itype x) { \ - return static_cast(__metal_exp10(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype exp2(itype x) { \ - return static_cast(__metal_exp2(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype fabs(itype x) { \ - return static_cast(__metal_fabs(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype fdim(itype x, itype y) { \ - ctype t = static_cast(x - y); \ - return static_cast(select(t, ctype(0), t < ctype(0) || x == y)); \ - } \ - METAL_FUNC otype floor(itype x) { \ - return static_cast(__metal_floor(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype fma(itype x, itype y, itype z) { \ - return static_cast(__metal_fma( \ - static_cast(x), static_cast(y), static_cast(z))); \ - } \ - METAL_FUNC otype fmax(itype x, itype y) { \ - return static_cast( \ - __metal_fmax(static_cast(x), static_cast(y), mfast)); \ - } \ - METAL_FUNC otype fmax3(itype x, itype y, itype z) { \ - return static_cast(__metal_fmax3( \ - static_cast(x), \ - static_cast(y), \ - static_cast(z), \ - mfast)); \ - } \ - METAL_FUNC otype fmedian3(itype x, itype y, itype z) { \ - return static_cast(__metal_fmedian3( \ - static_cast(x), \ - static_cast(y), \ - static_cast(z), \ - mfast)); \ - } \ - METAL_FUNC otype fmin(itype x, itype y) { \ - return static_cast( \ - __metal_fmin(static_cast(x), static_cast(y), mfast)); \ - } \ - METAL_FUNC otype fmin3(itype x, itype y, itype z) { \ - return static_cast(__metal_fmin3( \ - static_cast(x), \ - static_cast(y), \ - static_cast(z), \ - mfast)); \ - } \ - METAL_FUNC otype fmod(itype x, itype y) { \ - return static_cast( \ - __metal_fmod(static_cast(x), static_cast(y), mfast)); \ - } \ - METAL_FUNC otype fract(itype x) { \ - return static_cast(__metal_fract(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype frexp(itype x, thread int& exp) { \ - return static_cast(__metal_frexp(static_cast(x), &exp)); \ - } \ - METAL_FUNC otype ldexp(itype x, int k) { \ - return static_cast(__metal_ldexp(static_cast(x), k, mfast)); \ - } \ - METAL_FUNC otype log(itype x) { \ - return static_cast(__metal_log(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype log10(itype x) { \ - return static_cast(__metal_log10(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype log2(itype x) { \ - return static_cast(__metal_log2(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype max(itype x, itype y) { \ - return static_cast( \ - __metal_fmax(static_cast(x), static_cast(y), mfast)); \ - } \ - METAL_FUNC otype max3(itype x, itype y, itype z) { \ - return static_cast(__metal_fmax3( \ - static_cast(x), \ - static_cast(y), \ - static_cast(z), \ - mfast)); \ - } \ - METAL_FUNC otype median3(itype x, itype y, itype z) { \ - return static_cast(__metal_fmedian3( \ - static_cast(x), \ - static_cast(y), \ - static_cast(z), \ - mfast)); \ - } \ - METAL_FUNC otype min(itype x, itype y) { \ - return static_cast( \ - __metal_fmin(static_cast(x), static_cast(y), mfast)); \ - } \ - METAL_FUNC otype min3(itype x, itype y, itype z) { \ - return static_cast(__metal_fmin3( \ - static_cast(x), \ - static_cast(y), \ - static_cast(z), \ - mfast)); \ - } \ - METAL_FUNC otype nextafter(itype x, itype y) { \ - return static_cast( \ - __metal_nextafter(static_cast(x), static_cast(y))); \ - } \ - METAL_FUNC otype pow(itype x, itype y) { \ - return static_cast( \ - __metal_pow(static_cast(x), static_cast(y), mfast)); \ - } \ - METAL_FUNC otype powr(itype x, itype y) { \ - return static_cast( \ - __metal_powr(static_cast(x), static_cast(y), mfast)); \ - } \ - METAL_FUNC otype rint(itype x) { \ - return static_cast(__metal_rint(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype round(itype x) { \ - return static_cast(__metal_round(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype rsqrt(itype x) { \ - return static_cast(__metal_rsqrt(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype sin(itype x) { \ - return static_cast(__metal_sin(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype sinh(itype x) { \ - return static_cast(__metal_sinh(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype sinpi(itype x) { \ - return static_cast(__metal_sinpi(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype sqrt(itype x) { \ - return static_cast(__metal_sqrt(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype tan(itype x) { \ - return static_cast(__metal_tan(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype tanh(itype x) { \ - return static_cast(__metal_tanh(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype tanpi(itype x) { \ - return static_cast(__metal_tanpi(static_cast(x), mfast)); \ - } \ - METAL_FUNC otype trunc(itype x) { \ - return static_cast(__metal_trunc(static_cast(x), mfast)); \ - } - -namespace metal { - -instantiate_metal_math_funcs( - bfloat16_t, - bfloat16_t, - float, - __METAL_MAYBE_FAST_MATH__); - -namespace fast { - -instantiate_metal_math_funcs( - bfloat16_t, - bfloat16_t, - float, - __METAL_FAST_MATH__); - -} // namespace fast - -namespace precise { - -instantiate_metal_math_funcs( - bfloat16_t, - bfloat16_t, - float, - __METAL_PRECISE_MATH__); - -} // namespace precise - -} // namespace metal - -/////////////////////////////////////////////////////////////////////////////// -// Metal simd for bfloat16 -/////////////////////////////////////////////////////////////////////////////// - -#define instantiate_metal_simd_comm_funcs( \ - itype, otype, ctype, itype_to_ctype, ctype_to_otype) \ - \ - METAL_FUNC otype simd_broadcast(itype data, ushort broadcast_lane_id) { \ - return ctype_to_otype( \ - __metal_simd_broadcast(itype_to_ctype(data), broadcast_lane_id)); \ - } \ - \ - METAL_FUNC otype simd_shuffle(itype data, ushort simd_lane_id) { \ - return ctype_to_otype( \ - __metal_simd_shuffle(itype_to_ctype(data), simd_lane_id)); \ - } \ - \ - METAL_FUNC otype simd_shuffle_and_fill_down( \ - itype data, itype filling_data, ushort delta, ushort modulo) { \ - return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ - itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ - } \ - \ - METAL_FUNC otype simd_shuffle_and_fill_down( \ - itype data, itype filling_data, ushort delta) { \ - return ctype_to_otype(__metal_simd_shuffle_and_fill_down( \ - itype_to_ctype(data), \ - itype_to_ctype(filling_data), \ - delta, \ - __metal_get_simdgroup_size(ushort()))); \ - } \ - \ - METAL_FUNC otype simd_shuffle_and_fill_up( \ - itype data, itype filling_data, ushort delta, ushort modulo) { \ - return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ - itype_to_ctype(data), itype_to_ctype(filling_data), delta, modulo)); \ - } \ - \ - METAL_FUNC otype simd_shuffle_and_fill_up( \ - itype data, itype filling_data, ushort delta) { \ - return ctype_to_otype(__metal_simd_shuffle_and_fill_up( \ - itype_to_ctype(data), \ - itype_to_ctype(filling_data), \ - delta, \ - __metal_get_simdgroup_size(ushort()))); \ - } \ - \ - METAL_FUNC otype simd_shuffle_down(itype data, ushort delta) { \ - return ctype_to_otype( \ - __metal_simd_shuffle_down(itype_to_ctype(data), delta)); \ - } \ - \ - METAL_FUNC otype simd_shuffle_rotate_down(itype data, ushort delta) { \ - return ctype_to_otype( \ - __metal_simd_shuffle_rotate_down(itype_to_ctype(data), delta)); \ - } \ - \ - METAL_FUNC otype simd_shuffle_rotate_up(itype data, ushort delta) { \ - return ctype_to_otype( \ - __metal_simd_shuffle_rotate_up(itype_to_ctype(data), delta)); \ - } \ - \ - METAL_FUNC otype simd_shuffle_up(itype data, ushort delta) { \ - return ctype_to_otype( \ - __metal_simd_shuffle_up(itype_to_ctype(data), delta)); \ - } \ - \ - METAL_FUNC otype simd_shuffle_xor(itype data, ushort mask) { \ - return ctype_to_otype( \ - __metal_simd_shuffle_xor(itype_to_ctype(data), mask)); \ - } - -#define instantiate_metal_simd_reduction_funcs(itype, otype, ctype) \ - \ - METAL_FUNC otype simd_max(itype data) { \ - return static_cast(__metal_simd_max(static_cast(data))); \ - } \ - \ - METAL_FUNC otype simd_min(itype data) { \ - return static_cast(__metal_simd_min(static_cast(data))); \ - } \ - \ - METAL_FUNC otype simd_prefix_exclusive_product(itype data) { \ - return static_cast( \ - __metal_simd_prefix_exclusive_product(static_cast(data))); \ - } \ - \ - METAL_FUNC otype simd_prefix_exclusive_sum(itype data) { \ - return static_cast( \ - __metal_simd_prefix_exclusive_sum(static_cast(data))); \ - } \ - \ - METAL_FUNC otype simd_prefix_inclusive_product(itype data) { \ - return static_cast( \ - __metal_simd_prefix_inclusive_product(static_cast(data))); \ - } \ - \ - METAL_FUNC otype simd_prefix_inclusive_sum(itype data) { \ - return static_cast( \ - __metal_simd_prefix_inclusive_sum(static_cast(data))); \ - } \ - \ - METAL_FUNC otype simd_product(itype data) { \ - return static_cast(__metal_simd_product(static_cast(data))); \ - } \ - \ - METAL_FUNC otype simd_sum(itype data) { \ - return static_cast(__metal_simd_sum(static_cast(data))); \ - } \ - \ - METAL_FUNC otype simd_xor(itype data) { \ - return static_cast(__metal_simd_xor(static_cast(data))); \ - } - -namespace metal { - -instantiate_metal_simd_comm_funcs( - bfloat16_t, - bfloat16_t, - uint16_t, - bfloat16_to_uint16, - uint16_to_bfloat16); -instantiate_metal_simd_reduction_funcs(bfloat16_t, bfloat16_t, float); - -} // namespace metal diff --git a/Source/Cmlx/mlx-generated/metal/binary.h b/Source/Cmlx/mlx-generated/metal/binary.h deleted file mode 100644 index f1df8853..00000000 --- a/Source/Cmlx/mlx-generated/metal/binary.h +++ /dev/null @@ -1,199 +0,0 @@ -// Copyright © 2024 Apple Inc. - -template -[[kernel]] void binary_ss( - device const T* a, - device const T* b, - device U* c, - uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[0]); -} - -template ::n> -[[kernel]] void binary_sv( - device const T* a, - device const T* b, - device U* c, - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - c[index + i] = Op()(a[0], b[index + i]); - } - } else { - for (int i = 0; i < N; ++i) { - c[index + i] = Op()(a[0], b[index + i]); - } - } -} - -template ::n> -[[kernel]] void binary_vs( - device const T* a, - device const T* b, - device U* c, - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - c[index + i] = Op()(a[index + i], b[0]); - } - } else { - for (int i = 0; i < N; ++i) { - c[index + i] = Op()(a[index + i], b[0]); - } - } -} - -template ::n> -[[kernel]] void binary_vv( - device const T* a, - device const T* b, - device U* c, - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - c[index + i] = Op()(a[index + i], b[index + i]); - } - } else { - for (int i = 0; i < N; ++i) { - c[index + i] = Op()(a[index + i], b[index + i]); - } - } -} - -template ::n> -[[kernel]] void binary_sv2( - device const T* a, - device const T* b, - device U* c, - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - c[offset + i] = Op()(a[0], b[offset + i]); - } - } else { - for (int i = 0; i < N; ++i) { - c[offset + i] = Op()(a[0], b[offset + i]); - } - } -} - -template ::n> -[[kernel]] void binary_vs2( - device const T* a, - device const T* b, - device U* c, - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - c[offset + i] = Op()(a[offset + i], b[0]); - } - } else { - for (int i = 0; i < N; ++i) { - c[offset + i] = Op()(a[offset + i], b[0]); - } - } -} - -template ::n> -[[kernel]] void binary_vv2( - device const T* a, - device const T* b, - device U* c, - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - c[offset + i] = Op()(a[offset + i], b[offset + i]); - } - } else { - for (int i = 0; i < N; ++i) { - c[offset + i] = Op()(a[offset + i], b[offset + i]); - } - } -} - -template -[[kernel]] void binary_g_nd1( - device const T* a, - device const T* b, - device U* c, - constant const int64_t& a_stride, - constant const int64_t& b_stride, - uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); - c[index] = Op()(a[a_idx], b[b_idx]); -} - -template -[[kernel]] void binary_g_nd2( - device const T* a, - device const T* b, - device U* c, - constant const int64_t a_strides[2], - constant const int64_t b_strides[2], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; - c[out_idx] = Op()(a[a_idx], b[b_idx]); -} - -template -[[kernel]] void binary_g_nd3( - device const T* a, - device const T* b, - device U* c, - constant const int64_t a_strides[3], - constant const int64_t b_strides[3], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); - c[out_idx] = Op()(a[a_idx], b[b_idx]); -} - -template < - typename T, - typename U, - typename Op, - int N = 1, - typename IdxT = int64_t> -[[kernel]] void binary_g( - device const T* a, - device const T* b, - device U* c, - constant const int* shape, - constant const int64_t* a_strides, - constant const int64_t* b_strides, - constant const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( - {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); - auto xshape = shape[ndim - 1]; - IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); - IdxT a_xstride = a_strides[ndim - 1]; - IdxT b_xstride = b_strides[ndim - 1]; - for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { - c[out_idx++] = Op()(a[idx.x], b[idx.y]); - idx.x += a_xstride; - idx.y += b_xstride; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/binary_ops.h b/Source/Cmlx/mlx-generated/metal/binary_ops.h deleted file mode 100644 index 4e3d881f..00000000 --- a/Source/Cmlx/mlx-generated/metal/binary_ops.h +++ /dev/null @@ -1,330 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#pragma once - -#include -#include - -constant mlx::os_log logger("mlx", "binary_ops"); - -struct Add { - template - T operator()(T x, T y) { - return x + y; - } -}; - -struct FloorDivide { - template - T operator()(T x, T y) { - return x / y; - } - template <> - float operator()(float x, float y) { - return trunc(x / y); - } - template <> - half operator()(half x, half y) { - return trunc(x / y); - } - template <> - bfloat16_t operator()(bfloat16_t x, bfloat16_t y) { - return trunc(x / y); - } -}; - -struct Divide { - template - T operator()(T x, T y) { - return x / y; - } -}; - -struct Remainder { - template - metal::enable_if_t & !metal::is_signed_v, T> - operator()(T x, T y) { - return x % y; - } - template - metal::enable_if_t & metal::is_signed_v, T> - operator()(T x, T y) { - auto r = x % y; - if (r != 0 && (r < 0 != y < 0)) { - r += y; - } - return r; - } - template - metal::enable_if_t, T> operator()(T x, T y) { - T r = fmod(x, y); - if (r != 0 && (r < 0 != y < 0)) { - r += y; - } - return r; - } - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - return x % y; - } -}; - -struct Equal { - template - bool operator()(T x, T y) { - return x == y; - } -}; - -struct NaNEqual { - template - bool operator()(T x, T y) { - return x == y || (metal::isnan(x) && metal::isnan(y)); - } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x == y || - (metal::isnan(x.real) && metal::isnan(y.real) && metal::isnan(x.imag) && - metal::isnan(y.imag)) || - (x.real == y.real && metal::isnan(x.imag) && metal::isnan(y.imag)) || - (metal::isnan(x.real) && metal::isnan(y.real) && x.imag == y.imag); - } -}; - -struct Greater { - template - bool operator()(T x, T y) { - return x > y; - } -}; - -struct GreaterEqual { - template - bool operator()(T x, T y) { - return x >= y; - } -}; - -struct Less { - template - bool operator()(T x, T y) { - return x < y; - } -}; - -struct LessEqual { - template - bool operator()(T x, T y) { - return x <= y; - } -}; - -struct LogAddExp { - template - T operator()(T x, T y) { - if (metal::isnan(x) || metal::isnan(y)) { - return metal::numeric_limits::quiet_NaN(); - } - constexpr T inf = metal::numeric_limits::infinity(); - T maxval = metal::max(x, y); - T minval = metal::min(x, y); - return (minval == -inf || maxval == inf) - ? maxval - : (maxval + log1p(metal::exp(minval - maxval))); - }; - - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag) || metal::isnan(y.real) || - metal::isnan(y.imag)) { - return metal::numeric_limits::quiet_NaN(); - } - constexpr float inf = metal::numeric_limits::infinity(); - complex64_t maxval = x > y ? x : y; - complex64_t minval = x < y ? x : y; - if (minval.real == -inf || maxval.real == inf) - return maxval; - float m = metal::exp(minval.real - maxval.real); - complex64_t dexp{ - m * metal::cos(minval.imag - maxval.imag), - m * metal::sin(minval.imag - maxval.imag), - }; - return maxval + log1p(dexp); - } -}; - -struct Maximum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::max(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x > y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x > y ? x : y; - } -}; - -struct Minimum { - template - metal::enable_if_t, T> operator()(T x, T y) { - return metal::min(x, y); - } - - template - metal::enable_if_t, T> operator()(T x, T y) { - if (metal::isnan(x)) { - return x; - } - return x < y ? x : y; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (metal::isnan(x.real) || metal::isnan(x.imag)) { - return x; - } - return x < y ? x : y; - } -}; - -struct Multiply { - template - T operator()(T x, T y) { - return x * y; - } -}; - -struct NotEqual { - template - bool operator()(T x, T y) { - return x != y; - } - template <> - bool operator()(complex64_t x, complex64_t y) { - return x.real != y.real || x.imag != y.imag; - } -}; - -struct Power { - template - metal::enable_if_t, T> operator()(T base, T exp) { - return metal::pow(base, exp); - } - - template - metal::enable_if_t, T> operator()(T base, T exp) { - T res = 1; - // Undefined to raise integer to negative power - if (exp < 0) { - logger.log_debug( - "int pow exp<0 (base=%ld exp=%ld)", (long)base, (long)exp); - return 0; - } - - while (exp) { - if (exp & 1) { - res *= base; - } - exp >>= 1; - base *= base; - } - return res; - } - - template <> - complex64_t operator()(complex64_t x, complex64_t y) { - if (x.real == 0 && x.imag == 0) { - if (metal::isnan(y.real) || metal::isnan(y.imag)) { - auto nan = metal::numeric_limits::quiet_NaN(); - return {nan, nan}; - } - return {0.0, 0.0}; - } - auto x_theta = metal::atan2(x.imag, x.real); - auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); - auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); - auto phase = y.imag * x_ln_r + y.real * x_theta; - return {mag * metal::cos(phase), mag * metal::sin(phase)}; - } -}; - -struct Subtract { - template - T operator()(T x, T y) { - return x - y; - } -}; - -struct LogicalAnd { - template - T operator()(T x, T y) { - return x && y; - }; -}; - -struct LogicalOr { - template - T operator()(T x, T y) { - return x || y; - }; -}; - -struct BitwiseAnd { - template - T operator()(T x, T y) { - return x & y; - }; -}; - -struct BitwiseOr { - template - T operator()(T x, T y) { - return x | y; - }; -}; - -struct BitwiseXor { - template - T operator()(T x, T y) { - return x ^ y; - }; -}; - -struct LeftShift { - template - T operator()(T x, T y) { - return x << y; - }; -}; - -struct RightShift { - template - T operator()(T x, T y) { - return x >> y; - }; -}; - -struct ArcTan2 { - template - T operator()(T y, T x) { - return metal::precise::atan2(y, x); - } -}; - -struct DivMod { - template - metal::array operator()(T x, T y) { - return {FloorDivide{}(x, y), Remainder{}(x, y)}; - }; -}; diff --git a/Source/Cmlx/mlx-generated/metal/binary_two.h b/Source/Cmlx/mlx-generated/metal/binary_two.h deleted file mode 100644 index 4455e4ca..00000000 --- a/Source/Cmlx/mlx-generated/metal/binary_two.h +++ /dev/null @@ -1,244 +0,0 @@ -// Copyright © 2024 Apple Inc. - -template -[[kernel]] void binary_ss( - device const T* a, - device const T* b, - device U* c, - device U* d, - uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[0]); - c[index] = out[0]; - d[index] = out[1]; -} - -template ::n> -[[kernel]] void binary_sv( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - auto out = Op()(a[0], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; - } - } else { - for (int i = 0; i < N; ++i) { - auto out = Op()(a[0], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; - } - } -} - -template ::n> -[[kernel]] void binary_vs( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - auto out = Op()(a[index + i], b[0]); - c[index + i] = out[0]; - d[index + i] = out[1]; - } - } else { - for (int i = 0; i < N; ++i) { - auto out = Op()(a[index + i], b[0]); - c[index + i] = out[0]; - d[index + i] = out[1]; - } - } -} - -template ::n> -[[kernel]] void binary_vv( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - auto out = Op()(a[index + i], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; - } - } else { - for (int i = 0; i < N; ++i) { - auto out = Op()(a[index + i], b[index + i]); - c[index + i] = out[0]; - d[index + i] = out[1]; - } - } -} - -template ::n> -[[kernel]] void binary_sv2( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - auto out = Op()(a[0], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; - } - } else { - for (int i = 0; i < N; ++i) { - auto out = Op()(a[0], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; - } - } -} - -template ::n> -[[kernel]] void binary_vs2( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - auto out = Op()(a[offset + i], b[0]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; - } - } else { - for (int i = 0; i < N; ++i) { - auto out = Op()(a[offset + i], b[0]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; - } - } -} - -template ::n> -[[kernel]] void binary_vv2( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - auto out = Op()(a[offset + i], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; - } - } else { - for (int i = 0; i < N; ++i) { - auto out = Op()(a[offset + i], b[offset + i]); - c[offset + i] = out[0]; - d[offset + i] = out[1]; - } - } -} - -template -[[kernel]] void binary_g_nd1( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const int64_t& a_stride, - constant const int64_t& b_stride, - uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_stride); - auto b_idx = elem_to_loc_1(index, b_stride); - auto out = Op()(a[a_idx], b[b_idx]); - c[index] = out[0]; - d[index] = out[1]; -} - -template -[[kernel]] void binary_g_nd2( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const int64_t a_strides[2], - constant const int64_t b_strides[2], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; - auto out = Op()(a[a_idx], b[b_idx]); - c[out_idx] = out[0]; - d[out_idx] = out[1]; -} - -template -[[kernel]] void binary_g_nd3( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const int64_t a_strides[3], - constant const int64_t b_strides[3], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); - auto out = Op()(a[a_idx], b[b_idx]); - c[out_idx] = out[0]; - d[out_idx] = out[1]; -} - -template < - typename T, - typename U, - typename Op, - int N = 1, - typename IdxT = int64_t> -[[kernel]] void binary_g( - device const T* a, - device const T* b, - device U* c, - device U* d, - constant const int* shape, - constant const int64_t* a_strides, - constant const int64_t* b_strides, - constant const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_2_nd( - {N * index.x, index.y, index.z}, shape, a_strides, b_strides, ndim); - auto xshape = shape[ndim - 1]; - IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); - IdxT a_xstride = a_strides[ndim - 1]; - IdxT b_xstride = b_strides[ndim - 1]; - for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { - auto out = Op()(a[idx.x], b[idx.y]); - c[out_idx] = out[0]; - d[out_idx++] = out[1]; - idx.x += a_xstride; - idx.y += b_xstride; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/cexpf.h b/Source/Cmlx/mlx-generated/metal/cexpf.h deleted file mode 100644 index b45fe6a2..00000000 --- a/Source/Cmlx/mlx-generated/metal/cexpf.h +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright © 2025 Apple Inc. -// Copyright © 2008-2013 NVIDIA Corporation -// Copyright © 2013 Filipe RNC Maia -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// Forked from -// https://github.com/NVIDIA/cccl/blob/main/thrust/thrust/detail/complex/cexpf.h - -// TODO: We should use thrust::exp but the thrust header in old CUDA versions -// can not be used in JIT. - -#pragma once - -#include - -using ieee_float_shape_type = union { - float value; - uint32_t word; -}; - -inline void get_float_word(thread uint32_t& i, float d) { - ieee_float_shape_type gf_u; - gf_u.value = (d); - (i) = gf_u.word; -} - -inline void get_float_word(thread int32_t& i, float d) { - ieee_float_shape_type gf_u; - gf_u.value = (d); - (i) = gf_u.word; -} - -inline void set_float_word(thread float& d, uint32_t i) { - ieee_float_shape_type sf_u; - sf_u.word = (i); - (d) = sf_u.value; -} - -inline float frexp_expf(float x, thread int* expt) { - const uint32_t k = 235; - const float kln2 = 162.88958740F; - - float exp_x; - uint32_t hx; - - exp_x = metal::exp(x - kln2); - get_float_word(hx, exp_x); - *expt = (hx >> 23) - (0x7f + 127) + k; - set_float_word(exp_x, (hx & 0x7fffff) | ((0x7f + 127) << 23)); - return exp_x; -} - -inline complex64_t ldexp_cexpf(complex64_t z, int expt) { - float x, y, exp_x, scale1, scale2; - int ex_expt, half_expt; - - x = z.real; - y = z.imag; - exp_x = frexp_expf(x, &ex_expt); - expt += ex_expt; - - half_expt = expt / 2; - set_float_word(scale1, (0x7f + half_expt) << 23); - half_expt = expt - half_expt; - set_float_word(scale2, (0x7f + half_expt) << 23); - - return complex64_t{ - metal::cos(y) * exp_x * scale1 * scale2, - metal::sin(y) * exp_x * scale1 * scale2}; -} - -inline complex64_t cexpf(const thread complex64_t& z) { - float x, y, exp_x; - uint32_t hx, hy; - - const uint32_t exp_ovfl = 0x42b17218, cexp_ovfl = 0x43400074; - - x = z.real; - y = z.imag; - - get_float_word(hy, y); - hy &= 0x7fffffff; - - /* cexp(x + I 0) = exp(x) + I 0 */ - if (hy == 0) { - return complex64_t{metal::exp(x), y}; - } - get_float_word(hx, x); - /* cexp(0 + I y) = cos(y) + I sin(y) */ - if ((hx & 0x7fffffff) == 0) { - return complex64_t{metal::cos(y), metal::sin(y)}; - } - if (hy >= 0x7f800000) { - if ((hx & 0x7fffffff) != 0x7f800000) { - /* cexp(finite|NaN +- I Inf|NaN) = NaN + I NaN */ - return complex64_t{y - y, y - y}; - } else if (hx & 0x80000000) { - /* cexp(-Inf +- I Inf|NaN) = 0 + I 0 */ - return complex64_t{0.0, 0.0}; - } else { - /* cexp(+Inf +- I Inf|NaN) = Inf + I NaN */ - return complex64_t{x, y - y}; - } - } - - if (hx >= exp_ovfl && hx <= cexp_ovfl) { - /* - * x is between 88.7 and 192, so we must scale to avoid - * overflow in expf(x). - */ - return ldexp_cexpf(z, 0); - } else { - /* - * Cases covered here: - * - x < exp_ovfl and exp(x) won't overflow (common case) - * - x > cexp_ovfl, so exp(x) * s overflows for all s > 0 - * - x = +-Inf (generated by exp()) - * - x = NaN (spurious inexact exception from y) - */ - exp_x = metal::exp(x); - return complex64_t{exp_x * metal::cos(y), exp_x * metal::sin(y)}; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/complex.h b/Source/Cmlx/mlx-generated/metal/complex.h deleted file mode 100644 index 6e391483..00000000 --- a/Source/Cmlx/mlx-generated/metal/complex.h +++ /dev/null @@ -1,173 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -#include - -using namespace metal; - -struct complex64_t; - -template -static constexpr constant bool can_convert_to_complex64 = - !is_same_v && is_convertible_v; - -template -static constexpr constant bool can_convert_from_complex64 = - !is_same_v && - (is_convertible_v || is_convertible_v); - -struct complex64_t { - float real; - float imag; - - // Constructors - constexpr complex64_t(float real, float imag) : real(real), imag(imag) {}; - constexpr complex64_t() : real(0), imag(0) {}; - constexpr complex64_t() threadgroup : real(0), imag(0) {}; - - // Conversions to complex64_t - template < - typename T, - typename = typename enable_if>::type> - constexpr complex64_t(T x) thread : real(x), imag(0) {} - - template < - typename T, - typename = typename enable_if>::type> - constexpr complex64_t(T x) threadgroup : real(x), imag(0) {} - - template < - typename T, - typename = typename enable_if>::type> - constexpr complex64_t(T x) device : real(x), imag(0) {} - - template < - typename T, - typename = typename enable_if>::type> - constexpr complex64_t(T x) constant : real(x), imag(0) {} - - // Conversions from complex64_t - template < - typename T, - typename = typename enable_if>::type> - constexpr operator T() const thread { - return static_cast(real); - } - - template < - typename T, - typename = typename enable_if>::type> - constexpr operator T() const threadgroup { - return static_cast(real); - } - - template < - typename T, - typename = typename enable_if>::type> - constexpr operator T() const device { - return static_cast(real); - } - - template < - typename T, - typename = typename enable_if>::type> - constexpr operator T() const constant { - return static_cast(real); - } -}; - -constexpr complex64_t operator-(complex64_t x) { - return {-x.real, -x.imag}; -} - -constexpr bool operator>=(complex64_t a, complex64_t b) { - return (a.real > b.real) || (a.real == b.real && a.imag >= b.imag); -} - -constexpr bool operator>(complex64_t a, complex64_t b) { - return (a.real > b.real) || (a.real == b.real && a.imag > b.imag); -} - -constexpr bool operator<=(complex64_t a, complex64_t b) { - return operator>=(b, a); -} - -constexpr bool operator<(complex64_t a, complex64_t b) { - return operator>(b, a); -} - -constexpr bool operator==(complex64_t a, complex64_t b) { - return a.real == b.real && a.imag == b.imag; -} - -constexpr complex64_t operator+(complex64_t a, complex64_t b) { - return {a.real + b.real, a.imag + b.imag}; -} - -constexpr thread complex64_t& operator+=(thread complex64_t& a, complex64_t b) { - a.real += b.real; - a.imag += b.imag; - return a; -} - -constexpr threadgroup complex64_t& operator+=( - threadgroup complex64_t& a, - complex64_t b) { - a.real += b.real; - a.imag += b.imag; - return a; -} - -constexpr device complex64_t& operator+=(device complex64_t& a, complex64_t b) { - a.real += b.real; - a.imag += b.imag; - return a; -} - -constexpr complex64_t operator+(float a, complex64_t b) { - return {a + b.real, b.imag}; -} -constexpr complex64_t operator+(complex64_t a, float b) { - return {a.real + b, a.imag}; -} - -constexpr complex64_t operator-(complex64_t a, complex64_t b) { - return {a.real - b.real, a.imag - b.imag}; -} -constexpr complex64_t operator-(float a, complex64_t b) { - return {a - b.real, -b.imag}; -} -constexpr complex64_t operator-(complex64_t a, float b) { - return {a.real - b, a.imag}; -} - -constexpr complex64_t operator*(complex64_t a, complex64_t b) { - return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real}; -} - -constexpr complex64_t operator/(complex64_t a, complex64_t b) { - auto denom = b.real * b.real + b.imag * b.imag; - auto x = a.real * b.real + a.imag * b.imag; - auto y = a.imag * b.real - a.real * b.imag; - return {x / denom, y / denom}; -} - -constexpr complex64_t operator/(float a, complex64_t b) { - auto denom = b.real * b.real + b.imag * b.imag; - auto x = a * b.real; - auto y = -a * b.imag; - return {x / denom, y / denom}; -} - -constexpr complex64_t operator%(complex64_t a, complex64_t b) { - auto real = a.real - (b.real * static_cast(a.real / b.real)); - auto imag = a.imag - (b.imag * static_cast(a.imag / b.imag)); - if (real != 0 && (real < 0 != b.real < 0)) { - real += b.real; - } - if (imag != 0 && (imag < 0 != b.imag < 0)) { - imag += b.imag; - } - return {real, imag}; -} diff --git a/Source/Cmlx/mlx-generated/metal/conv.metal b/Source/Cmlx/mlx-generated/metal/conv.metal deleted file mode 100644 index e6cc127c..00000000 --- a/Source/Cmlx/mlx-generated/metal/conv.metal +++ /dev/null @@ -1,702 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include -#include - -#include "steel/conv/params.h" -#include "utils.h" - -#define MLX_MTL_CONST static constant constexpr const - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -/// Naive unfold with dilation -/////////////////////////////////////////////////////////////////////////////// - -template -[[kernel]] void naive_unfold_Nd( - const device T* in [[buffer(0)]], - device T* out [[buffer(1)]], - const constant MLXConvParams* params [[buffer(2)]], - uint3 gid [[thread_position_in_grid]]) { - int filter_size = params->C; - for (short i = 0; i < N; i++) - filter_size *= params->wS[i]; - - int out_pixels = 1; - for (short i = 0; i < N; i++) - out_pixels *= params->oS[i]; - - // Set out - out += (size_t)gid.z * filter_size + (size_t)gid.y * (params->C); - - // Coordinates in input - int is[N] = {0}; - - // gid.z: N oS (Batch and row in unfolded output) - // gid.y: wS (Filter location to unfold input) - // gid.x: C (channel) - - int n = (gid.z) / out_pixels; - int oS = (gid.z) % out_pixels; - int wS = gid.y; - - bool valid = n < params->N; - - // Unroll dimensions - for (int i = N - 1; i >= 0; --i) { - int os_ = (oS % params->oS[i]); - int ws_ = (wS % params->wS[i]); - - ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; - - int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; - int is_max = 1 + params->idil[i] * (params->iS[i] - 1); - - valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0); - - is[i] = is_ / params->idil[i]; - - oS /= params->oS[i]; - wS /= params->wS[i]; - } - - if (valid) { - size_t in_offset = n * params->in_strides[0]; - - for (int i = 0; i < N; ++i) { - in_offset += is[i] * params->in_strides[i + 1]; - } - - out[gid.x] = in[in_offset + gid.x]; - } else { - out[gid.x] = T(0); - } -} - -// This kernel unfolds the input array of size (N, *spatial_dims, C) -// into an array of size (N x *spatial_dims, C x *kernel_dims). -template -[[kernel]] void naive_unfold_transpose_Nd( - const device T* in [[buffer(0)]], - device T* out [[buffer(1)]], - const constant MLXConvParams* params [[buffer(2)]], - uint3 gid [[thread_position_in_grid]]) { - int filter_size = params->C; - for (short i = 0; i < N; i++) - filter_size *= params->wS[i]; - - int out_pixels = 1; - for (short i = 0; i < N; i++) - out_pixels *= params->oS[i]; - - // Set out - out += - (size_t)gid.z * filter_size + (size_t)gid.x * (filter_size / params->C); - - // Coordinates in input - int is[N] = {0}; - - // gid.z: N oS (Batch and row in unfolded output) - // gid.y: wS (Filter location to unfold input) - // gid.x: C (channel) - - int n = (gid.z) / out_pixels; - int oS = (gid.z) % out_pixels; - int wS = gid.y; - - bool valid = n < params->N; - - // Unroll dimensions - int kernel_stride = 1; - for (int i = N - 1; i >= 0; --i) { - int os_ = (oS % params->oS[i]); - int ws_ = (wS % params->wS[i]); - out += ws_ * kernel_stride; - - ws_ = params->flip ? params->wS[i] - ws_ - 1 : ws_; - - int is_ = os_ * params->str[i] - params->pad[i] + ws_ * params->kdil[i]; - int is_max = 1 + params->idil[i] * (params->iS[i] - 1); - - valid &= is_ >= 0 && is_ < is_max && (is_ % params->idil[i] == 0); - - is[i] = is_ / params->idil[i]; - - oS /= params->oS[i]; - wS /= params->wS[i]; - - kernel_stride *= params->wS[i]; - } - - if (valid) { - size_t in_offset = n * params->in_strides[0]; - - for (int i = 0; i < N; ++i) { - in_offset += is[i] * params->in_strides[i + 1]; - } - - out[0] = in[in_offset + gid.x]; - } else { - out[0] = T(0); - } -} - -#define instantiate_naive_unfold_nd(name, itype, n) \ - template [[host_name("naive_unfold_nd_" #name "_" #n)]] [[kernel]] void \ - naive_unfold_Nd( \ - const device itype* in [[buffer(0)]], \ - device itype* out [[buffer(1)]], \ - const constant MLXConvParams* params [[buffer(2)]], \ - uint3 gid [[thread_position_in_grid]]); \ - template \ - [[host_name("naive_unfold_transpose_nd_" #name "_" #n)]] [[kernel]] void \ - naive_unfold_transpose_Nd( \ - const device itype* in [[buffer(0)]], \ - device itype* out [[buffer(1)]], \ - const constant MLXConvParams* params [[buffer(2)]], \ - uint3 gid [[thread_position_in_grid]]); - -#define instantiate_naive_unfold_nd_dims(name, itype) \ - instantiate_naive_unfold_nd(name, itype, 1) instantiate_naive_unfold_nd( \ - name, itype, 2) instantiate_naive_unfold_nd(name, itype, 3) - -instantiate_naive_unfold_nd_dims(float32, float); -instantiate_naive_unfold_nd_dims(float16, half); -instantiate_naive_unfold_nd_dims(bfloat16, bfloat16_t); - -/////////////////////////////////////////////////////////////////////////////// -/// Depthwise convolution kernels -/////////////////////////////////////////////////////////////////////////////// - -constant int ker_h [[function_constant(00)]]; -constant int ker_w [[function_constant(01)]]; -constant int str_h [[function_constant(10)]]; -constant int str_w [[function_constant(11)]]; -constant int tgp_h [[function_constant(100)]]; -constant int tgp_w [[function_constant(101)]]; -constant bool do_flip [[function_constant(200)]]; - -constant int span_h = tgp_h * str_h + ker_h - 1; -constant int span_w = tgp_w * str_w + ker_w - 1; -constant int span_hw = span_h * span_w; - -template -[[kernel]] void depthwise_conv_2d( - const device T* in [[buffer(0)]], - const device T* wt [[buffer(1)]], - device T* out [[buffer(2)]], - const constant MLXConvParams<2>& params [[buffer(3)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 gid [[thread_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int tc = 8; - constexpr int tw = 8; - constexpr int th = 4; - - constexpr int c_per_thr = 8; - - constexpr int TGH = th * 2 + 6; - constexpr int TGW = tw * 2 + 6; - constexpr int TGC = tc; - - threadgroup T ins[TGH * TGW * TGC]; - - const int n_tgblocks_h = params.oS[0] / th; - const int n = tid.z / n_tgblocks_h; - const int tghid = tid.z % n_tgblocks_h; - const int oh = tghid * th + lid.z; - const int ow = gid.y; - const int c = gid.x; - - in += n * params.in_strides[0]; - - // Load in - { - constexpr int n_threads = th * tw * tc; - const int tg_oh = (tghid * th) * str_h - params.pad[0]; - const int tg_ow = (tid.y * tw) * str_w - params.pad[1]; - const int tg_c = tid.x * tc; - - const int thread_idx = simd_gid * 32 + simd_lid; - constexpr int thr_per_hw = tc / c_per_thr; - constexpr int hw_per_group = n_threads / thr_per_hw; - - const int thr_c = thread_idx % thr_per_hw; - const int thr_hw = thread_idx / thr_per_hw; - - for (int hw = thr_hw; hw < span_hw; hw += hw_per_group) { - const int h = hw / span_w; - const int w = hw % span_w; - - const int ih = tg_oh + h; - const int iw = tg_ow + w; - - const int in_s_offset = h * span_w * TGC + w * TGC; - - if (ih >= 0 && ih < params.iS[0] && iw >= 0 && iw < params.iS[1]) { - const auto in_load = - in + ih * params.in_strides[1] + iw * params.in_strides[2] + tg_c; - - MLX_MTL_PRAGMA_UNROLL - for (int cc = 0; cc < c_per_thr; ++cc) { - ins[in_s_offset + c_per_thr * thr_c + cc] = - in_load[c_per_thr * thr_c + cc]; - } - } else { - MLX_MTL_PRAGMA_UNROLL - for (int cc = 0; cc < c_per_thr; ++cc) { - ins[in_s_offset + c_per_thr * thr_c + cc] = T(0); - } - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - wt += c * params.wt_strides[0]; - - const auto ins_ptr = - &ins[lid.z * str_h * span_w * TGC + lid.y * str_w * TGC + lid.x]; - float o = 0.; - for (int h = 0; h < ker_h; ++h) { - for (int w = 0; w < ker_w; ++w) { - int wt_h = h; - int wt_w = w; - if (do_flip) { - wt_h = ker_h - h - 1; - wt_w = ker_w - w - 1; - } - auto inv = ins_ptr[h * span_w * TGC + w * TGC]; - auto wtv = wt[wt_h * ker_w + wt_w]; - o += inv * wtv; - } - } - threadgroup_barrier(mem_flags::mem_none); - - out += n * params.out_strides[0] + oh * params.out_strides[1] + - ow * params.out_strides[2]; - out[c] = static_cast(o); -} - -#define instantiate_depthconv2d(iname, itype) \ - instantiate_kernel("depthwise_conv_2d_" #iname, depthwise_conv_2d, itype) - -instantiate_depthconv2d(float32, float); -instantiate_depthconv2d(float16, half); -instantiate_depthconv2d(bfloat16, bfloat16_t); - -template -[[kernel]] void depthwise_conv_1d( - const device T* in [[buffer(0)]], - const device T* w [[buffer(1)]], - device T* out [[buffer(2)]], - constant const IdxT strides[3], - constant const int& kernel_size, - uint3 tid [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - out += (tid.z * static_cast(grid_dim.y) + tid.y) * grid_dim.x + tid.x; - in += tid.z * strides[0] + tid.y * strides[1] + tid.x * strides[2]; - w += tid.x * kernel_size; - - float acc = 0.0; - for (int i = 0; i < kernel_size; ++i) { - acc += static_cast(in[0]) * w[i]; - in += strides[1]; - } - *out = static_cast(acc); -} - -#define instantiate_depthconv1d(iname, itype) \ - instantiate_kernel( \ - "depthwise_conv_1d_" #iname, depthwise_conv_1d, itype, int32_t) \ - instantiate_kernel( \ - "depthwise_conv_1d_" #iname "_large", \ - depthwise_conv_1d, \ - itype, \ - int64_t) - -instantiate_depthconv1d(float32, float); -instantiate_depthconv1d(float16, half); -instantiate_depthconv1d(bfloat16, bfloat16_t); - -/////////////////////////////////////////////////////////////////////////////// -/// Winograd kernels -/////////////////////////////////////////////////////////////////////////////// - -template -struct WinogradTransforms {}; - -template <> -struct WinogradTransforms<6, 3, 8> { - MLX_MTL_CONST int OUT_TILE_SIZE = 6; - MLX_MTL_CONST int FILTER_SIZE = 3; - MLX_MTL_CONST int IN_TILE_SIZE = OUT_TILE_SIZE + FILTER_SIZE - 1; - MLX_MTL_CONST int SIMD_MATRIX_SIZE = 8; - MLX_MTL_CONST float in_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { - {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, - {0.00f, 1.00f, -1.00f, 0.50f, -0.50f, 2.00f, -2.00f, -1.00f}, - {-5.25f, 1.00f, 1.00f, 0.25f, 0.25f, 4.00f, 4.00f, 0.00f}, - {0.00f, -4.25f, 4.25f, -2.50f, 2.50f, -2.50f, 2.50f, 5.25f}, - {5.25f, -4.25f, -4.25f, -1.25f, -1.25f, -5.00f, -5.00f, 0.00f}, - {0.00f, 1.00f, -1.00f, 2.00f, -2.00f, 0.50f, -0.50f, -5.25f}, - {-1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 0.00f}, - {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, - }; - - MLX_MTL_CONST float out_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { - {1.00f, 0.00f, 0.00f, 0.00f, 0.00f, 0.00f}, - {1.00f, 1.00f, 1.00f, 1.00f, 1.00f, 1.00f}, - {1.00f, -1.00f, 1.00f, -1.00f, 1.00f, -1.00f}, - {1.00f, 2.00f, 4.00f, 8.00f, 16.00f, 32.00f}, - {1.00f, -2.00f, 4.00f, -8.00f, 16.00f, -32.00f}, - {1.00f, 0.50f, 0.25f, 0.125f, 0.0625f, 0.03125f}, - {1.00f, -0.50f, 0.25f, -0.125f, 0.0625f, -0.03125f}, - {0.00f, 0.00f, 0.00f, 0.00f, 0.00f, 1.00f}, - }; - - MLX_MTL_CONST float wt_transform[SIMD_MATRIX_SIZE][SIMD_MATRIX_SIZE] = { - {1.00, 0.00, 0.00}, - {-2.0 / 9.00, -2.0 / 9.00, -2.0 / 9.00}, - {-2.0 / 9.00, 2.0 / 9.00, -2.0 / 9.00}, - {1.0 / 90.0, 1.0 / 45.0, 2.0 / 45.0}, - {1.0 / 90.0, -1.0 / 45.0, 2.0 / 45.0}, - {32.0 / 45.0, 16.0 / 45.0, 8.0 / 45.0}, - {32.0 / 45.0, -16.0 / 45.0, 8.0 / 45.0}, - {0.00, 0.00, 1.00}, - }; -}; - -constant constexpr const float WinogradTransforms<6, 3, 8>::wt_transform[8][8]; -constant constexpr const float WinogradTransforms<6, 3, 8>::in_transform[8][8]; -constant constexpr const float WinogradTransforms<6, 3, 8>::out_transform[8][8]; - -template -[[kernel, max_total_threads_per_threadgroup(BO * 32)]] void -winograd_conv_2d_weight_transform( - const device T* wt_in [[buffer(0)]], - device T* wt_out [[buffer(1)]], - const constant int& C [[buffer(2)]], - const constant int& O [[buffer(3)]], - uint tid [[threadgroup_position_in_grid]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) { - using WGT = WinogradTransforms; - - // Get lane position in simdgroup - const short qid = simd_lane_id / 4; - const short sm = (qid & 4) + (simd_lane_id / 2) % 4; - const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - - // Initialize G matrix - simdgroup_matrix G; - G.thread_elements()[0] = WGT::wt_transform[sm][sn]; - G.thread_elements()[1] = WGT::wt_transform[sm][sn + 1]; - - // Initialize Gt matrix - simdgroup_matrix Gt; - Gt.thread_elements()[0] = WGT::wt_transform[sn][sm]; - Gt.thread_elements()[1] = WGT::wt_transform[sn + 1][sm]; - - // Move to the correct output filter - size_t ko = BO * tid + simd_group_id; - wt_in += ko * R * R * C; - - // wt_out is stored transposed (A x A x C x O) - short ohw_0 = sm * 8 + sn; - short ohw_1 = sm * 8 + sn + 1; - device T* wt_out_0 = wt_out + ohw_0 * C * O + ko; - device T* wt_out_1 = wt_out + ohw_1 * C * O + ko; - - // Prepare shared memory - threadgroup T Ws[BO][R][R][BC]; - - // Loop over C - for (int bc = 0; bc < C; bc += BC) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Read into shared memory - for (int kh = 0; kh < R; ++kh) { - for (int kw = 0; kw < R; ++kw) { - for (int kc = simd_lane_id; kc < BC; kc += 32) { - Ws[simd_group_id][kh][kw][kc] = wt_in[kh * R * C + kw * C + kc]; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - // Do transform and store the result - for (int c = 0; c < BC; ++c) { - simdgroup_matrix g; - g.thread_elements()[0] = - sm < R && sn < R ? Ws[simd_group_id][sm][sn][c] : T(0); - g.thread_elements()[1] = - sm < R && sn + 1 < R ? Ws[simd_group_id][sm][sn + 1][c] : T(0); - - simdgroup_matrix g_out = (G * g) * Gt; - wt_out_0[c * O] = static_cast(g_out.thread_elements()[0]); - wt_out_1[c * O] = static_cast(g_out.thread_elements()[1]); - } - - wt_in += BC; - wt_out_0 += BC * O; - wt_out_1 += BC * O; - } -} - -#define instantiate_winograd_conv_2d_weight_transform_base(name, itype, bc) \ - template [[host_name( \ - "winograd_conv_2d_weight_transform_" #name "_bc" #bc)]] [[kernel]] void \ - winograd_conv_2d_weight_transform( \ - const device itype* wt_in [[buffer(0)]], \ - device itype* wt_out [[buffer(1)]], \ - const constant int& C [[buffer(2)]], \ - const constant int& O [[buffer(3)]], \ - uint tid [[threadgroup_position_in_grid]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]]); - -template -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void -winograd_conv_2d_input_transform( - const device T* inp_in [[buffer(0)]], - device T* inp_out [[buffer(1)]], - const constant MLXConvParams<2>& params [[buffer(2)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 tgp_per_grid [[threadgroups_per_grid]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) { - (void)lid; - - using WGT = WinogradTransforms; - constexpr int A = WGT::IN_TILE_SIZE; - constexpr int N_SIMD_GROUPS = WM * WN; - - // Get lane position in simdgroup - const short qid = simd_lane_id / 4; - const short sm = (qid & 4) + (simd_lane_id / 2) % 4; - const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - - // Initialize B matrix - simdgroup_matrix B; - B.thread_elements()[0] = WGT::in_transform[sm][sn]; - B.thread_elements()[1] = WGT::in_transform[sm][sn + 1]; - - // Initialize Bt matrix - simdgroup_matrix Bt; - Bt.thread_elements()[0] = WGT::in_transform[sn][sm]; - Bt.thread_elements()[1] = WGT::in_transform[sn + 1][sm]; - - // Resolve input tile - constexpr int TH = (A / WM); - constexpr int TW = (A / WN); - int kh = TH * (simd_group_id / WN); - int kw = TW * (simd_group_id % WN); - int bh = M * tid.y + kh; - int bw = M * tid.x + kw; - - // Move to the correct input tile - inp_in += tid.z * params.in_strides[0] + bh * params.in_strides[1] + - bw * params.in_strides[2]; - - // Pre compute strides - int jump_in[TH][TW]; - - for (int h = 0; h < TH; h++) { - for (int w = 0; w < TW; w++) { - jump_in[h][w] = h * params.in_strides[1] + w * params.in_strides[2]; - } - } - - // inp_out is stored interleaved (A x A x tiles x C) - size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; - size_t tile_id = - tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; - size_t ohw_0 = sm * 8 + sn; - size_t ohw_1 = sm * 8 + sn + 1; - device T* inp_out_0 = - inp_out + ohw_0 * N_TILES * params.C + tile_id * params.C; - device T* inp_out_1 = - inp_out + ohw_1 * N_TILES * params.C + tile_id * params.C; - - // Prepare shared memory - threadgroup T Is[A][A][BC]; - - // Loop over C - for (int bc = 0; bc < params.C; bc += BC) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Read into shared memory - for (int h = 0; h < TH; h++) { - for (int w = 0; w < TW; w++) { - const device T* in_ptr = inp_in + jump_in[h][w]; - for (int c = simd_lane_id; c < BC; c += 32) { - Is[kh + h][kw + w][c] = in_ptr[c]; - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - // Do transform and store the result - for (int c = simd_group_id; c < BC; c += N_SIMD_GROUPS) { - simdgroup_matrix I; - I.thread_elements()[0] = Is[sm][sn][c]; - I.thread_elements()[1] = Is[sm][sn + 1][c]; - - simdgroup_matrix I_out = (Bt * I) * B; - inp_out_0[c] = static_cast(I_out.thread_elements()[0]); - inp_out_1[c] = static_cast(I_out.thread_elements()[1]); - } - - inp_in += BC; - inp_out_0 += BC; - inp_out_1 += BC; - } -} - -#define instantiate_winograd_conv_2d_input_transform(name, itype, bc) \ - template [[host_name( \ - "winograd_conv_2d_input_transform_" #name "_bc" #bc)]] [[kernel]] void \ - winograd_conv_2d_input_transform( \ - const device itype* inp_in [[buffer(0)]], \ - device itype* inp_out [[buffer(1)]], \ - const constant MLXConvParams<2>& params [[buffer(2)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 tgp_per_grid [[threadgroups_per_grid]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]]); - -template -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void -winograd_conv_2d_output_transform( - const device T* out_in [[buffer(0)]], - device T* out_out [[buffer(1)]], - const constant MLXConvParams<2>& params [[buffer(2)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 tgp_per_grid [[threadgroups_per_grid]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) { - (void)lid; - - using WGT = WinogradTransforms; - constexpr int N_SIMD_GROUPS = WM * WN; - - // Get lane position in simdgroup - const short qid = simd_lane_id / 4; - const short sm = (qid & 4) + (simd_lane_id / 2) % 4; - const short sn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - - // Initialize A matrix - simdgroup_matrix B; - B.thread_elements()[0] = WGT::out_transform[sm][sn]; - B.thread_elements()[1] = WGT::out_transform[sm][sn + 1]; - - // Initialize At matrix - simdgroup_matrix Bt; - Bt.thread_elements()[0] = WGT::out_transform[sn][sm]; - Bt.thread_elements()[1] = WGT::out_transform[sn + 1][sm]; - - // Out_in comes in shape (A x A x tiles x O) - // We do transform and then write out to out_out in shape (N, H, W, O) - - // Resolve output tile - constexpr int TH = (M / WM); - constexpr int TW = (M / WN); - int kh = TH * (simd_group_id / WN); - int kw = TW * (simd_group_id % WN); - int bh = M * tid.y + kh; - int bw = M * tid.x + kw; - - // Move to the correct input tile - out_out += tid.z * params.out_strides[0] + bh * params.out_strides[1] + - bw * params.out_strides[2]; - - // Pre compute strides - int jump_in[TH][TW]; - - for (int h = 0; h < TH; h++) { - for (int w = 0; w < TW; w++) { - bool valid = ((bh + h) < params.oS[0]) && ((bw + w) < params.oS[1]); - jump_in[h][w] = - valid ? h * params.out_strides[1] + w * params.out_strides[2] : -1; - } - } - - // out_in is stored interleaved (A x A x tiles x O) - size_t N_TILES = tgp_per_grid.x * tgp_per_grid.y * tgp_per_grid.z; - size_t tile_id = - tid.z * tgp_per_grid.x * tgp_per_grid.y + tid.y * tgp_per_grid.x + tid.x; - size_t ohw_0 = sm * 8 + sn; - size_t ohw_1 = sm * 8 + sn + 1; - const device T* out_in_0 = - out_in + ohw_0 * N_TILES * params.O + tile_id * params.O; - const device T* out_in_1 = - out_in + ohw_1 * N_TILES * params.O + tile_id * params.O; - - // Prepare shared memory - threadgroup T Os[M][M][BO]; - - // Loop over O - for (int bo = 0; bo < params.O; bo += BO) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Do transform and store the result - for (int c = simd_group_id; c < BO; c += N_SIMD_GROUPS) { - simdgroup_matrix O_mat; - O_mat.thread_elements()[0] = out_in_0[c]; - O_mat.thread_elements()[1] = out_in_1[c]; - - simdgroup_matrix O_out = (Bt * (O_mat * B)); - if ((sm < M) && (sn < M)) { - Os[sm][sn][c] = static_cast(O_out.thread_elements()[0]); - } - if ((sm < M) && ((sn + 1) < M)) { - Os[sm][sn + 1][c] = static_cast(O_out.thread_elements()[1]); - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - // Read out from shared memory - for (int h = 0; h < TH; h++) { - for (int w = 0; w < TW; w++) { - if (jump_in[h][w] >= 0) { - device T* out_ptr = out_out + jump_in[h][w]; - for (int c = simd_lane_id; c < BO; c += 32) { - out_ptr[c] = Os[kh + h][kw + w][c]; - } - } - } - } - - out_out += BO; - out_in_0 += BO; - out_in_1 += BO; - } -} - -#define instantiate_winograd_conv_2d_output_transform(name, itype, bo) \ - template [[host_name( \ - "winograd_conv_2d_output_transform_" #name "_bo" #bo)]] [[kernel]] void \ - winograd_conv_2d_output_transform( \ - const device itype* out_in [[buffer(0)]], \ - device itype* out_out [[buffer(1)]], \ - const constant MLXConvParams<2>& params [[buffer(2)]], \ - uint3 tid [[threadgroup_position_in_grid]], \ - uint3 lid [[thread_position_in_threadgroup]], \ - uint3 tgp_per_grid [[threadgroups_per_grid]], \ - uint simd_group_id [[simdgroup_index_in_threadgroup]], \ - uint simd_lane_id [[thread_index_in_simdgroup]]); - -// clang-format off -#define instantiate_winograd_conv_2d(name, itype) \ - instantiate_winograd_conv_2d_weight_transform_base(name, itype, 32) \ - instantiate_winograd_conv_2d_input_transform(name, itype, 32) \ - instantiate_winograd_conv_2d_output_transform(name, itype, 32) // clang-format on - -// clang-format off -instantiate_winograd_conv_2d(float32, float); -instantiate_winograd_conv_2d(bfloat16, bfloat16_t); -instantiate_winograd_conv_2d(float16, half); // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/copy.h b/Source/Cmlx/mlx-generated/metal/copy.h deleted file mode 100644 index cf22347e..00000000 --- a/Source/Cmlx/mlx-generated/metal/copy.h +++ /dev/null @@ -1,276 +0,0 @@ -// Copyright © 2024 Apple Inc. - -template ::n> -[[kernel]] void copy_s( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - dst[index + i] = static_cast(src[0]); - } - } else { - for (int i = 0; i < N; ++i) { - dst[index + i] = static_cast(src[0]); - } - } -} - -template ::n> -[[kernel]] void copy_v( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - dst[index + i] = static_cast(src[index + i]); - } - } else { - for (int i = 0; i < N; ++i) { - dst[index + i] = static_cast(src[index + i]); - } - } -} - -template ::n> -[[kernel]] void copy_s2( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - dst[offset + i] = static_cast(src[0]); - } - } else { - for (int i = 0; i < N; ++i) { - dst[offset + i] = static_cast(src[0]); - } - } -} - -template ::n> -[[kernel]] void copy_v2( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - dst[offset + i] = static_cast(src[offset + i]); - } - } else { - for (int i = 0; i < N; ++i) { - dst[offset + i] = static_cast(src[offset + i]); - } - } -} - -template -[[kernel]] void copy_g_nd1( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t& src_stride [[buffer(3)]], - uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - dst[index] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_g_nd2( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - IdxT dst_idx = index.x + IdxT(grid_dim.x) * index.y; - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_g_nd3( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - IdxT dst_idx = - index.x + IdxT(grid_dim.x) * (index.y + IdxT(grid_dim.y) * index.z); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_g( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int& ndim [[buffer(5)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto src_idx = elem_to_loc( - {N * index.x, index.y, index.z}, src_shape, src_strides, ndim); - if (N == 1) { - IdxT dst_idx = - index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); - dst[dst_idx] = static_cast(src[src_idx]); - return; - } - auto xshape = src_shape[ndim - 1]; - IdxT dst_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); - auto src_xstride = src_strides[ndim - 1]; - for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { - dst[dst_idx + i] = static_cast(src[src_idx]); - src_idx += src_xstride; - } -} - -template -[[kernel]] void copy_gg_nd1( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t& src_stride [[buffer(3)]], - constant const int64_t& dst_stride [[buffer(4)]], - uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - auto dst_idx = elem_to_loc_1(index, dst_stride); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_gg_nd2( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint2 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - auto dst_idx = elem_to_loc_2(index, dst_strides); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_gg_nd3( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - auto dst_idx = elem_to_loc_3(index, dst_strides); - dst[dst_idx] = static_cast(src[src_idx]); -} - -template -[[kernel]] void copy_gg( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - constant const int& ndim [[buffer(5)]], - uint3 index [[thread_position_in_grid]]) { - auto idx = elem_to_loc_2_nd( - {N * index.x, index.y, index.z}, - src_shape, - src_strides, - dst_strides, - ndim); - if (N == 1) { - dst[idx.y] = static_cast(src[idx.x]); - return; - } - IdxT src_xstride = src_strides[ndim - 1]; - IdxT dst_xstride = dst_strides[ndim - 1]; - auto xshape = src_shape[ndim - 1]; - for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { - dst[idx.y] = static_cast(src[idx.x]); - idx.x += src_xstride; - idx.y += dst_xstride; - } -} - -template -[[kernel]] void copy_gg_dynamic_nd1( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t& src_stride [[buffer(3)]], - constant const int64_t& dst_stride [[buffer(4)]], - constant const int64_t& src_offset [[buffer(6)]], - constant const int64_t& dst_offset [[buffer(7)]], - uint index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_1(index, src_stride); - auto dst_idx = elem_to_loc_1(index, dst_stride); - dst[dst_idx + dst_offset] = src[src_idx + src_offset]; -} - -template -[[kernel]] void copy_gg_dynamic_nd2( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - constant const int64_t& src_offset [[buffer(6)]], - constant const int64_t& dst_offset [[buffer(7)]], - uint2 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_2(index, src_strides); - auto dst_idx = elem_to_loc_2(index, dst_strides); - dst[dst_idx + dst_offset] = src[src_idx + src_offset]; -} - -template -[[kernel]] void copy_gg_dynamic_nd3( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - constant const int64_t& src_offset [[buffer(6)]], - constant const int64_t& dst_offset [[buffer(7)]], - uint3 index [[thread_position_in_grid]]) { - auto src_idx = elem_to_loc_3(index, src_strides); - auto dst_idx = elem_to_loc_3(index, dst_strides); - dst[dst_idx + dst_offset] = src[src_idx + src_offset]; -} - -template -[[kernel]] void copy_gg_dynamic( - device const T* src [[buffer(0)]], - device U* dst [[buffer(1)]], - constant const int* src_shape [[buffer(2)]], - constant const int64_t* src_strides [[buffer(3)]], - constant const int64_t* dst_strides [[buffer(4)]], - constant const int& ndim [[buffer(5)]], - constant const int64_t& src_offset [[buffer(6)]], - constant const int64_t& dst_offset [[buffer(7)]], - uint3 index [[thread_position_in_grid]]) { - src += src_offset; - dst += dst_offset; - auto idx = elem_to_loc_2_nd( - {N * index.x, index.y, index.z}, - src_shape, - src_strides, - dst_strides, - ndim); - if (N == 1) { - dst[idx.y] = src[idx.x]; - return; - } - IdxT src_xstride = src_strides[ndim - 1]; - IdxT dst_xstride = dst_strides[ndim - 1]; - auto xshape = src_shape[ndim - 1]; - for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { - dst[idx.y] = src[idx.x]; - idx.x += src_xstride; - idx.y += dst_xstride; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/defines.h b/Source/Cmlx/mlx-generated/metal/defines.h deleted file mode 100644 index c369adb7..00000000 --- a/Source/Cmlx/mlx-generated/metal/defines.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -#if defined __METAL__ || defined MLX_METAL_JIT -#define MTL_CONST constant -#else -#define MTL_CONST -#endif - -static MTL_CONST constexpr int MAX_REDUCE_SPECIALIZED_DIMS = 4; -static MTL_CONST constexpr int REDUCE_N_READS = 4; -static MTL_CONST constexpr int REDUCE_N_WRITES = 4; -static MTL_CONST constexpr int SOFTMAX_N_READS = 4; -static MTL_CONST constexpr int RMS_N_READS = 4; -static MTL_CONST constexpr int RMS_LOOPED_LIMIT = 4096; - -// Instantiate a templated kernel. -// Extra args are used as template parameters: -// e.g. instantiate_kernel(binary_int, binary, a, b) -> -// [[host_name(binary_int)]] [kernel] binary -#define instantiate_kernel(name, func, ...) \ - template [[host_name( \ - name)]] [[kernel]] decltype(func<__VA_ARGS__>) func<__VA_ARGS__>; diff --git a/Source/Cmlx/mlx-generated/metal/erf.h b/Source/Cmlx/mlx-generated/metal/erf.h deleted file mode 100644 index 8a9499e2..00000000 --- a/Source/Cmlx/mlx-generated/metal/erf.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once -#include -#include "expm1f.h" - -/* - * Approximation to the error function. - * Based on code from: - * https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199 - */ -float erf(float a) { - float r, s, t, u; - t = metal::abs(a); - s = a * a; - if (t > 0.927734375f) { - // maximum error 0.99527 ulp - r = metal::fma( - -1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12 - u = metal::fma( - -3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6 - r = metal::fma(r, s, u); - r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4 - r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1 - r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3 - r = metal::fma(r, t, -t); - r = -expm1f(r); - r = metal::copysign(r, a); - } else { - // maximum error 0.98929 ulp - r = -5.96761703e-4f; // -0x1.38e000p-11 - r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8 - r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6 - r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4 - r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2 - r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3 - r = metal::fma(r, a, a); - } - return r; -} - -float erfinv(float a) { - auto t = metal::fma(a, 0.0f - a, 1.0f); - t = metal::log(t); - float p; - if (metal::abs(t) > 6.125f) { // maximum ulp error = 2.35793 - p = 3.03697567e-10f; // 0x1.4deb44p-32 - p = metal::fma(p, t, 2.93243101e-8f); // 0x1.f7c9aep-26 - p = metal::fma(p, t, 1.22150334e-6f); // 0x1.47e512p-20 - p = metal::fma(p, t, 2.84108955e-5f); // 0x1.dca7dep-16 - p = metal::fma(p, t, 3.93552968e-4f); // 0x1.9cab92p-12 - p = metal::fma(p, t, 3.02698812e-3f); // 0x1.8cc0dep-9 - p = metal::fma(p, t, 4.83185798e-3f); // 0x1.3ca920p-8 - p = metal::fma(p, t, -2.64646143e-1f); // -0x1.0eff66p-2 - p = metal::fma(p, t, 8.40016484e-1f); // 0x1.ae16a4p-1 - } else { // maximum ulp error = 2.35002 - p = 5.43877832e-9f; // 0x1.75c000p-28 - p = metal::fma(p, t, 1.43285448e-7f); // 0x1.33b402p-23 - p = metal::fma(p, t, 1.22774793e-6f); // 0x1.499232p-20 - p = metal::fma(p, t, 1.12963626e-7f); // 0x1.e52cd2p-24 - p = metal::fma(p, t, -5.61530760e-5f); // -0x1.d70bd0p-15 - p = metal::fma(p, t, -1.47697632e-4f); // -0x1.35be90p-13 - p = metal::fma(p, t, 2.31468678e-3f); // 0x1.2f6400p-9 - p = metal::fma(p, t, 1.15392581e-2f); // 0x1.7a1e50p-7 - p = metal::fma(p, t, -2.32015476e-1f); // -0x1.db2aeep-3 - p = metal::fma(p, t, 8.86226892e-1f); // 0x1.c5bf88p-1 - } - return a * p; -} diff --git a/Source/Cmlx/mlx-generated/metal/expm1f.h b/Source/Cmlx/mlx-generated/metal/expm1f.h deleted file mode 100644 index 68224e17..00000000 --- a/Source/Cmlx/mlx-generated/metal/expm1f.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#pragma once - -#include - -// Original license copied below: -// Copyright (c) 2015-2023 Norbert Juffa -// All rights reserved. -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions -// are met: -// -// 1. Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// -// 2. Redistributions in binary form must reproduce the above copyright -// notice, this list of conditions and the following disclaimer in the -// documentation and/or other materials provided with the distribution. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -/* Compute exponential base e minus 1. Maximum ulp error = 0.997458 - - i = rint(a/log(2)), f = a-i*log(2). Then expm1(a) = 2**i * (expm1(f)+1) - 1. - Compute r = expm1(f). Then expm1(a)= 2 * (0.5 * 2**i * r + 0.5 * 2**i - 0.5). - With t = 0.5*2**i, expm1(a) = 2*(r * t + t-0.5). However, for best accuracy, - when i == 1, expm1(a)= 2*(r + 0.5), and when i == 0, expm1(a) = r. - - NOTE: Scale factor b is only applied if i < 0 or i > 1 (should be power of 2) -*/ -float expm1f_scaled_unchecked(float a, float b) { - float f, j, r, s, t, u, v, x, y; - int i; - - // exp(a) = 2**i * exp(f); i = rintf (a / log(2)) - j = fma(1.442695f, a, 12582912.f); // 0x1.715476p0, 0x1.8p23 - j = j - 12582912.0f; // 0x1.8p23 - i = (int)j; - f = fma(j, -6.93145752e-1f, a); - - // approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] - s = f * f; - if (a == 0.0f) - s = a; // ensure -0 is passed through - // err = 0.997458 ulp1 = 11081805 - r = 1.97350979e-4f; // 0x1.9de000p-13 - r = fma(r, f, 1.39309070e-3f); // 0x1.6d30bcp-10 - r = fma(r, f, 8.33343994e-3f); // 0x1.1111f6p-7 - r = fma(r, f, 4.16668020e-2f); // 0x1.55559ep-5 - r = fma(r, f, 1.66666716e-1f); // 0x1.55555cp-3 - r = fma(r, f, 4.99999970e-1f); // 0x1.fffffep-2 - u = (j == 1) ? (f + 0.5f) : f; - v = fma(r, s, u); - s = 0.5f * b; - t = ldexp(s, i); - y = t - s; - x = (t - y) - s; // double-float canonicalization of difference - r = fma(v, t, x) + y; - r = r + r; - if (j == 0) - r = v; - if (j == 1) - r = v + v; - return r; -} - -/* Compute exponential base e minus 1. max ulp err = 0.99746 */ -float expm1f(float a) { - float r; - - r = expm1f_scaled_unchecked(a, 1.0f); - /* handle severe overflow and underflow */ - if (abs(a - 1.0f) > 88.0f) { - r = pow(2, a); - r = fma(r, r, -1.0f); - } - return r; -} diff --git a/Source/Cmlx/mlx-generated/metal/fft.h b/Source/Cmlx/mlx-generated/metal/fft.h deleted file mode 100644 index 4f18730b..00000000 --- a/Source/Cmlx/mlx-generated/metal/fft.h +++ /dev/null @@ -1,486 +0,0 @@ -// Copyright © 2024 Apple Inc. - -// Metal FFT using Stockham's algorithm -// -// References: -// - VkFFT (https://github.com/DTolm/VkFFT) -// - Eric Bainville's excellent page (http://www.bealto.com/gpu-fft.html) - -#include - -#include "fft/radix.h" -#include "fft/readwrite.h" -#include "steel/defines.h" - -using namespace metal; - -#define MAX_RADIX 13 -// Reached when elems_per_thread_ = 6, max_radix = 13 -// and some threads have to do 3 radix 6s requiring 18 float2s. -#define MAX_OUTPUT_SIZE 18 - -// Specialize for a particular value of N at runtime -STEEL_CONST bool inv_ [[function_constant(0)]]; -STEEL_CONST bool is_power_of_2_ [[function_constant(1)]]; -STEEL_CONST int elems_per_thread_ [[function_constant(2)]]; -// rader_m = n / rader_n -STEEL_CONST int rader_m_ [[function_constant(3)]]; -// Stockham steps -STEEL_CONST int radix_13_steps_ [[function_constant(4)]]; -STEEL_CONST int radix_11_steps_ [[function_constant(5)]]; -STEEL_CONST int radix_8_steps_ [[function_constant(6)]]; -STEEL_CONST int radix_7_steps_ [[function_constant(7)]]; -STEEL_CONST int radix_6_steps_ [[function_constant(8)]]; -STEEL_CONST int radix_5_steps_ [[function_constant(9)]]; -STEEL_CONST int radix_4_steps_ [[function_constant(10)]]; -STEEL_CONST int radix_3_steps_ [[function_constant(11)]]; -STEEL_CONST int radix_2_steps_ [[function_constant(12)]]; -// Rader steps -STEEL_CONST int rader_13_steps_ [[function_constant(13)]]; -STEEL_CONST int rader_11_steps_ [[function_constant(14)]]; -STEEL_CONST int rader_8_steps_ [[function_constant(15)]]; -STEEL_CONST int rader_7_steps_ [[function_constant(16)]]; -STEEL_CONST int rader_6_steps_ [[function_constant(17)]]; -STEEL_CONST int rader_5_steps_ [[function_constant(18)]]; -STEEL_CONST int rader_4_steps_ [[function_constant(19)]]; -STEEL_CONST int rader_3_steps_ [[function_constant(20)]]; -STEEL_CONST int rader_2_steps_ [[function_constant(21)]]; - -// See "radix.h" for radix codelets -typedef void (*RadixFunc)(thread float2*, thread float2*); - -// Perform a single radix n butterfly with appropriate twiddles -template -METAL_FUNC void radix_butterfly( - int i, - int p, - thread float2* x, - thread short* indices, - thread float2* y) { - // i: the index in the overall DFT that we're processing. - // p: the size of the DFTs we're merging at this step. - // m: how many threads are working on this DFT. - int k, j; - - // Use faster bitwise operations when working with powers of two - constexpr bool radix_p_2 = (radix & (radix - 1)) == 0; - if (radix_p_2 && is_power_of_2_) { - constexpr short power = __builtin_ctz(radix); - k = i & (p - 1); - j = ((i - k) << power) + k; - } else { - k = i % p; - j = (i / p) * radix * p + k; - } - - // Apply twiddles - if (p > 1) { - float2 twiddle_1 = get_twiddle(k, radix * p); - float2 twiddle = twiddle_1; - x[1] = complex_mul(x[1], twiddle); - - STEEL_PRAGMA_UNROLL - for (int t = 2; t < radix; t++) { - twiddle = complex_mul(twiddle, twiddle_1); - x[t] = complex_mul(x[t], twiddle); - } - } - - radix_func(x, y); - - STEEL_PRAGMA_UNROLL - for (int t = 0; t < radix; t++) { - indices[t] = j + t * p; - } -} - -// Perform all the radix steps required for a -// particular radix size n. -template -METAL_FUNC void radix_n_steps( - int i, - thread int* p, - int m, - int n, - int num_steps, - thread float2* inputs, - thread short* indices, - thread float2* values, - threadgroup float2* buf) { - int m_r = n / radix; - // When combining different sized radices, we have to do - // multiple butterflies in a single thread. - // E.g. n = 28 = 4 * 7 - // 4 threads, 7 elems_per_thread - // All threads do 1 radix7 butterfly. - // 3 threads do 2 radix4 butterflies. - // 1 thread does 1 radix4 butterfly. - int max_radices_per_thread = (elems_per_thread_ + radix - 1) / radix; - - int index = 0; - int r_index = 0; - for (int s = 0; s < num_steps; s++) { - for (int t = 0; t < max_radices_per_thread; t++) { - index = i + t * m; - if (index < m_r) { - for (int r = 0; r < radix; r++) { - inputs[r] = buf[index + r * m_r]; - } - radix_butterfly( - index, *p, inputs, indices + t * radix, values + t * radix); - } - } - - // Wait until all threads have read their inputs into thread local mem - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int t = 0; t < max_radices_per_thread; t++) { - index = i + t * m; - if (index < m_r) { - for (int r = 0; r < radix; r++) { - r_index = t * radix + r; - buf[indices[r_index]] = values[r_index]; - } - } - } - - // Wait until all threads have written back to threadgroup mem - threadgroup_barrier(mem_flags::mem_threadgroup); - *p *= radix; - } -} - -#define RADIX_STEP(radix, radix_func, num_steps) \ - radix_n_steps( \ - fft_idx, p, m, n, num_steps, inputs, indices, values, buf); - -template -METAL_FUNC void -perform_fft(int fft_idx, thread int* p, int m, int n, threadgroup float2* buf) { - float2 inputs[MAX_RADIX]; - short indices[MAX_OUTPUT_SIZE]; - float2 values[MAX_OUTPUT_SIZE]; - - RADIX_STEP(2, radix2, rader ? rader_2_steps_ : radix_2_steps_); - RADIX_STEP(3, radix3, rader ? rader_3_steps_ : radix_3_steps_); - RADIX_STEP(4, radix4, rader ? rader_4_steps_ : radix_4_steps_); - RADIX_STEP(5, radix5, rader ? rader_5_steps_ : radix_5_steps_); - RADIX_STEP(6, radix6, rader ? rader_6_steps_ : radix_6_steps_); - RADIX_STEP(7, radix7, rader ? rader_7_steps_ : radix_7_steps_); - RADIX_STEP(8, radix8, rader ? rader_8_steps_ : radix_8_steps_); - RADIX_STEP(11, radix11, rader ? rader_11_steps_ : radix_11_steps_); - RADIX_STEP(13, radix13, rader ? rader_13_steps_ : radix_13_steps_); -} - -// Each FFT is computed entirely in shared GPU memory. -// -// N is decomposed into radix-n DFTs: -// e.g. 128 = 2 * 4 * 4 * 4 -template -[[kernel]] void fft( - const device in_T* in [[buffer(0)]], - device out_T* out [[buffer(1)]], - constant const int& n, - constant const int& batch_size, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - threadgroup float2 shared_in[tg_mem_size]; - - thread ReadWriter read_writer = ReadWriter( - in, - &shared_in[0], - out, - n, - batch_size, - elems_per_thread_, - elem, - grid, - inv_); - - if (read_writer.out_of_bounds()) { - return; - }; - read_writer.load(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - int p = 1; - int fft_idx = elem.z; // Thread index in DFT - int m = grid.z; // Threads per DFT - int tg_idx = elem.y * n; // Index of this DFT in threadgroup - threadgroup float2* buf = &shared_in[tg_idx]; - - perform_fft(fft_idx, &p, m, n, buf); - - read_writer.write(); -} - -template -[[kernel]] void rader_fft( - const device in_T* in [[buffer(0)]], - device out_T* out [[buffer(1)]], - const device float2* raders_b_q [[buffer(2)]], - const device short* raders_g_q [[buffer(3)]], - const device short* raders_g_minus_q [[buffer(4)]], - constant const int& n, - constant const int& batch_size, - constant const int& rader_n, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - // Use Rader's algorithm to compute fast FFTs - // when a prime factor `p` of `n` is greater than 13 but - // has `p - 1` Stockham decomposable into to prime factors <= 13. - // - // E.g. n = 102 - // = 2 * 3 * 17 - // . = 2 * 3 * RADER(16) - // . = 2 * 3 * RADER(4 * 4) - // - // In numpy: - // x_perm = x[g_q] - // y = np.fft.fft(x_perm) * b_q - // z = np.fft.ifft(y) + x[0] - // out = z[g_minus_q] - // out[0] = x[1:].sum() - // - // Where the g_q and g_minus_q are permutations formed - // by the group under multiplicative modulo N using the - // primitive root of N and b_q is a constant. - // See https://en.wikipedia.org/wiki/Rader%27s_FFT_algorithm - // - // Rader's uses fewer operations than Bluestein's and so - // is more accurate. It's also faster in most cases. - threadgroup float2 shared_in[tg_mem_size]; - - thread ReadWriter read_writer = ReadWriter( - in, - &shared_in[0], - out, - n, - batch_size, - elems_per_thread_, - elem, - grid, - inv_); - - if (read_writer.out_of_bounds()) { - return; - }; - read_writer.load(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // The number of the threads we're using for each DFT - int m = grid.z; - - int fft_idx = elem.z; - int tg_idx = elem.y * n; - threadgroup float2* buf = &shared_in[tg_idx]; - - // rader_m = n / rader_n; - int rader_m = rader_m_; - - // We have to load two x_0s for each thread since sometimes - // elems_per_thread_ crosses a boundary. - // E.g. with n = 34, rader_n = 17, elems_per_thread_ = 4 - // 0 0 0 0 1 1 1 1 2 2 2 2 3 3 3 3 4 4 4 4 5 5 5 5 6 6 6 6 7 7 7 7 8 8 - // 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 - short x_0_index = - metal::min(fft_idx * elems_per_thread_ / (rader_n - 1), rader_m - 1); - float2 x_0[2] = {buf[x_0_index], buf[x_0_index + 1]}; - - // Do the Rader permutation in shared memory - float2 temp[MAX_RADIX]; - int max_index = n - rader_m - 1; - for (int e = 0; e < elems_per_thread_; e++) { - short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); - short g_q = raders_g_q[index / rader_m]; - temp[e] = buf[rader_m + (g_q - 1) * rader_m + index % rader_m]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int e = 0; e < elems_per_thread_; e++) { - short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); - buf[index + rader_m] = temp[e]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Rader FFT on x[rader_m:] - int p = 1; - perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); - - // x_1 + ... + x_n is computed for us in the first FFT step so - // we save it in the first rader_m indices of the array for later. - int x_sum_index = metal::min(fft_idx, rader_m - 1); - buf[x_sum_index] = buf[rader_m + x_sum_index * (rader_n - 1)]; - - float2 inv = {1.0f, -1.0f}; - for (int e = 0; e < elems_per_thread_; e++) { - short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); - short interleaved_index = - index / rader_m + (index % rader_m) * (rader_n - 1); - temp[e] = complex_mul( - buf[rader_m + interleaved_index], - raders_b_q[interleaved_index % (rader_n - 1)]); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int e = 0; e < elems_per_thread_; e++) { - short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); - buf[rader_m + index] = temp[e] * inv; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Rader IFFT on x[rader_m:] - p = 1; - perform_fft(fft_idx, &p, m, n - rader_m, buf + rader_m); - - float2 rader_inv_factor = {1.0f / (rader_n - 1), -1.0f / (rader_n - 1)}; - - for (int e = 0; e < elems_per_thread_; e++) { - short index = metal::min(fft_idx * elems_per_thread_ + e, n - rader_m - 1); - short diff_index = index / (rader_n - 1) - x_0_index; - temp[e] = buf[rader_m + index] * rader_inv_factor + x_0[diff_index]; - } - - // Use the sum of elements that was computed in the first FFT - float2 x_sum = buf[x_0_index] + x_0[0]; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - for (int e = 0; e < elems_per_thread_; e++) { - short index = metal::min(fft_idx * elems_per_thread_ + e, max_index); - short g_q_index = index % (rader_n - 1); - short g_q = raders_g_minus_q[g_q_index]; - short out_index = index - g_q_index + g_q + (index / (rader_n - 1)); - buf[out_index] = temp[e]; - } - - buf[x_0_index * rader_n] = x_sum; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - p = rader_n; - perform_fft(fft_idx, &p, m, n, buf); - - read_writer.write(); -} - -template -[[kernel]] void bluestein_fft( - const device in_T* in [[buffer(0)]], - device out_T* out [[buffer(1)]], - const device float2* w_q [[buffer(2)]], - const device float2* w_k [[buffer(3)]], - constant const int& length, - constant const int& n, - constant const int& batch_size, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - // Computes arbitrary length FFTs with Bluestein's algorithm - // - // In numpy: - // bluestein_n = next_power_of_2(2*n - 1) - // out = w_k * np.fft.ifft(np.fft.fft(w_k * in, bluestein_n) * w_q) - // - // Where w_k and w_q are precomputed on CPU in high precision as: - // w_k = np.exp(-1j * np.pi / n * (np.arange(-n + 1, n) ** 2)) - // w_q = np.fft.fft(1/w_k[-n:]) - threadgroup float2 shared_in[tg_mem_size]; - - thread ReadWriter read_writer = ReadWriter( - in, - &shared_in[0], - out, - n, - batch_size, - elems_per_thread_, - elem, - grid, - inv_); - - if (read_writer.out_of_bounds()) { - return; - }; - read_writer.load_padded(length, w_k); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - int p = 1; - int fft_idx = elem.z; // Thread index in DFT - int m = grid.z; // Threads per DFT - int tg_idx = elem.y * n; // Index of this DFT in threadgroup - threadgroup float2* buf = &shared_in[tg_idx]; - - // fft - perform_fft(fft_idx, &p, m, n, buf); - - float2 inv = float2(1.0f, -1.0f); - for (int t = 0; t < elems_per_thread_; t++) { - int index = fft_idx + t * m; - buf[index] = complex_mul(buf[index], w_q[index]) * inv; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // ifft - p = 1; - perform_fft(fft_idx, &p, m, n, buf); - - read_writer.write_padded(length, w_k); -} - -template < - int tg_mem_size, - typename in_T, - typename out_T, - int step, - bool real = false> -[[kernel]] void four_step_fft( - const device in_T* in [[buffer(0)]], - device out_T* out [[buffer(1)]], - constant const int& n1, - constant const int& n2, - constant const int& batch_size, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - // Fast four step FFT implementation for powers of 2. - int overall_n = n1 * n2; - int n = step == 0 ? n1 : n2; - int stride = step == 0 ? n2 : n1; - - // The number of the threads we're using for each DFT - int m = grid.z; - int fft_idx = elem.z; - - threadgroup float2 shared_in[tg_mem_size]; - threadgroup float2* buf = &shared_in[elem.y * n]; - - using read_writer_t = ReadWriter; - read_writer_t read_writer = read_writer_t( - in, - &shared_in[0], - out, - n, - batch_size, - elems_per_thread_, - elem, - grid, - inv_); - - if (read_writer.out_of_bounds()) { - return; - }; - read_writer.load_strided(stride, overall_n); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - int p = 1; - perform_fft(fft_idx, &p, m, n, buf); - - read_writer.write_strided(stride, overall_n); -} diff --git a/Source/Cmlx/mlx-generated/metal/fft/radix.h b/Source/Cmlx/mlx-generated/metal/fft/radix.h deleted file mode 100644 index bd61eef6..00000000 --- a/Source/Cmlx/mlx-generated/metal/fft/radix.h +++ /dev/null @@ -1,328 +0,0 @@ -// Copyright © 2024 Apple Inc. - -/* Radix kernels - -We provide optimized, single threaded Radix codelets -for n=2,3,4,5,6,7,8,10,11,12,13. - -For n=2,3,4,5,6 we hand write the codelets. -For n=8,10,12 we combine smaller codelets. -For n=7,11,13 we use Rader's algorithm which decomposes -them into (n-1)=6,10,12 codelets. */ - -#pragma once - -#include -#include -#include - -METAL_FUNC float2 complex_mul(float2 a, float2 b) { - return float2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); -} - -// Complex mul followed by conjugate -METAL_FUNC float2 complex_mul_conj(float2 a, float2 b) { - return float2(a.x * b.x - a.y * b.y, -a.x * b.y - a.y * b.x); -} - -// Compute an FFT twiddle factor -METAL_FUNC float2 get_twiddle(int k, int p) { - float theta = -2.0f * k * M_PI_F / p; - - float2 twiddle = {metal::fast::cos(theta), metal::fast::sin(theta)}; - return twiddle; -} - -METAL_FUNC void radix2(thread float2* x, thread float2* y) { - y[0] = x[0] + x[1]; - y[1] = x[0] - x[1]; -} - -METAL_FUNC void radix3(thread float2* x, thread float2* y) { - float pi_2_3 = -0.8660254037844387; - - float2 a_1 = x[1] + x[2]; - float2 a_2 = x[1] - x[2]; - - y[0] = x[0] + a_1; - float2 b_1 = x[0] - 0.5 * a_1; - float2 b_2 = pi_2_3 * a_2; - - float2 b_2_j = {-b_2.y, b_2.x}; - y[1] = b_1 + b_2_j; - y[2] = b_1 - b_2_j; -} - -METAL_FUNC void radix4(thread float2* x, thread float2* y) { - float2 z_0 = x[0] + x[2]; - float2 z_1 = x[0] - x[2]; - float2 z_2 = x[1] + x[3]; - float2 z_3 = x[1] - x[3]; - float2 z_3_i = {z_3.y, -z_3.x}; - - y[0] = z_0 + z_2; - y[1] = z_1 + z_3_i; - y[2] = z_0 - z_2; - y[3] = z_1 - z_3_i; -} - -METAL_FUNC void radix5(thread float2* x, thread float2* y) { - float2 root_5_4 = 0.5590169943749475; - float2 sin_2pi_5 = 0.9510565162951535; - float2 sin_1pi_5 = 0.5877852522924731; - - float2 a_1 = x[1] + x[4]; - float2 a_2 = x[2] + x[3]; - float2 a_3 = x[1] - x[4]; - float2 a_4 = x[2] - x[3]; - - float2 a_5 = a_1 + a_2; - float2 a_6 = root_5_4 * (a_1 - a_2); - float2 a_7 = x[0] - a_5 / 4; - float2 a_8 = a_7 + a_6; - float2 a_9 = a_7 - a_6; - float2 a_10 = sin_2pi_5 * a_3 + sin_1pi_5 * a_4; - float2 a_11 = sin_1pi_5 * a_3 - sin_2pi_5 * a_4; - float2 a_10_j = {a_10.y, -a_10.x}; - float2 a_11_j = {a_11.y, -a_11.x}; - - y[0] = x[0] + a_5; - y[1] = a_8 + a_10_j; - y[2] = a_9 + a_11_j; - y[3] = a_9 - a_11_j; - y[4] = a_8 - a_10_j; -} - -METAL_FUNC void radix6(thread float2* x, thread float2* y) { - float sin_pi_3 = 0.8660254037844387; - float2 a_1 = x[2] + x[4]; - float2 a_2 = x[0] - a_1 / 2; - float2 a_3 = sin_pi_3 * (x[2] - x[4]); - float2 a_4 = x[5] + x[1]; - float2 a_5 = x[3] - a_4 / 2; - float2 a_6 = sin_pi_3 * (x[5] - x[1]); - float2 a_7 = x[0] + a_1; - - float2 a_3_i = {a_3.y, -a_3.x}; - float2 a_6_i = {a_6.y, -a_6.x}; - float2 a_8 = a_2 + a_3_i; - float2 a_9 = a_2 - a_3_i; - float2 a_10 = x[3] + a_4; - float2 a_11 = a_5 + a_6_i; - float2 a_12 = a_5 - a_6_i; - - y[0] = a_7 + a_10; - y[1] = a_8 - a_11; - y[2] = a_9 + a_12; - y[3] = a_7 - a_10; - y[4] = a_8 + a_11; - y[5] = a_9 - a_12; -} - -METAL_FUNC void radix7(thread float2* x, thread float2* y) { - // Rader's algorithm - float2 inv = {1 / 6.0, -1 / 6.0}; - - // fft - float2 in1[6] = {x[1], x[3], x[2], x[6], x[4], x[5]}; - radix6(in1, y + 1); - - y[0] = y[1] + x[0]; - - // b_q - y[1] = complex_mul_conj(y[1], float2(-1, 0)); - y[2] = complex_mul_conj(y[2], float2(2.44013336, -1.02261879)); - y[3] = complex_mul_conj(y[3], float2(2.37046941, -1.17510629)); - y[4] = complex_mul_conj(y[4], float2(0, -2.64575131)); - y[5] = complex_mul_conj(y[5], float2(2.37046941, 1.17510629)); - y[6] = complex_mul_conj(y[6], float2(-2.44013336, -1.02261879)); - - // ifft - radix6(y + 1, x + 1); - - y[1] = x[1] * inv + x[0]; - y[5] = x[2] * inv + x[0]; - y[4] = x[3] * inv + x[0]; - y[6] = x[4] * inv + x[0]; - y[2] = x[5] * inv + x[0]; - y[3] = x[6] * inv + x[0]; -} - -METAL_FUNC void radix8(thread float2* x, thread float2* y) { - float cos_pi_4 = 0.7071067811865476; - float2 w_0 = {cos_pi_4, -cos_pi_4}; - float2 w_1 = {-cos_pi_4, -cos_pi_4}; - float2 temp[8] = {x[0], x[2], x[4], x[6], x[1], x[3], x[5], x[7]}; - radix4(temp, x); - radix4(temp + 4, x + 4); - - y[0] = x[0] + x[4]; - y[4] = x[0] - x[4]; - float2 x_5 = complex_mul(x[5], w_0); - y[1] = x[1] + x_5; - y[5] = x[1] - x_5; - float2 x_6 = {x[6].y, -x[6].x}; - y[2] = x[2] + x_6; - y[6] = x[2] - x_6; - float2 x_7 = complex_mul(x[7], w_1); - y[3] = x[3] + x_7; - y[7] = x[3] - x_7; -} - -template -METAL_FUNC void radix10(thread float2* x, thread float2* y) { - float2 w[4]; - w[0] = {0.8090169943749475, -0.5877852522924731}; - w[1] = {0.30901699437494745, -0.9510565162951535}; - w[2] = {-w[1].x, w[1].y}; - w[3] = {-w[0].x, w[0].y}; - - if (raders_perm) { - float2 temp[10] = { - x[0], x[3], x[4], x[8], x[2], x[1], x[7], x[9], x[6], x[5]}; - radix5(temp, x); - radix5(temp + 5, x + 5); - } else { - float2 temp[10] = { - x[0], x[2], x[4], x[6], x[8], x[1], x[3], x[5], x[7], x[9]}; - radix5(temp, x); - radix5(temp + 5, x + 5); - } - - y[0] = x[0] + x[5]; - y[5] = x[0] - x[5]; - for (int t = 1; t < 5; t++) { - float2 a = complex_mul(x[t + 5], w[t - 1]); - y[t] = x[t] + a; - y[t + 5] = x[t] - a; - } -} - -METAL_FUNC void radix11(thread float2* x, thread float2* y) { - // Raders Algorithm - float2 inv = {1 / 10.0, -1 / 10.0}; - - // fft - radix10(x + 1, y + 1); - - y[0] = y[1] + x[0]; - - // b_q - y[1] = complex_mul_conj(y[1], float2(-1, 0)); - y[2] = complex_mul_conj(y[2], float2(0.955301878, -3.17606649)); - y[3] = complex_mul_conj(y[3], float2(2.63610556, 2.01269656)); - y[4] = complex_mul_conj(y[4], float2(2.54127802, 2.13117479)); - y[5] = complex_mul_conj(y[5], float2(2.07016210, 2.59122150)); - y[6] = complex_mul_conj(y[6], float2(0, -3.31662479)); - y[7] = complex_mul_conj(y[7], float2(2.07016210, -2.59122150)); - y[8] = complex_mul_conj(y[8], float2(-2.54127802, 2.13117479)); - y[9] = complex_mul_conj(y[9], float2(2.63610556, -2.01269656)); - y[10] = complex_mul_conj(y[10], float2(-0.955301878, -3.17606649)); - - // ifft - radix10(y + 1, x + 1); - - y[1] = x[1] * inv + x[0]; - y[6] = x[2] * inv + x[0]; - y[3] = x[3] * inv + x[0]; - y[7] = x[4] * inv + x[0]; - y[9] = x[5] * inv + x[0]; - y[10] = x[6] * inv + x[0]; - y[5] = x[7] * inv + x[0]; - y[8] = x[8] * inv + x[0]; - y[4] = x[9] * inv + x[0]; - y[2] = x[10] * inv + x[0]; -} - -template -METAL_FUNC void radix12(thread float2* x, thread float2* y) { - float2 w[6]; - float sin_pi_3 = 0.8660254037844387; - w[0] = {sin_pi_3, -0.5}; - w[1] = {0.5, -sin_pi_3}; - w[2] = {0, -1}; - w[3] = {-0.5, -sin_pi_3}; - w[4] = {-sin_pi_3, -0.5}; - - if (raders_perm) { - float2 temp[12] = { - x[0], - x[3], - x[2], - x[11], - x[8], - x[9], - x[1], - x[7], - x[5], - x[10], - x[4], - x[6]}; - radix6(temp, x); - radix6(temp + 6, x + 6); - } else { - float2 temp[12] = { - x[0], - x[2], - x[4], - x[6], - x[8], - x[10], - x[1], - x[3], - x[5], - x[7], - x[9], - x[11]}; - radix6(temp, x); - radix6(temp + 6, x + 6); - } - - y[0] = x[0] + x[6]; - y[6] = x[0] - x[6]; - for (int t = 1; t < 6; t++) { - float2 a = complex_mul(x[t + 6], w[t - 1]); - y[t] = x[t] + a; - y[t + 6] = x[t] - a; - } -} - -METAL_FUNC void radix13(thread float2* x, thread float2* y) { - // Raders Algorithm - float2 inv = {1 / 12.0, -1 / 12.0}; - - // fft - radix12(x + 1, y + 1); - - y[0] = y[1] + x[0]; - - // b_q - y[1] = complex_mul_conj(y[1], float2(-1, 0)); - y[2] = complex_mul_conj(y[2], float2(3.07497206, -1.88269669)); - y[3] = complex_mul_conj(y[3], float2(3.09912468, 1.84266823)); - y[4] = complex_mul_conj(y[4], float2(3.45084438, -1.04483161)); - y[5] = complex_mul_conj(y[5], float2(0.91083583, 3.48860690)); - y[6] = complex_mul_conj(y[6], float2(-3.60286363, 0.139189267)); - y[7] = complex_mul_conj(y[7], float2(3.60555128, 0)); - y[8] = complex_mul_conj(y[8], float2(3.60286363, 0.139189267)); - y[9] = complex_mul_conj(y[9], float2(0.91083583, -3.48860690)); - y[10] = complex_mul_conj(y[10], float2(-3.45084438, -1.04483161)); - y[11] = complex_mul_conj(y[11], float2(3.09912468, -1.84266823)); - y[12] = complex_mul_conj(y[12], float2(-3.07497206, -1.88269669)); - - // ifft - radix12(y + 1, x + 1); - - y[1] = x[1] * inv + x[0]; - y[7] = x[2] * inv + x[0]; - y[10] = x[3] * inv + x[0]; - y[5] = x[4] * inv + x[0]; - y[9] = x[5] * inv + x[0]; - y[11] = x[6] * inv + x[0]; - y[12] = x[7] * inv + x[0]; - y[6] = x[8] * inv + x[0]; - y[3] = x[9] * inv + x[0]; - y[8] = x[10] * inv + x[0]; - y[4] = x[11] * inv + x[0]; - y[2] = x[12] * inv + x[0]; -} \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/fft/readwrite.h b/Source/Cmlx/mlx-generated/metal/fft/readwrite.h deleted file mode 100644 index 4459d36f..00000000 --- a/Source/Cmlx/mlx-generated/metal/fft/readwrite.h +++ /dev/null @@ -1,624 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include - -#include "../fft/radix.h" - -/* FFT helpers for reading and writing from/to device memory. - -For many sizes, GPU FFTs are memory bandwidth bound so -read/write performance is important. - -Where possible, we read 128 bits sequentially in each thread, -coalesced with accesses from adjacent threads for optimal performance. - -We implement specialized reading/writing for: - - FFT - - RFFT - - IRFFT - -Each with support for: - - Contiguous reads - - Padded reads - - Strided reads -*/ - -#define MAX_RADIX 13 - -using namespace metal; - -template < - typename in_T, - typename out_T, - int step = 0, - bool four_step_real = false> -struct ReadWriter { - const device in_T* in; - threadgroup float2* buf; - device out_T* out; - int n; - int batch_size; - int elems_per_thread; - uint3 elem; - uint3 grid; - int threads_per_tg; - bool inv; - - // Used for strided access - int strided_device_idx = 0; - int strided_shared_idx = 0; - - METAL_FUNC ReadWriter( - const device in_T* in_, - threadgroup float2* buf_, - device out_T* out_, - const short n_, - const int batch_size_, - const short elems_per_thread_, - const uint3 elem_, - const uint3 grid_, - const bool inv_) - : in(in_), - buf(buf_), - out(out_), - n(n_), - batch_size(batch_size_), - elems_per_thread(elems_per_thread_), - elem(elem_), - grid(grid_), - inv(inv_) { - // Account for padding on last threadgroup - threads_per_tg = elem.x == grid.x - 1 - ? (batch_size - (grid.x - 1) * grid.y) * grid.z - : grid.y * grid.z; - } - - // ifft(x) = 1/n * conj(fft(conj(x))) - METAL_FUNC float2 post_in(float2 elem) const { - return inv ? float2(elem.x, -elem.y) : elem; - } - - // Handle float case for generic RFFT alg - METAL_FUNC float2 post_in(float elem) const { - return float2(elem, 0); - } - - METAL_FUNC float2 pre_out(float2 elem) const { - return inv ? float2(elem.x / n, -elem.y / n) : elem; - } - - METAL_FUNC float2 pre_out(float2 elem, int length) const { - return inv ? float2(elem.x / length, -elem.y / length) : elem; - } - - METAL_FUNC bool out_of_bounds() const { - // Account for possible extra threadgroups - int grid_index = elem.x * grid.y + elem.y; - return grid_index >= batch_size; - } - - METAL_FUNC void load() const { - size_t batch_idx = size_t(elem.x * grid.y) * n; - short tg_idx = elem.y * grid.z + elem.z; - short max_index = grid.y * n - 2; - - // 2 complex64s = 128 bits - constexpr int read_width = 2; - for (short e = 0; e < (elems_per_thread / read_width); e++) { - short index = read_width * tg_idx + read_width * threads_per_tg * e; - index = metal::min(index, max_index); - // vectorized reads - buf[index] = post_in(in[batch_idx + index]); - buf[index + 1] = post_in(in[batch_idx + index + 1]); - } - max_index += 1; - if (elems_per_thread % 2 != 0) { - short index = tg_idx + - read_width * threads_per_tg * (elems_per_thread / read_width); - index = metal::min(index, max_index); - buf[index] = post_in(in[batch_idx + index]); - } - } - - METAL_FUNC void write() const { - size_t batch_idx = size_t(elem.x * grid.y) * n; - short tg_idx = elem.y * grid.z + elem.z; - short max_index = grid.y * n - 2; - - constexpr int read_width = 2; - for (short e = 0; e < (elems_per_thread / read_width); e++) { - short index = read_width * tg_idx + read_width * threads_per_tg * e; - index = metal::min(index, max_index); - // vectorized reads - out[batch_idx + index] = pre_out(buf[index]); - out[batch_idx + index + 1] = pre_out(buf[index + 1]); - } - max_index += 1; - if (elems_per_thread % 2 != 0) { - short index = tg_idx + - read_width * threads_per_tg * (elems_per_thread / read_width); - index = metal::min(index, max_index); - out[batch_idx + index] = pre_out(buf[index]); - } - } - - // Padded IO for Bluestein's algorithm - METAL_FUNC void load_padded(int length, const device float2* w_k) const { - size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; - int fft_idx = elem.z; - int m = grid.z; - - threadgroup float2* seq_buf = buf + elem.y * n; - for (int e = 0; e < elems_per_thread; e++) { - int index = metal::min(fft_idx + e * m, n - 1); - if (index < length) { - float2 elem = post_in(in[batch_idx + index]); - seq_buf[index] = complex_mul(elem, w_k[index]); - } else { - seq_buf[index] = 0.0; - } - } - } - - METAL_FUNC void write_padded(int length, const device float2* w_k) const { - size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; - int fft_idx = elem.z; - int m = grid.z; - float2 inv_factor = {1.0f / n, -1.0f / n}; - - threadgroup float2* seq_buf = buf + elem.y * n; - for (int e = 0; e < elems_per_thread; e++) { - int index = metal::min(fft_idx + e * m, n - 1); - if (index < length) { - float2 elem = seq_buf[index + length - 1] * inv_factor; - out[batch_idx + index] = pre_out(complex_mul(elem, w_k[index]), length); - } - } - } - - // Strided IO for four step FFT - METAL_FUNC void compute_strided_indices(int stride, int overall_n) { - // Use the batch threadgroup dimension to coalesce memory accesses: - // e.g. stride = 12 - // device | shared mem - // 0 1 2 3 | 0 12 - - - // - - - - | 1 13 - - - // - - - - | 2 14 - - - // 12 13 14 15 | 3 15 - - - int coalesce_width = grid.y; - int tg_idx = elem.y * grid.z + elem.z; - int outer_batch_size = stride / coalesce_width; - - int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + - overall_n * (elem.x / outer_batch_size); - strided_device_idx = strided_batch_idx + - tg_idx / coalesce_width * elems_per_thread * stride + - tg_idx % coalesce_width; - strided_shared_idx = (tg_idx % coalesce_width) * n + - tg_idx / coalesce_width * elems_per_thread; - } - - // Four Step FFT First Step - METAL_FUNC void load_strided(int stride, int overall_n) { - compute_strided_indices(stride, overall_n); - for (int e = 0; e < elems_per_thread; e++) { - buf[strided_shared_idx + e] = - post_in(in[strided_device_idx + e * stride]); - } - } - - METAL_FUNC void write_strided(int stride, int overall_n) { - for (int e = 0; e < elems_per_thread; e++) { - float2 output = buf[strided_shared_idx + e]; - int combined_idx = (strided_device_idx + e * stride) % overall_n; - int ij = (combined_idx / stride) * (combined_idx % stride); - // Apply four step twiddles at end of first step - float2 twiddle = get_twiddle(ij, overall_n); - out[strided_device_idx + e * stride] = complex_mul(output, twiddle); - } - } -}; - -// Four Step FFT Second Step -template <> -METAL_FUNC void ReadWriter::load_strided( - int stride, - int overall_n) { - // Silence compiler warnings - (void)stride; - (void)overall_n; - // Don't invert between steps - bool default_inv = inv; - inv = false; - load(); - inv = default_inv; -} - -template <> -METAL_FUNC void ReadWriter::write_strided( - int stride, - int overall_n) { - compute_strided_indices(stride, overall_n); - for (int e = 0; e < elems_per_thread; e++) { - float2 output = buf[strided_shared_idx + e]; - out[strided_device_idx + e * stride] = pre_out(output, overall_n); - } -} - -// For RFFT, we interleave batches of two real sequences into one complex one: -// -// z_k = x_k + j.y_k -// X_k = (Z_k + Z_(N-k)*) / 2 -// Y_k = -j * ((Z_k - Z_(N-k)*) / 2) -// -// This roughly doubles the throughput over the regular FFT. -template <> -METAL_FUNC bool ReadWriter::out_of_bounds() const { - int grid_index = elem.x * grid.y + elem.y; - // We pack two sequences into one for RFFTs - return grid_index * 2 >= batch_size; -} - -template <> -METAL_FUNC void ReadWriter::load() const { - size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; - threadgroup float2* seq_buf = buf + elem.y * n; - - // No out of bounds accesses on odd batch sizes - int grid_index = elem.x * grid.y + elem.y; - short next_in = - batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; - - short m = grid.z; - short fft_idx = elem.z; - - for (int e = 0; e < elems_per_thread; e++) { - int index = metal::min(fft_idx + e * m, n - 1); - seq_buf[index].x = in[batch_idx + index]; - seq_buf[index].y = in[batch_idx + index + next_in]; - } -} - -template <> -METAL_FUNC void ReadWriter::write() const { - short n_over_2 = (n / 2) + 1; - - size_t batch_idx = - size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; - threadgroup float2* seq_buf = buf + elem.y * n; - - int grid_index = elem.x * grid.y + elem.y; - short next_out = - batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; - - float2 conj = {1, -1}; - float2 minus_j = {0, -1}; - - short m = grid.z; - short fft_idx = elem.z; - - for (int e = 0; e < elems_per_thread / 2 + 1; e++) { - int index = metal::min(fft_idx + e * m, n_over_2 - 1); - // x_0 = z_0.real - // y_0 = z_0.imag - if (index == 0) { - out[batch_idx + index] = {seq_buf[index].x, 0}; - out[batch_idx + index + next_out] = {seq_buf[index].y, 0}; - } else { - float2 x_k = seq_buf[index]; - float2 x_n_minus_k = seq_buf[n - index] * conj; - out[batch_idx + index] = (x_k + x_n_minus_k) / 2; - out[batch_idx + index + next_out] = - complex_mul(((x_k - x_n_minus_k) / 2), minus_j); - } - } -} - -template <> -METAL_FUNC void ReadWriter::load_padded( - int length, - const device float2* w_k) const { - size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; - threadgroup float2* seq_buf = buf + elem.y * n; - - // No out of bounds accesses on odd batch sizes - int grid_index = elem.x * grid.y + elem.y; - short next_in = - batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; - - short m = grid.z; - short fft_idx = elem.z; - - for (int e = 0; e < elems_per_thread; e++) { - int index = metal::min(fft_idx + e * m, n - 1); - if (index < length) { - float2 elem = - float2(in[batch_idx + index], in[batch_idx + index + next_in]); - seq_buf[index] = complex_mul(elem, w_k[index]); - } else { - seq_buf[index] = 0; - } - } -} - -template <> -METAL_FUNC void ReadWriter::write_padded( - int length, - const device float2* w_k) const { - int length_over_2 = (length / 2) + 1; - size_t batch_idx = - size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; - threadgroup float2* seq_buf = buf + elem.y * n + length - 1; - - int grid_index = elem.x * grid.y + elem.y; - short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 - ? 0 - : length_over_2; - - float2 conj = {1, -1}; - float2 inv_factor = {1.0f / n, -1.0f / n}; - float2 minus_j = {0, -1}; - - short m = grid.z; - short fft_idx = elem.z; - - for (int e = 0; e < elems_per_thread / 2 + 1; e++) { - int index = metal::min(fft_idx + e * m, length_over_2 - 1); - // x_0 = z_0.real - // y_0 = z_0.imag - if (index == 0) { - float2 elem = complex_mul(w_k[index], seq_buf[index] * inv_factor); - out[batch_idx + index] = float2(elem.x, 0); - out[batch_idx + index + next_out] = float2(elem.y, 0); - } else { - float2 x_k = complex_mul(w_k[index], seq_buf[index] * inv_factor); - float2 x_n_minus_k = complex_mul( - w_k[length - index], seq_buf[length - index] * inv_factor); - x_n_minus_k *= conj; - // w_k should happen before this extraction - out[batch_idx + index] = (x_k + x_n_minus_k) / 2; - out[batch_idx + index + next_out] = - complex_mul(((x_k - x_n_minus_k) / 2), minus_j); - } - } -} - -// For IRFFT, we do the opposite -// -// Z_k = X_k + j.Y_k -// x_k = Re(Z_k) -// Y_k = Imag(Z_k) -template <> -METAL_FUNC bool ReadWriter::out_of_bounds() const { - int grid_index = elem.x * grid.y + elem.y; - // We pack two sequences into one for IRFFTs - return grid_index * 2 >= batch_size; -} - -template <> -METAL_FUNC void ReadWriter::load() const { - short n_over_2 = (n / 2) + 1; - size_t batch_idx = - size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; - threadgroup float2* seq_buf = buf + elem.y * n; - - // No out of bounds accesses on odd batch sizes - int grid_index = elem.x * grid.y + elem.y; - short next_in = - batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n_over_2; - - short m = grid.z; - short fft_idx = elem.z; - - float2 conj = {1, -1}; - float2 plus_j = {0, 1}; - - for (int t = 0; t < elems_per_thread / 2 + 1; t++) { - int index = metal::min(fft_idx + t * m, n_over_2 - 1); - float2 x = in[batch_idx + index]; - float2 y = in[batch_idx + index + next_in]; - // NumPy forces first input to be real - bool first_val = index == 0; - // NumPy forces last input on even irffts to be real - bool last_val = n % 2 == 0 && index == n_over_2 - 1; - if (first_val || last_val) { - x = float2(x.x, 0); - y = float2(y.x, 0); - } - seq_buf[index] = x + complex_mul(y, plus_j); - seq_buf[index].y = -seq_buf[index].y; - if (index > 0 && !last_val) { - seq_buf[n - index] = (x * conj) + complex_mul(y * conj, plus_j); - seq_buf[n - index].y = -seq_buf[n - index].y; - } - } -} - -template <> -METAL_FUNC void ReadWriter::write() const { - int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; - threadgroup float2* seq_buf = buf + elem.y * n; - - int grid_index = elem.x * grid.y + elem.y; - short next_out = - batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : n; - - short m = grid.z; - short fft_idx = elem.z; - - for (int e = 0; e < elems_per_thread; e++) { - int index = metal::min(fft_idx + e * m, n - 1); - out[batch_idx + index] = seq_buf[index].x / n; - out[batch_idx + index + next_out] = seq_buf[index].y / -n; - } -} - -template <> -METAL_FUNC void ReadWriter::load_padded( - int length, - const device float2* w_k) const { - int n_over_2 = (n / 2) + 1; - int length_over_2 = (length / 2) + 1; - - size_t batch_idx = - size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; - threadgroup float2* seq_buf = buf + elem.y * n; - - // No out of bounds accesses on odd batch sizes - int grid_index = elem.x * grid.y + elem.y; - short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 - ? 0 - : length_over_2; - - short m = grid.z; - short fft_idx = elem.z; - - float2 conj = {1, -1}; - float2 plus_j = {0, 1}; - - for (int t = 0; t < elems_per_thread / 2 + 1; t++) { - int index = metal::min(fft_idx + t * m, n_over_2 - 1); - float2 x = in[batch_idx + index]; - float2 y = in[batch_idx + index + next_in]; - if (index < length_over_2) { - bool last_val = length % 2 == 0 && index == length_over_2 - 1; - if (last_val) { - x = float2(x.x, 0); - y = float2(y.x, 0); - } - float2 elem1 = x + complex_mul(y, plus_j); - seq_buf[index] = complex_mul(elem1 * conj, w_k[index]); - if (index > 0 && !last_val) { - float2 elem2 = (x * conj) + complex_mul(y * conj, plus_j); - seq_buf[length - index] = - complex_mul(elem2 * conj, w_k[length - index]); - } - } else { - short pad_index = metal::min(length + (index - length_over_2) * 2, n - 2); - seq_buf[pad_index] = 0; - seq_buf[pad_index + 1] = 0; - } - } -} - -template <> -METAL_FUNC void ReadWriter::write_padded( - int length, - const device float2* w_k) const { - size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; - threadgroup float2* seq_buf = buf + elem.y * n + length - 1; - - int grid_index = elem.x * grid.y + elem.y; - short next_out = - batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 ? 0 : length; - - short m = grid.z; - short fft_idx = elem.z; - - float2 inv_factor = {1.0f / n, -1.0f / n}; - for (int e = 0; e < elems_per_thread; e++) { - int index = fft_idx + e * m; - if (index < length) { - float2 output = complex_mul(seq_buf[index] * inv_factor, w_k[index]); - out[batch_idx + index] = output.x / length; - out[batch_idx + index + next_out] = output.y / -length; - } - } -} - -// Four Step RFFT -template <> -METAL_FUNC void -ReadWriter::load_strided( - int stride, - int overall_n) { - // Silence compiler warnings - (void)stride; - (void)overall_n; - // Don't invert between steps - bool default_inv = inv; - inv = false; - load(); - inv = default_inv; -} - -template <> -METAL_FUNC void -ReadWriter::write_strided( - int stride, - int overall_n) { - int overall_n_over_2 = overall_n / 2 + 1; - int coalesce_width = grid.y; - int tg_idx = elem.y * grid.z + elem.z; - int outer_batch_size = stride / coalesce_width; - - int strided_batch_idx = (elem.x % outer_batch_size) * coalesce_width + - overall_n_over_2 * (elem.x / outer_batch_size); - strided_device_idx = strided_batch_idx + - tg_idx / coalesce_width * elems_per_thread / 2 * stride + - tg_idx % coalesce_width; - strided_shared_idx = (tg_idx % coalesce_width) * n + - tg_idx / coalesce_width * elems_per_thread / 2; - for (int e = 0; e < elems_per_thread / 2; e++) { - float2 output = buf[strided_shared_idx + e]; - out[strided_device_idx + e * stride] = output; - } - - // Add on n/2 + 1 element - if (tg_idx == 0 && elem.x % outer_batch_size == 0) { - out[strided_batch_idx + overall_n / 2] = buf[n / 2]; - } -} - -// Four Step IRFFT -template <> -METAL_FUNC void -ReadWriter::load_strided( - int stride, - int overall_n) { - int overall_n_over_2 = overall_n / 2 + 1; - auto conj = float2(1, -1); - - compute_strided_indices(stride, overall_n); - // Translate indices in terms of N - k - for (int e = 0; e < elems_per_thread; e++) { - int device_idx = strided_device_idx + e * stride; - int overall_batch = device_idx / overall_n; - int overall_index = device_idx % overall_n; - if (overall_index < overall_n_over_2) { - device_idx -= overall_batch * (overall_n - overall_n_over_2); - buf[strided_shared_idx + e] = in[device_idx] * conj; - } else { - int conj_idx = overall_n - overall_index; - device_idx = overall_batch * overall_n_over_2 + conj_idx; - buf[strided_shared_idx + e] = in[device_idx]; - } - } -} - -template <> -METAL_FUNC void -ReadWriter::load_strided( - int stride, - int overall_n) { - // Silence compiler warnings - (void)stride; - (void)overall_n; - bool default_inv = inv; - inv = false; - load(); - inv = default_inv; -} - -template <> -METAL_FUNC void -ReadWriter::write_strided( - int stride, - int overall_n) { - compute_strided_indices(stride, overall_n); - - for (int e = 0; e < elems_per_thread; e++) { - out[strided_device_idx + e * stride] = - pre_out(buf[strided_shared_idx + e], overall_n).x; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/fp4.h b/Source/Cmlx/mlx-generated/metal/fp4.h deleted file mode 100644 index 25642f20..00000000 --- a/Source/Cmlx/mlx-generated/metal/fp4.h +++ /dev/null @@ -1,48 +0,0 @@ -#pragma once - -struct fp4_e2m1 { - fp4_e2m1(float x) { - if (metal::isnan(x)) { - bits = 0x7; - return; - } - - const uint8_t sign_bit = (metal::signbit(x)) ? 0x8 : 0x0; - x = metal::abs(x); - - if (x > 5.0f) { - bits = 0x7; - } else if (x >= 3.5f) { - bits = 0x6; - } else if (x > 2.5f) { - bits = 0x5; - } else if (x >= 1.75f) { - bits = 0x4; - } else if (x > 1.25f) { - bits = 0x3; - } else if (x >= 0.75f) { - bits = 0x2; - } else if (x > 0.25f) { - bits = 0x1; - } else { - bits = 0x0; - } - bits |= sign_bit; - } - - operator float16_t() { - half converted = as_type(ushort((bits & 7) << 9)); - converted *= 16384.0; - return bits & 8 ? -converted : converted; - } - - operator float() { - return static_cast(this->operator float16_t()); - } - - operator bfloat16_t() { - return static_cast(this->operator float16_t()); - } - - uint8_t bits; -}; diff --git a/Source/Cmlx/mlx-generated/metal/fp8.h b/Source/Cmlx/mlx-generated/metal/fp8.h deleted file mode 100644 index 60d34be6..00000000 --- a/Source/Cmlx/mlx-generated/metal/fp8.h +++ /dev/null @@ -1,80 +0,0 @@ -#pragma once - -struct fp8_e4m3 { - template - fp8_e4m3(T f) { - // From PyTorch - // https://github.com/pytorch/pytorch/blob/e3643e1e0e923f0fc063dfab6f45c956d568919d/c10/util/Float8_e4m3fn.h#L148 - uint32_t fp8_max = 543 << 21; - uint32_t denorm_mask = 141 << 23; - uint32_t f_bits = as_type(static_cast(f)); - uint32_t sign = f_bits & 0x80000000; - f_bits ^= sign; - if (f_bits >= fp8_max) { - // Default behavior saturates to min/max - bits = 0x7E; - } else { - if (f_bits < (121 << 23)) { - f_bits = as_type( - as_type(f_bits) + as_type(denorm_mask)); - bits = static_cast(f_bits - denorm_mask); - } else { - // resulting mantissa is odd - uint8_t mant_odd = (f_bits >> 20) & 1; - f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; - f_bits += mant_odd; - bits = static_cast(f_bits >> 20); - } - } - bits |= static_cast(sign >> 24); - } - - operator float16_t() { - uint16_t v = (bits & 127) << 7; - half converted = as_type(v); - converted *= 256.0; - auto sign = bits & 128; - return (sign ? -converted : converted); - } - - operator bfloat16_t() { - return static_cast(this->operator float16_t()); - } - - operator float() { - return static_cast(this->operator float16_t()); - } - - uint8_t bits; -}; - -struct fp8_e8m0 { - fp8_e8m0(float x) { - if (!metal::isfinite(x)) { - bits = 0xFF; - return; - } - if (x < 0.0f) { - bits = 0x00; - return; - } - float le = metal::log2(x); - int n = int(metal::round(le)); - - n = n < -127 ? -127 : n; - n = n > 127 ? 127 : n; - bits = static_cast(n + 127); - } - - operator bfloat16_t() { - uint16_t out = (bits == 0 ? 0x40 : (static_cast(bits) << 7)); - return as_type(out); - } - - operator float() { - uint32_t out = (bits == 0 ? 0x400000 : (static_cast(bits) << 23)); - return as_type(out); - } - - uint8_t bits; -}; diff --git a/Source/Cmlx/mlx-generated/metal/fp_quantized.h b/Source/Cmlx/mlx-generated/metal/fp_quantized.h deleted file mode 100644 index eef3f2cf..00000000 --- a/Source/Cmlx/mlx-generated/metal/fp_quantized.h +++ /dev/null @@ -1,1850 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include -#include - -#include "fp4.h" -#include "fp8.h" - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -using namespace metal; - -#define MLX_MTL_CONST static constant constexpr const - -MLX_MTL_CONST int SIMD_SIZE = 32; -MLX_MTL_CONST int QUAD_SIZE = 4; - -template -inline constexpr short get_pack_factor() { - return wsize / bits; -} - -template -inline constexpr short get_bytes_per_pack() { - return wsize / 8; -} - -template -static inline T dequantize_scale(uint8_t s) { - if constexpr (group_size == 16) { - // Use nv scale - return T(*(thread fp8_e4m3*)(&s)); - } else { - return T(*(thread fp8_e8m0*)(&s)); - } -} - -template -struct Quantize { - uint8_t operator()(float x) { - if (bits == 8) { - return fp8_e4m3(x).bits; - } else { - return fp4_e2m1(x).bits; - } - } -}; - -template -struct Dequantize { - U operator()(uint8_t x) { - if constexpr (bits == 8) { - return U(*(thread fp8_e4m3*)(&x)); - } else { - return U(*(thread fp4_e2m1*)(&x)); - } - } -}; - -template -inline void load_vector(const device T* x, thread U* x_thread) { -#pragma unroll - for (int i = 0; i < values_per_thread; i++) { - x_thread[i] = x[i]; - } -} - -template -inline void load_vector_safe(const device T* x, thread U* x_thread, int N) { - for (int i = 0; i < N; i++) { - x_thread[i] = x[i]; - } - - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } -} - -template -inline U qdot(const device uint8_t* w, const thread U* x_thread, U scale) { - U accum = 0; - if constexpr (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - for (int i = 0; i < (values_per_thread / 4); i++) { - accum += - (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + - x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + - x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + - x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); - } - } else { - for (int i = 0; i < values_per_thread; i++) { - accum += x_thread[i] * Dequantize<8>{}(w[i]); - } - } - - return scale * accum; -} - -template -inline U -qdot_safe(const device uint8_t* w, const thread U* x_thread, U scale, int N) { - U accum = 0; - - if constexpr (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - for (int i = 0; i < (N / 4); i++) { - accum += - (x_thread[4 * i] * Dequantize<4>{}(ws[i]) + - x_thread[4 * i + 1] * Dequantize<4>{}(ws[i] >> 4) + - x_thread[4 * i + 2] * Dequantize<4>{}(ws[i] >> 8) + - x_thread[4 * i + 3] * Dequantize<4>{}(ws[i] >> 12)); - } - } else { - for (int i = 0; i < N; i++) { - accum += x_thread[i] * Dequantize<8>{}(w[i]); - } - } - return scale * accum; -} - -template -inline void qouter(const thread uint8_t* w, U x, U scale, thread U* result) { - if constexpr (bits == 4) { - for (int i = 0; i < (values_per_thread / 2); i++) { - result[2 * i] += x * scale * Dequantize<4>{}(w[i]); - result[2 * i + 1] += x * scale * Dequantize<4>{}(w[i] >> 4); - } - } else { - for (int i = 0; i < values_per_thread; i++) { - result[i] += x * scale * Dequantize<8>{}(w[i]); - } - } -} - -template -inline void dequantize(uint8_t w, U scale, threadgroup U* w_local) { - if constexpr (bits == 4) { - w_local[0] = scale * Dequantize<4, U>{}(w); - w_local[1] = scale * Dequantize<4, U>{}(w >> 4); - } else { - w_local[0] = scale * Dequantize<8, U>{}(w); - } -} - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short group_size, - short bits> -struct QuantizedBlockLoader { - MLX_MTL_CONST short pack_factor = get_pack_factor<8, bits>(); - MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); - MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; - MLX_MTL_CONST short n_reads = - (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; - MLX_MTL_CONST short group_steps = group_size < BCOLS ? 1 : group_size / BCOLS; - MLX_MTL_CONST short scale_step = group_size < BCOLS ? BCOLS / group_size : 1; - - static_assert( - (n_reads * pack_factor) <= group_size, - "The number of reads per thread must be less than the group size."); - - const int src_ld; - const int tile_stride; - short group_step_cnt; - const int group_stride; - - const short thread_idx; - const short bi; - const short bj; - - threadgroup T* dst; - const device uint8_t* src; - const device uint8_t* scales; - - QuantizedBlockLoader( - const device uint8_t* src_, - const device uint8_t* scales_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride( - reduction_dim ? BCOLS_PACKED * bytes_per_pack - : BROWS * src_ld * bytes_per_pack / pack_factor), - group_step_cnt(0), - group_stride(BROWS * src_ld / group_size), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(n_reads * thread_idx / BCOLS_PACKED), - bj((n_reads * thread_idx) % BCOLS_PACKED), - dst(dst_ + bi * dst_ld + bj * pack_factor), - src(src_ + bi * src_ld * bytes_per_pack / pack_factor + - bj * bytes_per_pack), - scales( - scales_ + bi * src_ld / group_size + - (bj * pack_factor) / group_size) {} - - void load_unsafe() const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - T scale = dequantize_scale(*scales); - for (int i = 0; i < n_reads; i++) { - dequantize( - src[i * bytes_per_pack], scale, dst + i * pack_factor); - } - } - - void load_safe(short2 src_tile_dim) const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - if (reduction_dim == 1 && bi >= src_tile_dim.x) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - if (reduction_dim == 0 && bi >= src_tile_dim.y) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - T scale = dequantize_scale(*scales); - for (int i = 0; i < n_reads; i++) { - dequantize( - src[i * bytes_per_pack], scale, dst + i * pack_factor); - } - } - - void next() { - src += tile_stride; - if (reduction_dim == 1) { - if (group_steps > 1) { - group_step_cnt++; - if (group_step_cnt == group_steps) { - group_step_cnt = 0; - scales++; - } - } else { - scales += scale_step; - } - } else { - scales += group_stride; - } - } -}; - -template -METAL_FUNC void fp_qmv_quad_impl( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - constant int& in_vec_size, - const constant int& out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { - constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; - constexpr int pack_factor = get_pack_factor<32, bits>(); - constexpr int values_per_thread = D / QUAD_SIZE; - constexpr int steps_per_thread = - values_per_thread < group_size ? 1 : values_per_thread / group_size; - constexpr int values_per_step = values_per_thread / steps_per_thread; - constexpr int packs_per_thread = values_per_thread / pack_factor; - constexpr int packs_per_step = values_per_step / pack_factor; - constexpr int results_per_quadgroup = 8; - - typedef float U; - - thread U x_thread[values_per_thread]; - thread U result[results_per_quadgroup] = {0}; - - // Adjust positions - const int in_vec_size_w = in_vec_size / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; - - w += out_row * in_vec_size_w + quad_lid * packs_per_thread; - scales += - out_row * in_vec_size_g + (quad_lid * values_per_thread) / group_size; - x += tid.x * in_vec_size + quad_lid * values_per_thread; - y += tid.x * out_vec_size + out_row; - - load_vector(x, x_thread); - - for (int row = 0; row < results_per_quadgroup; row++) { - auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); - const device uint8_t* sl = scales + row * in_vec_size_g * quads_per_simd; -#pragma unroll - for (int k = 0; k < steps_per_thread; ++k) { - U s = dequantize_scale(sl[0]); - if (row * quads_per_simd + out_row < out_vec_size) { - result[row] += qdot( - wl, x_thread + k * values_per_step, s); - } - sl++; - wl += (sizeof(uint32_t) / sizeof(uint8_t)) * packs_per_step; - } - } - - for (int row = 0; row < results_per_quadgroup; row++) { - result[row] = quad_sum(result[row]); - if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { - y[row * quads_per_simd] = static_cast(result[row]); - } - } -} - -template -METAL_FUNC void fp_qmv_fast_impl( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int packs_per_thread = 2; - constexpr int num_simdgroups = 2; - constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = get_pack_factor<32, bits>(); - constexpr int bytes_per_pack = get_bytes_per_pack<32>(); - constexpr int values_per_thread = pack_factor * packs_per_thread; - constexpr int block_size = values_per_thread * SIMD_SIZE; - constexpr int scale_step_per_thread = group_size / values_per_thread; - - const device uint8_t* ws = (const device uint8_t*)w; - - typedef float U; - thread U x_thread[values_per_thread]; - thread U result[results_per_simdgroup] = {0}; - - // Adjust positions - const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + - simd_gid * results_per_simdgroup; - - ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; - scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.x * in_vec_size + simd_lid * values_per_thread; - y += tid.x * out_vec_size + out_row; - - for (int k = 0; k < in_vec_size; k += block_size) { - load_vector(x, x_thread); - - for (int row = 0; row < results_per_simdgroup; row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device auto* sl = scales + row * in_vec_size_g; - - U s = dequantize_scale(sl[0]); - result[row] += qdot(wl, x_thread, s); - } - - ws += block_size * bytes_per_pack / pack_factor; - scales += block_size / group_size; - x += block_size; - } - - for (int row = 0; row < results_per_simdgroup; row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { - y[row] = static_cast(result[row]); - } - } -} - -template -METAL_FUNC void fp_qmv_impl( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 2; - constexpr int results_per_simdgroup = 4; - constexpr int packs_per_thread = 1; - constexpr int pack_factor = get_pack_factor<32, bits>(); - constexpr int bytes_per_pack = get_bytes_per_pack<32>(); - - constexpr int values_per_thread = pack_factor * packs_per_thread; - constexpr int block_size = values_per_thread * SIMD_SIZE; - constexpr int scale_step_per_thread = group_size / values_per_thread; - - const device uint8_t* ws = (const device uint8_t*)w; - - typedef float U; - - thread U x_thread[values_per_thread]; - thread U result[results_per_simdgroup] = {0}; - - // Adjust positions - const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + - simd_gid * results_per_simdgroup; - const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); - - if (out_row >= out_vec_size) { - return; - } - - // In this case we need to properly guard all our reads because there isn't - // even 1 tile in the matrix - if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { - ws += - out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; - scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.x * in_vec_size + simd_lid * values_per_thread; - y += tid.x * out_vec_size + out_row; - - int k = 0; - for (; k < in_vec_size - block_size; k += block_size) { - load_vector(x, x_thread); - - for (int row = 0; - row < results_per_simdgroup && out_row + row < out_vec_size; - row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device auto* sl = scales + row * in_vec_size_g; - - uint8_t s = sl[0]; - result[row] += qdot(wl, x_thread, s); - } - - ws += block_size * bytes_per_pack / pack_factor; - scales += block_size / group_size; - x += block_size; - } - const int remaining = clamp( - static_cast(in_vec_size - k - simd_lid * values_per_thread), - 0, - values_per_thread); - if (remaining > 0) { - load_vector_safe(x, x_thread, remaining); - - for (int row = 0; - row < results_per_simdgroup && out_row + row < out_vec_size; - row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device auto* sl = scales + row * in_vec_size_g; - - U s = dequantize_scale(sl[0]); - result[row] += qdot(wl, x_thread, s); - } - } - - for (int row = 0; - row < results_per_simdgroup && out_row + row < out_vec_size; - row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { - y[row] = static_cast(result[row]); - } - } - } - - // In this case the last tile is moved back to redo some output values - else { - ws += used_out_row * in_vec_size_w + - simd_lid * packs_per_thread * bytes_per_pack; - scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.x * in_vec_size + simd_lid * values_per_thread; - y += tid.x * out_vec_size + used_out_row; - - int k = 0; - for (; k < in_vec_size - block_size; k += block_size) { - load_vector(x, x_thread); - - for (int row = 0; row < results_per_simdgroup; row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device auto* sl = scales + row * in_vec_size_g; - - U s = dequantize_scale(sl[0]); - result[row] += qdot(wl, x_thread, s); - } - - ws += block_size * bytes_per_pack / pack_factor; - scales += block_size / group_size; - x += block_size; - } - const int remaining = clamp( - static_cast(in_vec_size - k - simd_lid * values_per_thread), - 0, - values_per_thread); - if (remaining > 0) { - load_vector_safe(x, x_thread, remaining); - - for (int row = 0; row < results_per_simdgroup; row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device auto* sl = scales + row * in_vec_size_g; - - U s = dequantize_scale(sl[0]); - result[row] += - qdot_safe(wl, x_thread, s, remaining); - } - } - for (int row = 0; row < results_per_simdgroup; row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { - y[row] = static_cast(result[row]); - } - } - } -} - -template -METAL_FUNC void fp_qvm_impl( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const int in_vec_size, - const int out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 2; - constexpr int pack_factor = get_pack_factor<32, bits>(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int tn = group_size / pack_factor; - constexpr int block_size = SIMD_SIZE; - - using W_T = uint32_t; - const device W_T* ws = (const device W_T*)w; - - typedef float U; - typedef struct { - W_T wi[tn * bytes_per_pack]; - } vec_w; - - thread vec_w w_local; - thread U result[tn * pack_factor] = {0}; - thread U scale = 0; - thread U x_local = 0; - - // Adjust positions - const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; - const int out_vec_size_g = out_vec_size / group_size; - // 32 * (tid.y * 2 + simd_gid) - int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); - ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; - scales += out_col / group_size + simd_lid * out_vec_size_g; - x += tid.x * in_vec_size + simd_lid; - y += tid.x * out_vec_size + out_col; - - if (out_col >= out_vec_size) { - return; - } - - // Loop over in_vec in blocks of block_size - int remaining = in_vec_size % block_size; - if (remaining == 0) { - for (int i = 0; i < in_vec_size; i += block_size) { - x_local = *x; - scale = dequantize_scale(*scales); - w_local = *((device vec_w*)ws); - qouter( - (thread uint8_t*)&w_local, x_local, scale, result); - - x += block_size; - scales += block_size * out_vec_size_g; - ws += block_size * out_vec_size_w; - } - } else { - for (int i = block_size; i < in_vec_size; i += block_size) { - x_local = *x; - scale = dequantize_scale(*scales); - w_local = *((device vec_w*)ws); - - qouter( - (thread uint8_t*)&w_local, x_local, scale, result); - - x += block_size; - scales += block_size * out_vec_size_g; - ws += block_size * out_vec_size_w; - } - if (static_cast(simd_lid) < remaining) { - x_local = *x; - scale = dequantize_scale(*scales); - w_local = *((device vec_w*)ws); - } else { - x_local = 0; - scale = 0; - } - qouter( - (thread uint8_t*)&w_local, x_local, scale, result); - } - -// Accumulate in the simdgroup -#pragma clang loop unroll(full) - for (int k = 0; k < tn * pack_factor; k++) { - result[k] = simd_sum(result[k]); - } - - // Store the result - if (simd_lid == 0) { -#pragma clang loop unroll(full) - for (int k = 0; k < tn * pack_factor; k++) { - y[k] = static_cast(result[k]); - } - } -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const int BM = 32, - const int BK = 32, - const int BN = 32> -METAL_FUNC void fp_qmm_t_impl( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - threadgroup T* Xs, - threadgroup T* Ws, - const constant int& K, - const constant int& N, - const constant int& M, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - (void)lid; - - constexpr int WM = 2; - constexpr int WN = 2; - constexpr int pack_factor = get_pack_factor<8, bits>(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel:: - BlockMMA; - using loader_x_t = - mlx::steel::BlockLoader; - using loader_w_t = QuantizedBlockLoader< - T, - BN, - BK, - BK_padded, - 1, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - // Set the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - - auto wl = (const device uint8_t*)w; - - x += y_row * static_cast(K); - wl += y_col * K_w; - scales += y_col * K_g; - y += y_row * static_cast(N) + y_col; - - // Make the x loader and mma operation - const short num_els = min(BM, M - y_row); - const short num_outs = min(BN, N - y_col); - loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); - mma_t mma_op(simd_gid, simd_lid); - - if (num_els < BM) { - if (!aligned_N && num_outs < BN) { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_safe(short2(BK, num_outs)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } else { - if (!aligned_N && num_outs < BN) { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_safe(short2(BK, num_outs)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - if (num_els < BM || num_outs < BN) { - mma_op.store_result_safe(y, N, short2(num_outs, num_els)); - } else { - mma_op.store_result(y, N); - } -} - -template < - typename T, - const int group_size, - const int bits, - const int BM = 32, - const int BK = 32, - const int BN = 32> -METAL_FUNC void fp_qmm_n_impl( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - threadgroup T* Xs, - threadgroup T* Ws, - const constant int& K, - const constant int& N, - const constant int& M, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - (void)lid; - - constexpr int WM = 2; - constexpr int WN = 2; - constexpr int pack_factor = get_pack_factor<8, bits>(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel:: - BlockMMA; - using loader_x_t = mlx::steel:: - BlockLoader; - using loader_w_t = QuantizedBlockLoader< - T, - BK, - BN, - BN_padded, - 0, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - auto wl = (const device uint8_t*)w; - - // Set the block - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - x += y_row * static_cast(K); - wl += y_col * bytes_per_pack / pack_factor; - scales += y_col / group_size; - y += y_row * static_cast(N) + y_col; - - // Make the x loader and mma operation - const short num_els = min(BM, M - y_row); - loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(wl, scales, N, Ws, simd_gid, simd_lid); - mma_t mma_op(simd_gid, simd_lid); - - if (num_els < BM) { - if ((K % BK) != 0) { - const int k_blocks = K / BK; - for (int k = 0; k < k_blocks; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - const short num_k = K - k_blocks * BK; - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(num_k, num_els)); - loader_w.load_safe(short2(BN, num_k)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } else { - if ((K % BK) != 0) { - const int k_blocks = K / BK; - for (int k = 0; k < k_blocks; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - const short num_k = K - k_blocks * BK; - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(num_k, BM)); - loader_w.load_safe(short2(BN, num_k)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - if (num_els < BM) { - mma_op.store_result_safe(y, N, short2(BN, num_els)); - } else { - mma_op.store_result(y, N); - } -} - -template -METAL_FUNC void adjust_matrix_offsets( - const device T*& x, - const device uint32_t*& w, - const device uint8_t*& scales, - device T*& y, - int output_stride, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]]) { - // Set the input/output matrices - uint32_t x_idx = tid.z; - uint32_t w_idx = tid.z; - if (x_batch_ndims == 1) { - x += x_idx * x_strides[0]; - } else { - x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); - } - if (w_batch_ndims == 1) { - w += w_idx * w_strides[0]; - scales += w_idx * s_strides[0]; - } else { - ulong2 idx = elem_to_loc_broadcast( - w_idx, w_shape, w_strides, s_strides, w_batch_ndims); - w += idx.x; - scales += idx.y; - } - y += tid.z * output_stride; -} - -template -METAL_FUNC void adjust_matrix_offsets( - const device T*& x, - const device uint32_t*& w, - const device uint8_t*& scales, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T*& y, - int output_stride, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]]) { - // Set the input/output matrices - uint32_t x_idx; - uint32_t w_idx; - if (batch_ndims == 1) { - x_idx = lhs_indices[tid.z * lhs_strides[0]]; - w_idx = rhs_indices[tid.z * rhs_strides[0]]; - } else { - ulong2 idx = elem_to_loc_broadcast( - tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); - x_idx = lhs_indices[idx.x]; - w_idx = rhs_indices[idx.y]; - } - if (x_batch_ndims == 1) { - x += x_idx * x_strides[0]; - } else { - x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); - } - if (w_batch_ndims == 1) { - w += w_idx * w_strides[0]; - scales += w_idx * s_strides[0]; - } else { - ulong2 idx = elem_to_loc_broadcast( - w_idx, w_shape, w_strides, s_strides, w_batch_ndims); - w += idx.x; - scales += idx.y; - } - y += tid.z * output_stride; -} - -template -[[kernel]] void fp_qmv_quad( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { - if (batched) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - } - fp_qmv_quad_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, quad_gid, quad_lid); -} - -template -[[kernel]] void fp_qmv_fast( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - if (batched) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - } - fp_qmv_fast_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); -} - -template -[[kernel]] void fp_qmv( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - if (batched) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - } - fp_qmv_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); -} - -template -[[kernel]] void fp_qvm( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - if (batched) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - } - fp_qvm_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); -} - -template -[[kernel]] void fp_qvm_split_k( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int& final_block_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - - // When (in_vec_size % split_k != 0) the final block needs to be smaller - int in_vec_size_adj = - tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; - - fp_qvm_impl( - w, scales, x, y, in_vec_size_adj, out_vec_size, tid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const bool batched, - const int BM = 32, - const int BK = 32, - const int BN = 32> -[[kernel]] void fp_qmm_t( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& K, - const constant int& N, - const constant int& M, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BN * BK_padded]; - - if (batched) { - adjust_matrix_offsets( - x, - w, - scales, - y, - M * N, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - } - fp_qmm_t_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool batched, - const int BM = 32, - const int BK = 32, - const int BN = 32> -[[kernel]] void fp_qmm_n( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& K, - const constant int& N, - const constant int& M, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; - - if (batched) { - adjust_matrix_offsets( - x, - w, - scales, - y, - M * N, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - } - - fp_qmm_n_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template -[[kernel]] void fp_gather_qmv_fast( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - lhs_indices, - rhs_indices, - y, - out_vec_size * M, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - fp_qmv_fast_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); -} - -template -[[kernel]] void fp_gather_qmv( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - lhs_indices, - rhs_indices, - y, - out_vec_size * M, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - fp_qmv_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); -} - -template -[[kernel]] void fp_gather_qvm( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - lhs_indices, - rhs_indices, - y, - out_vec_size * M, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - fp_qvm_impl( - w, scales, x, y, in_vec_size, out_vec_size, tid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const int BM = 32, - const int BK = 32, - const int BN = 32> -[[kernel]] void fp_gather_qmm_t( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T* y, - const constant int& K, - const constant int& N, - const constant int& M, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BN * BK_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - fp_qmm_t_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const int BM = 32, - const int BK = 32, - const int BN = 32> -[[kernel]] void fp_gather_qmm_n( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T* y, - const constant int& K, - const constant int& N, - const constant int& M, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - fp_qmm_n_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - int group_size, - int bits, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose> -[[kernel]] void fp_gather_qmm_rhs( - const device T* x, - const device uint32_t* w, - const device uint8_t* scales, - const device uint32_t* indices, - device T* y, - const constant int& M, - const constant int& N, - const constant int& K, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = get_pack_factor<8, bits>(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - using mma_t = mlx::steel::BlockMMA< - T, - T, - BM, - BN, - BK, - WM, - WN, - false, - transpose, - BK_padded, - transpose ? BK_padded : BN_padded>; - using loader_x_t = - mlx::steel::BlockLoader; - using loader_w_t = QuantizedBlockLoader< - T, - transpose ? BN : BK, - transpose ? BK : BN, - transpose ? BK_padded : BN_padded, - transpose, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; - - // Compute the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int N_w = N * bytes_per_pack / pack_factor; - const int N_g = N / group_size; - const int K_it = K / BK; - const size_t stride_w = transpose ? N * K_w : K * N_w; - const size_t stride_s = transpose ? N * K_g : K * N_g; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - const size_t y_row_long = size_t(y_row); - const size_t y_col_long = size_t(y_col); - - // Prepare threadgroup bounds - const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); - const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); - - // Calculate the final tiles in the case that K is not aligned - const int k_remain = K - K_it * BK; - const short2 tile_x = short2(k_remain, tgp_bm); - const short2 tile_w = - transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - // Move x and output to the correct block - auto wl = (const device uint8_t*)w; - x += y_row_long * K; - y += y_row_long * N + y_col_long; - wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; - scales += transpose ? y_col_long * K_g : y_col / group_size; - - // Do as many matmuls as necessary - uint32_t index; - short offset; - uint32_t index_next = indices[y_row]; - short offset_next = 0; - int n = 0; - while (n < tgp_bm) { - n++; - offset = offset_next; - index = index_next; - offset_next = tgp_bm; - for (; n < tgp_bm; n++) { - if (indices[y_row + n] != index) { - offset_next = n; - index_next = indices[y_row + n]; - break; - } - } - threadgroup_barrier(mem_flags::mem_none); - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - // Prepare threadgroup loading operations - thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); - thread loader_w_t loader_w( - wl + index * stride_w, - scales + index * stride_s, - transpose ? K : N, - Ws, - simd_group_id, - simd_lane_id); - - // Matrices are all aligned check nothing - if (align_M && align_N) { - gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - - // Store results to device memory - if (offset_next - offset == BM) { - mma_op.store_result(y, N); - } else { - mma_op.store_result_slice( - y, N, short2(0, offset), short2(BN, offset_next)); - } - } else { - // Tile aligned so check outside of the hot loop - if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { - gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - - // Store results to device memory - if (offset_next - offset == BM) { - mma_op.store_result(y, N); - } else { - mma_op.store_result_slice( - y, N, short2(0, offset), short2(BN, offset_next)); - } - } - - // Tile partially aligned check rows - else if (align_N || tgp_bn == BN) { - gemm_loop_unaligned( - Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - mma_op.store_result_slice( - y, N, short2(0, offset), short2(BN, offset_next)); - } - - // Tile partially aligned check cols - else if (align_M || tgp_bm == BM) { - gemm_loop_unaligned( - Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - mma_op.store_result_slice( - y, N, short2(0, offset), short2(tgp_bn, offset_next)); - } - - // Nothing aligned so check both rows and cols - else { - gemm_loop_unaligned( - Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - mma_op.store_result_slice( - y, N, short2(0, offset), short2(tgp_bn, offset_next)); - } - } - } -} - -template -[[kernel]] void fp_quantize( - const device T* w [[buffer(0)]], - device uint8_t* out [[buffer(1)]], - device uint8_t* scales [[buffer(2)]], - uint2 tidx [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - constexpr bool use_mx_scale = group_size == 32; - size_t index = tidx.x + grid_dim.x * size_t(tidx.y); - - float scale; - float w_thread = w[index]; - if (use_mx_scale) { - scale = simd_max(abs(w_thread)); - } else { - float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); - float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); - scale = tidx.x < 16 ? w_max_l : w_max_r; - } - scale /= bits == 4 ? 6.0f : 448.0f; - - using ScaleType = metal::conditional_t; - auto s = ScaleType(scale); - uint8_t q_scale = s.bits; - scale = float(s); - - size_t gindex = index / group_size; - if (index % group_size == 0) { - scales[gindex] = q_scale; - } - - uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); - if (bits == 4) { - uint8_t sval = simd_shuffle_down(output, 1); - output |= sval << bits; - } - constexpr int pack_factor = bits == 8 ? 1 : 2; - if (index % pack_factor == 0) { - out[index / pack_factor] = output; - } -} - -template -[[kernel]] void fp_dequantize( - const device uint8_t* w [[buffer(0)]], - const device uint8_t* scales [[buffer(1)]], - device T* out [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - constexpr bool use_mx_scale = group_size == 32; - constexpr int pack_factor = bits == 8 ? 1 : 2; - size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t oindex = offset * pack_factor; - size_t gindex = oindex / group_size; - - out += oindex; - - using ScaleType = metal::conditional_t; - auto q_scale = ((device ScaleType*)(scales))[gindex]; - auto scale = float(q_scale); - - uint val = w[offset]; -#pragma clang loop unroll(full) - for (int i = 0; i < pack_factor; i++) { - uint8_t d; - if (bits == 4) { - d = (val >> (bits * i)) & 0x0f; - } else if (bits == 8) { - d = val; - } - out[i] = static_cast(scale * Dequantize{}(d)); - } -} - -template -[[kernel]] void fp_quantize_dequantize( - const device T* w [[buffer(0)]], - device T* out [[buffer(1)]], - uint2 tidx [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - constexpr bool use_mx_scale = group_size == 32; - size_t index = tidx.x + grid_dim.x * size_t(tidx.y); - - float scale; - float w_thread = w[index]; - if (use_mx_scale) { - scale = simd_max(abs(w_thread)); - } else { - float w_max_l = simd_max(tidx.x < 16 ? abs(w_thread) : 0.0); - float w_max_r = simd_max(tidx.x >= 16 ? abs(w_thread) : 0.0); - scale = tidx.x < 16 ? w_max_l : w_max_r; - } - scale /= bits == 4 ? 6.0f : 448.0f; - - using ScaleType = metal::conditional_t; - auto s = ScaleType(scale); - scale = float(s); - - uint8_t output = Quantize{}(scale == 0 ? 0.0f : w_thread / scale); - - out[index] = static_cast(scale * Dequantize{}(output)); -} diff --git a/Source/Cmlx/mlx-generated/metal/fp_quantized_nax.h b/Source/Cmlx/mlx-generated/metal/fp_quantized_nax.h deleted file mode 100644 index 38d9fb65..00000000 --- a/Source/Cmlx/mlx-generated/metal/fp_quantized_nax.h +++ /dev/null @@ -1,1044 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#include -#include - -#include "fp4.h" -#include "fp8.h" - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -using namespace metal; - -#define MLX_MTL_CONST static constant constexpr const - -MLX_MTL_CONST int SIMD_SIZE = 32; -MLX_MTL_CONST int QUAD_SIZE = 4; - -template -inline constexpr short get_pack_factor() { - return wsize / bits; -} - -template -inline constexpr short get_bytes_per_pack() { - return wsize / 8; -} - -template -static inline T dequantize_scale(uint8_t s) { - if constexpr (group_size == 16) { - // Use nv scale - return T(*(thread fp8_e4m3*)(&s)); - } else { - return T(*(thread fp8_e8m0*)(&s)); - } -} - -template -struct Quantize { - uint8_t operator()(float x) { - if (bits == 8) { - return fp8_e4m3(x).bits; - } else { - return fp4_e2m1(x).bits; - } - } -}; - -template -struct Dequantize { - U operator()(uint8_t x) { - if constexpr (bits == 8) { - return U(*(thread fp8_e4m3*)(&x)); - } else { - return U(*(thread fp4_e2m1*)(&x)); - } - } -}; - -template -inline void dequantize(uint8_t w, U scale, threadgroup U* w_local) { - if constexpr (bits == 4) { - w_local[0] = scale * Dequantize<4, U>{}(w); - w_local[1] = scale * Dequantize<4, U>{}(w >> 4); - } else { - w_local[0] = scale * Dequantize<8, U>{}(w); - } -} - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short group_size, - short bits> -struct QuantizedBlockLoader { - MLX_MTL_CONST short pack_factor = get_pack_factor<8, bits>(); - MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); - MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; - MLX_MTL_CONST short n_reads = - (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; - - MLX_MTL_CONST short n_reads_per_scale = (n_reads * pack_factor) <= group_size - ? n_reads - : (group_size / pack_factor); - MLX_MTL_CONST short n_steps_per_read = n_reads / n_reads_per_scale; - - MLX_MTL_CONST short n_groups = BCOLS / group_size; - - const int src_ld; - const int tile_stride; - const int group_stride; - - const short thread_idx; - const short bi; - const short bj; - - const short group_id; - - threadgroup T* dst; - const device uint8_t* src; - const device uint8_t* scales; - - QuantizedBlockLoader( - const device uint8_t* src_, - const device uint8_t* scales_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride( - reduction_dim ? BCOLS_PACKED * bytes_per_pack - : BROWS * src_ld * bytes_per_pack / pack_factor), - group_stride(BROWS * src_ld / group_size), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(n_reads * thread_idx / BCOLS_PACKED), - bj((n_reads * thread_idx) % BCOLS_PACKED), - group_id((bj * pack_factor) / group_size), - dst(dst_ + bi * dst_ld + bj * pack_factor), - src(src_ + bi * src_ld * bytes_per_pack / pack_factor + - bj * bytes_per_pack), - scales(scales_ + bi * src_ld / group_size + group_id) {} - - void load_unsafe() const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - int k = 0; - for (int i = 0; i < n_steps_per_read; i++) { - T scale = dequantize_scale(scales[i]); - for (int j = 0; j < n_reads_per_scale; j++) { - dequantize( - src[k * bytes_per_pack], scale, dst + k * pack_factor); - k++; - } - } - } - - void load_safe(short2 src_tile_dim) const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - if (reduction_dim == 1 && bi >= src_tile_dim.x) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - if (reduction_dim == 0 && bi >= src_tile_dim.y) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - int k = 0; - for (int i = 0; i < n_steps_per_read; i++) { - T scale = dequantize_scale(scales[i]); - for (int j = 0; j < n_reads_per_scale; j++) { - dequantize( - src[k * bytes_per_pack], scale, dst + k * pack_factor); - k++; - } - } - } - - void next() { - src += tile_stride; - if (reduction_dim == 1) { - scales += n_groups; - } else { - scales += n_groups * group_stride; - } - } -}; - -using namespace mlx::steel; - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2, - typename Wtype = bfloat> -METAL_FUNC void fp_qmm_t_impl( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - threadgroup Wtype* Ws, - const constant int& K, - const constant int& N, - const constant int& M, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - (void)lid; - - constexpr int pack_factor = get_pack_factor<8, bits>(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); - - // Instantiate Loader - using loader_w_t = QuantizedBlockLoader< - Wtype, - BN, - BK, - BK_padded, - 1, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - // Set the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - - auto wl = (const device uint8_t*)w; - - x += y_row * static_cast(K); - wl += y_col * K_w; - scales += y_col * K_g; - y += y_row * static_cast(N) + y_col; - - // Make the weight loader - loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); - - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; - constexpr short SM = BM / WM; - constexpr short SN = BN / WN; - constexpr short SK = 32; - - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; - - const short tm = SM * (simd_gid / WN); - const short tn = SN * (simd_gid % WN); - - constexpr bool transpose_a = false; - constexpr bool transpose_b = true; - - const short sgp_sm = min(SM, short(M - (y_row + tm))); - const bool is_unaligned_sm = (sgp_sm != SM); - - const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); - - const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); - const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); - - using AccumType = float; - - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - NAXTile Dtile; - - Dtile.clear(); - - x += tm * K; - - dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { - dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - if constexpr (kAlignedN.value) { - loader_w.load_unsafe(); - } else { - loader_w.load_safe(short2(BK, tgp_bn)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - volatile int compiler_barrier; - - if constexpr (kAlignedM.value) { - Atile.load(x + kk1, K); - } else { - Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); - } - - Btile.template load(Ws + tn * BK_padded + kk1); - - tile_matmad_nax( - Dtile, - Atile, - metal::bool_constant{}, - Btile, - metal::bool_constant{}); - - (void)compiler_barrier; - } - - x += BK; - loader_w.next(); - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - - if constexpr (kAlignedM.value && kAlignedN.value) { - Dtile.store(y + tm * N + tn, N); - } else if (kAlignedM.value && sgp_sn == SN) { - Dtile.store(y + tm * N + tn, N); - } else { - Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); - } - }); - }); -} - -template < - typename T, - const int group_size, - const int bits, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2, - typename Wtype = bfloat> -METAL_FUNC void fp_qmm_n_impl( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - threadgroup T* Ws, - const constant int& K, - const constant int& N, - const constant int& M, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - (void)lid; - (void)M; - - constexpr int pack_factor = get_pack_factor<8, bits>(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - using loader_w_t = QuantizedBlockLoader< - T, - BK, - BN, - BN_padded, - 0, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - // Set the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - - auto wl = (const device uint8_t*)w; - - x += y_row * static_cast(K); - wl += y_col * K_w; - scales += y_col * K_g; - y += y_row * static_cast(N) + y_col; - - // Make the x loader and mma operation - // const short num_els = min(BM, M - y_row); - // const short num_outs = min(BN, N - y_col); - loader_w_t loader_w(wl, scales, K, Ws, simd_gid, simd_lid); - - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; - constexpr short SM = BM / WM; - constexpr short SN = BN / WN; - constexpr short SK = 32; - - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; - - const short tm = SM * (simd_gid / WN); - const short tn = SN * (simd_gid % WN); - - const short ldb_tgp = BN_padded; - - constexpr bool transpose_a = false; - constexpr bool transpose_b = false; - - using AccumType = float; - - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - NAXTile Dtile; - - Dtile.clear(); - - x += tm * K; - - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - volatile int compiler_barrier; - - Atile.load(x + kk1, K); - Btile.template load(Ws + tn + kk1 * ldb_tgp); - - tile_matmad_nax( - Dtile, - Atile, - metal::bool_constant{}, - Btile, - metal::bool_constant{}); - - (void)compiler_barrier; - } - - x += BK; - loader_w.next(); - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - - Dtile.store(y + tm * N + tn, N); -} - -template -METAL_FUNC void adjust_matrix_offsets( - const device T*& x, - const device uint32_t*& w, - const device S*& scales, - device T*& y, - int output_stride, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]]) { - // Set the input/output matrices - uint32_t x_idx = tid.z; - uint32_t w_idx = tid.z; - if (x_batch_ndims == 1) { - x += x_idx * x_strides[0]; - } else { - x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); - } - if (w_batch_ndims == 1) { - w += w_idx * w_strides[0]; - scales += w_idx * s_strides[0]; - } else { - ulong2 idx = elem_to_loc_broadcast( - w_idx, w_shape, w_strides, s_strides, w_batch_ndims); - w += idx.x; - scales += idx.y; - } - y += tid.z * output_stride; -} - -template -METAL_FUNC void adjust_matrix_offsets( - const device T*& x, - const device uint32_t*& w, - const device S*& scales, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T*& y, - int output_stride, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]]) { - // Set the input/output matrices - uint32_t x_idx; - uint32_t w_idx; - if (batch_ndims == 1) { - x_idx = lhs_indices[tid.z * lhs_strides[0]]; - w_idx = rhs_indices[tid.z * rhs_strides[0]]; - } else { - ulong2 idx = elem_to_loc_broadcast( - tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); - x_idx = lhs_indices[idx.x]; - w_idx = rhs_indices[idx.y]; - } - if (x_batch_ndims == 1) { - x += x_idx * x_strides[0]; - } else { - x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); - } - if (w_batch_ndims == 1) { - w += w_idx * w_strides[0]; - scales += w_idx * s_strides[0]; - } else { - ulong2 idx = elem_to_loc_broadcast( - w_idx, w_shape, w_strides, s_strides, w_batch_ndims); - w += idx.x; - scales += idx.y; - } - y += tid.z * output_stride; -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const bool batched, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2, - typename Wtype = bfloat> -[[kernel]] void fp_qmm_t_nax( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& K, - const constant int& N, - const constant int& M, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); - - threadgroup Wtype Ws[BN * BK_padded]; - - if (batched) { - adjust_matrix_offsets( - x, - w, - scales, - y, - M * N, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - } - fp_qmm_t_impl( - w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool batched, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2, - typename Wtype = bfloat> -[[kernel]] void fp_qmm_n_nax( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - device T* y, - const constant int& K, - const constant int& N, - const constant int& M, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; - - if (batched) { - adjust_matrix_offsets( - x, - w, - scales, - y, - M * N, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - } - - fp_qmm_n_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2, - typename Wtype = bfloat> -[[kernel]] void fp_gather_qmm_t_nax( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T* y, - const constant int& K, - const constant int& N, - const constant int& M, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); - - threadgroup Wtype Ws[BN * BK_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - fp_qmm_t_impl( - w, scales, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2, - typename Wtype = bfloat> -[[kernel]] void fp_gather_qmm_n_nax( - const device uint32_t* w, - const device uint8_t* scales, - const device T* x, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T* y, - const constant int& K, - const constant int& N, - const constant int& M, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - tid); - fp_qmm_n_impl( - w, scales, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - int group_size, - const int bits, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose, - typename Wtype = bfloat> -[[kernel]] void fp_gather_qmm_rhs_nax( - const device T* x, - const device uint32_t* w, - const device uint8_t* scales, - const device uint32_t* indices, - device T* y, - const constant int& M, - const constant int& N, - const constant int& K, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = get_pack_factor<8, bits>(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - constexpr int BK_padded = (BK + 16 / sizeof(Wtype)); - constexpr int BN_padded = (BN + 16 / sizeof(Wtype)); - - using loader_w_t = QuantizedBlockLoader< - Wtype, - transpose ? BN : BK, - transpose ? BK : BN, - transpose ? BK_padded : BN_padded, - transpose, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - threadgroup Wtype Ws[transpose ? BN * BK_padded : BK * BN_padded]; - - // Compute the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int N_w = N * bytes_per_pack / pack_factor; - const int N_g = N / group_size; - const int K_it = K / BK; - const size_t stride_w = transpose ? N * K_w : K * N_w; - const size_t stride_s = transpose ? N * K_g : K * N_g; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - const size_t y_row_long = size_t(y_row); - const size_t y_col_long = size_t(y_col); - - // Prepare threadgroup bounds - const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); - const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); - - // Calculate the final tiles in the case that K is not aligned - const int k_remain = K - K_it * BK; - const short2 tile_w = - transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - // Move x and output to the correct block - auto wl = (const device uint8_t*)w; - x += y_row_long * K; - y += y_row_long * N + y_col_long; - wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; - scales += transpose ? y_col_long * K_g : y_col / group_size; - - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; - constexpr short SM = BM / WM; - constexpr short SN = BN / WN; - constexpr short SK = 32; - - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; - - const short tm = SM * (simd_group_id / WN); - const short tn = SN * (simd_group_id % WN); - - const short sgp_sm = - align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); - const short sgp_sn = - align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); - - const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); - const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); - - constexpr short BR = transpose ? TN : TK; - constexpr short BC = transpose ? TK : TN; - - using AccumType = float; - - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - // Do as many matmuls as necessary - uint32_t index; - short offset; - uint32_t index_next = indices[y_row]; - short offset_next = 0; - int n = 0; - while (n < tgp_bm) { - n++; - offset = offset_next; - index = index_next; - offset_next = tgp_bm; - for (; n < tgp_bm; n++) { - if (indices[y_row + n] != index) { - offset_next = n; - index_next = indices[y_row + n]; - break; - } - } - threadgroup_barrier(mem_flags::mem_none); - - // Prepare threadgroup mma operation - NAXTile Dtile; - - Dtile.clear(); - - const device T* xn = x + tm * K; - - // Prepare threadgroup loading operations - thread loader_w_t loader_w( - wl + index * stride_w, - scales + index * stride_s, - transpose ? K : N, - Ws, - simd_group_id, - simd_lane_id); - - dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { - dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { - for (int k = 0; k < K_it; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - if constexpr (kAlignedN.value) { - loader_w.load_unsafe(); - } else { - loader_w.load_safe( - transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - volatile int compiler_barrier; - - if constexpr (kAlignedM.value) { - Atile.load(xn + kk1, K); - } else { - Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); - } - - if constexpr (transpose) { - Btile.template load( - Ws + tn * BK_padded + kk1); - } else { - Btile.template load( - Ws + tn + kk1 * BN_padded); - } - - tile_matmad_nax( - Dtile, - Atile, - metal::bool_constant{}, - Btile, - metal::bool_constant{}); - - (void)compiler_barrier; - } - - xn += BK; - loader_w.next(); - } - - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_w.load_safe(tile_w); - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - volatile int compiler_barrier; - - const short psk = min(int(SK), max(0, (BK - kk1))); - Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); - - if constexpr (transpose) { - Btile.template load( - Ws + tn * BK_padded + kk1); - } else { - Btile.template load( - Ws + tn + kk1 * BN_padded); - } - - tile_matmad_nax( - Dtile, - Atile, - metal::bool_constant{}, - Btile, - metal::bool_constant{}); - - (void)compiler_barrier; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); - const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); - - // Store results to device memory - if constexpr (kAlignedN.value) { - if (m_lo_lim == 0 && m_hi_lim == SM) { - Dtile.store(y + tm * N + tn, N); - } else { - Dtile.store_slice( - y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); - } - } else { - Dtile.store_slice( - y + tm * N + tn, - N, - short2(0, m_lo_lim), - short2(sgp_sn, m_hi_lim)); - } - }); - }); - } -} diff --git a/Source/Cmlx/mlx-generated/metal/gemv.metal b/Source/Cmlx/mlx-generated/metal/gemv.metal deleted file mode 100644 index 89403d3d..00000000 --- a/Source/Cmlx/mlx-generated/metal/gemv.metal +++ /dev/null @@ -1,868 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include - -#include "utils.h" - -#include "steel/utils.h" - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -/// Matrix vector multiplication -/////////////////////////////////////////////////////////////////////////////// - -#define MLX_MTL_CONST static constant constexpr const - -template -struct DefaultAccT { - using type = float; -}; -template <> -struct DefaultAccT { - using type = complex64_t; -}; - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ - typename AccT = typename DefaultAccT::type> -struct GEMVKernel { - using acc_type = AccT; - - MLX_MTL_CONST int threadsM = BM * SM; - MLX_MTL_CONST int threadsN = BN * SN; - - MLX_MTL_CONST int blockM = threadsM * TM; - MLX_MTL_CONST int blockN = threadsN * TN; - - static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); - - static_assert( - SN == 4 || SN == 8 || SN == 16 || SN == 32, - "gemv block must have a width of 4, 8, 16, or 32"); - - // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up - // into blocks of (blockM, blockN) divided among threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each threadgroup has (threadsN, threadsM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM rows - // and the corresponding scalar from the vector - // 2. The thread then multiplies and adds to accumulate its local result for - // the block - // 3. At the end, each thread has accumulated results over all blocks across - // the rows. These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated blockM outputs - // - // Edge case handling: - // - The threadgroup with the largest tid has blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results - // remain zero) - // * The last thread that partially overlaps with the matrix is shifted - // inwards such that the thread block fits exactly in the matrix - - MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; - MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; - - template - static METAL_FUNC void - load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = static_cast(src[src_offset + tn]); - } - } - - template - static METAL_FUNC void load_safe( - const device T* src, - thread U dst[TN], - const int src_offset = 0, - const int src_size = TN) { - if (src_offset + TN <= src_size) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = static_cast(src[src_offset + tn]); - } - } else { // Edgecase - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = src_offset + tn < src_size - ? static_cast(src[src_offset + tn]) - : U(0); - } - } - } - - static METAL_FUNC void run( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& matrix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& bias_stride [[buffer(14)]], - threadgroup AccT* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - // Appease compiler - (void)lid; - - // Thread local accumulation results - thread AccT result[TM] = {0}; - thread T inter[TN]; - thread AccT v_coeff[TN]; - - const int thrM = SN != 32 ? simd_lid / SN : 0; - const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); - - const int sgN = BN != 1 ? (simd_gid % BN) : 0; - - const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); - const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; - - int bm = (simdM + thrM) * TM; - int bn = (simdN + thrN) * TN; - - // Block position - int out_row = tid.x * blockM + bm; - - // Exit simdgroup if rows out of bound - if (out_row >= out_vec_size) - return; - - // Adjust tail simdgroup to ensure in bound reads - out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; - - // Advance matrix - mat += out_row * matrix_ld; - - constexpr const uniform loop_stride = make_uniform(blockN); - const uniform in_size = make_uniform(in_vec_size); - const uniform n_iter = in_size / loop_stride; - const uniform last_iter = loop_stride * n_iter; - const uniform leftover = in_size - last_iter; - - // Loop over in_vec in blocks of blockN - for (int i = 0; i < n_iter; ++i) { - load_unsafe(in_vec, v_coeff, bn); - - // Per thread work loop - int mat_offset = 0; - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_unsafe(mat, inter, mat_offset + bn); - - // Accumulate results - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; - } - - mat_offset += matrix_ld; - } - - bn += blockN; - } - - if (leftover > 0) { - load_safe(in_vec, v_coeff, bn, in_size); - - // Per thread work loop - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_safe(&mat[tm * matrix_ld], inter, bn, in_size); - - // Accumulate results - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; - } - } - } - - // Simdgroup accumulations - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - MLX_MTL_PRAGMA_UNROLL - for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { - result[tm] += simd_shuffle_down(result[tm], sn); - } - } - - // Threadgroup accumulation results - if (needs_tgp_reduction) { - threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; - if (thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - tgp_results[tm] = result[tm]; - } - - threadgroup_barrier(mem_flags::mem_none); - - if (sgN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int sgn = 1; sgn < BN; sgn++) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - result[tm] += tgp_results[sgn * (blockM + TM) + tm]; - } - } - } - } - } - - // Write outputs - if (simdN == 0 && thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - if (kDoAxpby) { - out_vec[out_row + tm] = - static_cast(alpha) * static_cast(result[tm]) + - static_cast(beta) * bias[(out_row + tm) * bias_stride]; - } else { - out_vec[out_row + tm] = static_cast(result[tm]); - } - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -/// Vector matrix multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoAxpby, /* Do out = alpha * out + beta * bias */ - typename AccT = typename DefaultAccT::type> -struct GEMVTKernel { - using acc_type = AccT; - - MLX_MTL_CONST int threadsM = BM * SM; - MLX_MTL_CONST int threadsN = BN * SN; - - MLX_MTL_CONST int blockM = threadsM * TM; - MLX_MTL_CONST int blockN = threadsN * TN; - - static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); - - // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up - // into blocks of (blockM, blockN) divided among threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each threadgroup has (threadsN, threadsM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM contiguous rows - // and the corresponding scalar from the vector - // 2. The thread then accumulates its local result for the block - // 3. At the end, each thread has accumulated results over all blocks across - // the rows. These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated BN * TN outputs - // - // Edge case handling: - // - The threadgroup with the largest tid has blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results - // remain zero) - // * The last thread that partially overlaps with the matrix is shifted - // inwards such that the thread block fits exactly in the matrix - - MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; - MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; - - static METAL_FUNC void run( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& bias_stride [[buffer(14)]], - threadgroup AccT* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - // Appease compiler - (void)lid; - - // Thread local accumulation results - AccT result[TN] = {0}; - T inter[TN]; - AccT v_coeff[TM]; - const int thrM = SN != 32 ? simd_lid / SN : 0; - const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); - - const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); - const int sgN = BN != 1 ? (simd_gid % BN) : 0; - - const int simdM = SM * sgM; - const int simdN = SN * sgN; - - int cm = (simdM + thrM); - int cn = (simdN + thrN); - - int bm = cm * TM; - int bn = cn * TN; - - int out_col = tid.x * blockN + bn; - - constexpr const uniform loop_stride = make_uniform(blockM); - const uniform in_size = make_uniform(in_vec_size); - const uniform n_iter = in_size / loop_stride; - const uniform last_iter = loop_stride * n_iter; - const uniform leftover = in_size - last_iter; - - // Edgecase handling - if (out_col < out_vec_size) { - out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN; - - // Per thread accumulation main loop - for (int i = 0; i < n_iter; ++i) { - // Adding a threadgroup_barrier improves performance slightly - // This is possibly it may help exploit cache better - threadgroup_barrier(mem_flags::mem_none); - - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); - } - - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - auto vc = static_cast(v_coeff[tm]); - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } - for (int tn = 0; tn < TN; tn++) { - result[tn] += vc * inter[tn]; - } - } - - bm += blockM; - } - - if (leftover > 0) { - for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); - - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } - - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } - } - } - } - - // Simdgroup accumulations - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - MLX_MTL_PRAGMA_UNROLL - for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { - result[tn] += simd_shuffle_down(result[tn], SN * sm); - } - } - - // Threadgroup accumulation results - if (needs_tgp_reduction) { - threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; - if (thrM == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - tgp_results[tn] = result[tn]; - } - - threadgroup_barrier(mem_flags::mem_none); - - if (sgM == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int sgm = 1; sgm < BM; sgm++) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += tgp_results[sgm * (blockN + TN) + tn]; - } - } - } - } - } - - // Threadgroup accumulation and writing out results - if (cm == 0 && out_col < out_vec_size) { - MLX_MTL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - if (kDoAxpby) { - out_vec[out_col + j] = - static_cast(alpha) * static_cast(result[j]) + - static_cast(beta) * bias[(out_col + j) * bias_stride]; - } else { - out_vec[out_col + j] = static_cast(result[j]); - } - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -/// Matrix vector multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoNCBatch, /* Batch ndim > 1 */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* vector_batch_stride [[buffer(11)]], - const constant int64_t* matrix_batch_stride [[buffer(12)]], - const constant int64_t* bias_batch_stride [[buffer(13)]], - const constant int& bias_stride [[buffer(14)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVKernel; - threadgroup typename gemv_kernel::acc_type tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - // Update batch offsets - if (kDoNCBatch) { - in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); - mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - - if (kDoAxpby) { - bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); - } - - } else { - in_vec += tid.z * vector_batch_stride[0]; - mat += tid.z * matrix_batch_stride[0]; - - if (kDoAxpby) { - bias += tid.z * bias_batch_stride[0]; - } - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - bias_stride, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - -#define instantiate_gemv_helper( \ - name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ - instantiate_kernel( \ - "gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \ - "_tn" #tn "_nc" #nc "_axpby" #axpby, \ - gemv, \ - itype, \ - bm, \ - bn, \ - sm, \ - sn, \ - tm, \ - tn, \ - nc, \ - axpby) - -// clang-format off -#define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \ - instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \ - instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \ - instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \ - instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on - -// clang-format off -#define instantiate_gemv_blocks(name, itype) \ - instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \ - instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \ - instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \ - instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \ - instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \ - instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \ - instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on - -instantiate_gemv_blocks(float32, float); -instantiate_gemv_blocks(float16, half); -instantiate_gemv_blocks(bfloat16, bfloat16_t); -instantiate_gemv_blocks(complex64, complex64_t); - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_gather( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* index_batch_strides [[buffer(11)]], - const constant int& vector_batch_ndim [[buffer(12)]], - const constant int* vector_batch_shape [[buffer(13)]], - const constant int64_t* vector_batch_stride [[buffer(14)]], - const constant int& matrix_batch_ndim [[buffer(15)]], - const constant int* matrix_batch_shape [[buffer(16)]], - const constant int64_t* matrix_batch_stride [[buffer(17)]], - const constant uint32_t* vec_indices [[buffer(18)]], - const constant uint32_t* mat_indices [[buffer(19)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVKernel; - threadgroup typename gemv_kernel::acc_type tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - uint32_t indx_vec; - uint32_t indx_mat; - - // Update batch offsets - if (batch_ndim > 1) { - const constant auto* veci_bstrides = index_batch_strides; - const constant auto* mati_bstrides = index_batch_strides + batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); - - indx_vec = vec_indices[batch_offsets.x]; - indx_mat = mat_indices[batch_offsets.y]; - - } else { - indx_vec = vec_indices[index_batch_strides[0] * tid.z]; - indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; - } - - if (vector_batch_ndim > 1) { - in_vec += elem_to_loc( - indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); - } else { - in_vec += indx_vec * vector_batch_stride[0]; - } - - if (matrix_batch_ndim > 1) { - mat += elem_to_loc( - indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); - } else { - mat += indx_mat * matrix_batch_stride[0]; - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - batch_ndim, // Not used - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - -// clang-format off -#define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \ - instantiate_kernel( \ - "gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ - "_sn" #sn "_tm" #tm "_tn" #tn, \ - gemv_gather, itype, bm, bn, sm, sn, tm, tn) - -#define instantiate_gemv_bs_blocks(name, itype) \ - instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \ - instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \ - instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on - -instantiate_gemv_bs_blocks(float32, float); -instantiate_gemv_bs_blocks(float16, half); -instantiate_gemv_bs_blocks(bfloat16, bfloat16_t); -instantiate_gemv_bs_blocks(complex64, complex64_t); - -/////////////////////////////////////////////////////////////////////////////// -/// Vector matrix multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoNCBatch, /* Batch ndim > 1 */ - const bool kDoAxpby> /* Do out = alpha * out + beta * bias */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* vector_batch_stride [[buffer(11)]], - const constant int64_t* matrix_batch_stride [[buffer(12)]], - const constant int64_t* bias_batch_stride [[buffer(13)]], - const constant int& bias_stride [[buffer(14)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; - threadgroup typename gemv_kernel::acc_type tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - // Update batch offsets - if (kDoNCBatch) { - in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); - mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - - if (kDoAxpby) { - bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim); - } - - } else { - in_vec += tid.z * vector_batch_stride[0]; - mat += tid.z * matrix_batch_stride[0]; - - if (kDoAxpby) { - bias += tid.z * bias_batch_stride[0]; - } - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - bias_stride, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - -// clang-format off -#define instantiate_gemv_t_helper( \ - name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \ - instantiate_kernel( \ - "gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \ - "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \ - gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby) - -#define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \ - instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \ - instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \ - instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \ - instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on - -// clang-format off -#define instantiate_gemv_t_blocks(name, itype) \ - instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 1) \ - instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \ - instantiate_gemv_t(name, itype, 1, 4, 8, 4, 4, 4) \ - instantiate_gemv_t(name, itype, 1, 16, 8, 4, 4, 4) \ - instantiate_gemv_t(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on - -// clang-format off -instantiate_gemv_t_blocks(float32, float); -instantiate_gemv_t_blocks(float16, half); -instantiate_gemv_t_blocks(bfloat16, bfloat16_t); -instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on - -template < - typename T, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN> /* Thread cols (in elements) */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_gather( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - const device T* bias [[buffer(2)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant float& alpha [[buffer(7)]], - const constant float& beta [[buffer(8)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* index_batch_strides [[buffer(11)]], - const constant int& vector_batch_ndim [[buffer(12)]], - const constant int* vector_batch_shape [[buffer(13)]], - const constant int64_t* vector_batch_stride [[buffer(14)]], - const constant int& matrix_batch_ndim [[buffer(15)]], - const constant int* matrix_batch_shape [[buffer(16)]], - const constant int64_t* matrix_batch_stride [[buffer(17)]], - const constant uint32_t* vec_indices [[buffer(18)]], - const constant uint32_t* mat_indices [[buffer(19)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = GEMVTKernel; - threadgroup typename gemv_kernel::acc_type tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - uint32_t indx_vec; - uint32_t indx_mat; - - // Update batch offsets - if (batch_ndim > 1) { - const constant auto* veci_bstrides = index_batch_strides; - const constant auto* mati_bstrides = index_batch_strides + batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim); - - indx_vec = vec_indices[batch_offsets.x]; - indx_mat = mat_indices[batch_offsets.y]; - - } else { - indx_vec = vec_indices[index_batch_strides[0] * tid.z]; - indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z]; - } - - if (vector_batch_ndim > 1) { - in_vec += elem_to_loc( - indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim); - } else { - in_vec += indx_vec * vector_batch_stride[0]; - } - - if (matrix_batch_ndim > 1) { - mat += elem_to_loc( - indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim); - } else { - mat += indx_mat * matrix_batch_stride[0]; - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - bias, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - alpha, - beta, - batch_ndim, // Not used, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - -// clang-format off -#define instantiate_gemv_t_bs_helper( \ - nm, itype, bm, bn, sm, sn, tm, tn) \ - instantiate_kernel( \ - "gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \ - "_sn" #sn "_tm" #tm "_tn" #tn, \ - gemv_t_gather, itype, bm, bn, sm, sn, tm, tn) - -#define instantiate_gemv_t_bs_blocks(name, itype) \ - instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \ - instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \ - instantiate_gemv_t_bs_helper(name, itype, 1, 4, 8, 4, 4, 4) \ - instantiate_gemv_t_bs_helper(name, itype, 1, 16, 8, 4, 4, 4) \ - instantiate_gemv_t_bs_helper(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on - -// clang-format off -instantiate_gemv_t_bs_blocks(float32, float); -instantiate_gemv_t_bs_blocks(float16, half); -instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t); -instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/gemv_masked.h b/Source/Cmlx/mlx-generated/metal/gemv_masked.h deleted file mode 100644 index 9d4fac23..00000000 --- a/Source/Cmlx/mlx-generated/metal/gemv_masked.h +++ /dev/null @@ -1,827 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include "steel/utils.h" - -using namespace metal; - -#define MLX_MTL_CONST static constant constexpr const -#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") - -struct _NoMask { - char x; - - constexpr METAL_FUNC operator bool() { - return true; - } - constexpr METAL_FUNC operator bool() const threadgroup { - return true; - } - constexpr METAL_FUNC operator bool() const device { - return true; - } - constexpr METAL_FUNC operator bool() const constant { - return true; - } -}; - -typedef struct _NoMask nomask_t; - -template -struct ScaleOp { - OutT scale; - - METAL_FUNC OutT apply(InT x) const { - return static_cast(x) * scale; - } -}; - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - typename AccT = float> -struct GEMVKernel { - MLX_MTL_CONST int threadsM = BM * SM; - MLX_MTL_CONST int threadsN = BN * SN; - - MLX_MTL_CONST int blockM = threadsM * TM; - MLX_MTL_CONST int blockN = threadsN * TN; - - static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); - - static_assert( - SN == 8 || SN == 16 || SN == 32, - "gemv block must have a width of 8, 16, or 32"); - - static_assert(blockN >= blockM, "Masked gemv must have blockN >= blockM"); - - MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; - MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; - - MLX_MTL_CONST bool has_mul_operand_mask = - has_operand_mask && !metal::is_same_v; - MLX_MTL_CONST bool has_mul_output_mask = - has_output_mask && !metal::is_same_v; - - // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up - // into blocks of (blockM, blockN) divided among threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each threadgroup has (threadsN, threadsM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM rows - // and the corresponding scalar from the vector - // 2. The thread then multiplies and adds to accumulate its local result for - // the block - // 3. At the end, each thread has accumulated results over all blocks across - // the rows. These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated blockM outputs - // - // Edge case handling: - // - The threadgroup with the largest tid has blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results - // remain zero) - // * The last thread that partially overlaps with the matrix is shifted - // inwards such that the thread block fits exactly in the matrix - - MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0; - MLX_MTL_CONST bool needs_tgp_reduction = BN > 1; - - template - static METAL_FUNC void - load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = static_cast(src[src_offset + tn]); - } - } - - template - static METAL_FUNC void load_safe( - const device T* src, - thread U dst[TN], - const int src_offset = 0, - const int src_size = TN) { - if (src_offset + TN <= src_size) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = static_cast(src[src_offset + tn]); - } - } else { // Edgecase - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - dst[tn] = src_offset + tn < src_size - ? static_cast(src[src_offset + tn]) - : U(0); - } - } - } - - static METAL_FUNC void run( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& matrix_ld [[buffer(6)]], - const device out_mask_t* out_mask [[buffer(20)]], - const device op_mask_t* mat_mask [[buffer(21)]], - const device op_mask_t* vec_mask [[buffer(22)]], - const constant int* mask_strides [[buffer(23)]], - threadgroup AccT* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - // Appease compiler - (void)lid; - - // Thread local accumulation results - thread AccT result[TM] = {0}; - thread T inter[TN]; - thread AccT v_coeff[TN]; - - const int thrM = SN != 32 ? simd_lid / SN : 0; - const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); - - const int sgN = BN != 1 ? (simd_gid % BN) : 0; - - const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid); - const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0; - - int bm = (simdM + thrM) * TM; - int bn = (simdN + thrN) * TN; - - // Block position - int out_row = tid.x * blockM + bm; - - // Exit simdgroup if rows out of bound - if (out_row >= out_vec_size) - return; - - // Adjust tail simdgroup to ensure in bound reads - out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM; - - // Prepare mask offsets - const constant int* out_mask_strides = mask_strides; - const constant int* mat_mask_strides = - mask_strides + (has_output_mask ? 2 : 0); - const constant int* vec_mask_strides = - mat_mask_strides + (has_operand_mask ? 2 : 0); - - const int m_block_idx = blockN > blockM ? out_row / blockN : int(tid.x); - - const int out_mask_offset = - !has_output_mask ? 0 : m_block_idx * out_mask_strides[1]; - - int mat_mask_offset = - !has_operand_mask ? 0 : m_block_idx * mat_mask_strides[1]; - int vec_mask_offset = 0; - const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[0]; - const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[1]; - - T out_scale{1}; - - // Check output mask - if (has_output_mask) { - auto mask_out = out_mask[out_mask_offset]; - - // Write zeros and return if mask is 0 - if (!mask_out) { - if (simdN == 0 && thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - out_vec[out_row + tm] = T(0.); - } - } - - return; - } - - // Store scalar if multiplicative mask - if (has_mul_output_mask) { - out_scale = T(mask_out); - } - } - - // Advance matrix - mat += out_row * matrix_ld; - - // Prepare for loop - constexpr const uniform loop_stride = make_uniform(blockN); - const uniform in_size = make_uniform(in_vec_size); - const uniform n_iter = in_size / loop_stride; - const uniform last_iter = loop_stride * n_iter; - const uniform leftover = in_size - last_iter; - - // Loop over in_vec in blocks of blockN - for (int i = 0; i < n_iter; ++i) { - if (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset]))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - load_unsafe(in_vec, v_coeff, bn); - - // Apply scale - if (has_mul_operand_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - v_coeff[tn] *= block_scale; - } - } - - // Per thread work loop - int mat_offset = 0; - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_unsafe(mat, inter, mat_offset + bn); - - // Accumulate results - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; - } - - mat_offset += matrix_ld; - } - } - - bn += blockN; - mat_mask_offset += mat_mask_step; - vec_mask_offset += vec_mask_step; - } - - if (leftover > 0) { - if (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset]))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - load_safe(in_vec, v_coeff, bn, in_size); - - // Apply scale - if (has_mul_operand_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - v_coeff[tn] *= block_scale; - } - } - - // Per thread work loop - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - // Load for the row - load_safe(&mat[tm * matrix_ld], inter, bn, in_size); - - // Accumulate results - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tm] += inter[tn] * v_coeff[tn]; - } - } - } - } - - // Apply out scale - if (has_mul_output_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - result[tm] *= out_scale; - } - } - - // Simdgroup accumulations - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - MLX_MTL_PRAGMA_UNROLL - for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { - result[tm] += simd_shuffle_down(result[tm], sn); - } - } - - // Threadgroup accumulation results - if (needs_tgp_reduction) { - threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm; - if (thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - tgp_results[tm] = result[tm]; - } - - threadgroup_barrier(mem_flags::mem_none); - - if (sgN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int sgn = 1; sgn < BN; sgn++) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - result[tm] += tgp_results[sgn * (blockM + TM) + tm]; - } - } - } - } - } - - // Write outputs - if (simdN == 0 && thrN == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - out_vec[out_row + tm] = static_cast(result[tm]); - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -/// Vector matrix multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - typename AccT = float> -struct GEMVTKernel { - MLX_MTL_CONST int threadsM = BM * SM; - MLX_MTL_CONST int threadsN = BN * SN; - - MLX_MTL_CONST int blockM = threadsM * TM; - MLX_MTL_CONST int blockN = threadsN * TN; - - static_assert(SM * SN == 32, "simdgroup can only have 32 threads"); - - MLX_MTL_CONST bool has_operand_mask = !metal::is_same_v; - MLX_MTL_CONST bool has_output_mask = !metal::is_same_v; - - MLX_MTL_CONST bool has_mul_operand_mask = - has_operand_mask && !metal::is_same_v; - MLX_MTL_CONST bool has_mul_output_mask = - has_output_mask && !metal::is_same_v; - - // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up - // into blocks of (blockM, blockN) divided among threadgroups - // - Every thread works on a block of (TM, TN) - // - We assume each threadgroup has (threadsN, threadsM, 1) threads - // - // 1. A thread loads TN elements each from mat along TM contiguous rows - // and the corresponding scalar from the vector - // 2. The thread then accumulates its local result for the block - // 3. At the end, each thread has accumulated results over all blocks across - // the rows. These are then summed up across the threadgroup - // 4. Each threadgroup writes its accumulated BN * TN outputs - // - // Edge case handling: - // - The threadgroup with the largest tid has blocks that exceed the matrix - // * The blocks that start outside the matrix are never read (thread results - // remain zero) - // * The last thread that partially overlaps with the matrix is shifted - // inwards such that the thread block fits exactly in the matrix - - MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0; - MLX_MTL_CONST bool needs_tgp_reduction = BM > 1; - - static METAL_FUNC void run( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const device out_mask_t* out_mask [[buffer(20)]], - const device op_mask_t* mat_mask [[buffer(21)]], - const device op_mask_t* vec_mask [[buffer(22)]], - const constant int* mask_strides [[buffer(23)]], - threadgroup AccT* tgp_memory [[threadgroup(0)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - // Appease compiler - (void)lid; - - // Thread local accumulation results - AccT result[TN] = {0}; - T inter[TN]; - AccT v_coeff[TM]; - - const int thrM = SN != 32 ? simd_lid / SN : 0; - const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid); - - const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid); - const int sgN = BN != 1 ? (simd_gid % BN) : 0; - - const int simdM = SM * sgM; - const int simdN = SN * sgN; - - int cm = (simdM + thrM); - int cn = (simdN + thrN); - - int bm = cm * TM; - int bn = cn * TN; - - int out_col = tid.x * blockN + bn; - - // Prepare mask offsets - const constant int* out_mask_strides = mask_strides; - const constant int* mat_mask_strides = - out_mask_strides + (has_output_mask ? 2 : 0); - const constant int* vec_mask_strides = - mat_mask_strides + (has_operand_mask ? 2 : 0); - - const int n_block_idx = blockM > blockN ? out_col / blockM : int(tid.x); - - const int out_mask_offset = - !has_output_mask ? 0 : n_block_idx; // * out_mask_strides[0]; - - int mat_mask_offset = - !has_operand_mask ? 0 : n_block_idx * mat_mask_strides[0]; - int vec_mask_offset = 0; - const int mat_mask_step = !has_operand_mask ? 0 : mat_mask_strides[1]; - const int vec_mask_step = !has_operand_mask ? 0 : vec_mask_strides[0]; - - T out_scale{1}; - - // Check output mask - if (has_output_mask) { - auto mask_out = out_mask[out_mask_offset]; - - // Write zeros and return if mask is 0 - if (!mask_out) { - if (cm == 0 && out_col < out_vec_size) { - if (out_col + TN <= out_vec_size) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - out_vec[out_col + tn] = T(0.); - } - } else { - for (int tn = 0; tn < TN && (out_col + tn) < out_vec_size; tn++) { - out_vec[out_col + tn] = T(0.); - } - } - } - - return; - } - - // Store scalar if multiplicative mask - if (has_mul_output_mask) { - out_scale = T(mask_out); - } - } - - // Prepare for loop - constexpr const uniform loop_stride = make_uniform(blockM); - const uniform in_size = make_uniform(in_vec_size); - const uniform n_iter = in_size / loop_stride; - const uniform last_iter = loop_stride * n_iter; - const uniform leftover = in_size - last_iter; - - // Edgecase handling - if (out_col < out_vec_size) { - out_col = (out_col + TN) <= out_vec_size ? out_col : out_vec_size - TN; - - // Per thread accumulation main loop - for (int i = 0; i < n_iter; ++i) { - // Adding a threadgroup_barrier improves performance slightly - // This is possibly it may help exploit cache better - threadgroup_barrier(mem_flags::mem_none); - - if (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset]))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); - } - - // Apply scale - if (has_mul_operand_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - v_coeff[tm] *= block_scale; - } - } - - MLX_MTL_PRAGMA_UNROLL - for (int tm = 0; tm < TM; tm++) { - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } - } - } - - bm += blockM; - mat_mask_offset += mat_mask_step; - vec_mask_offset += vec_mask_step; - } - - if (leftover > 0) { - if (!has_operand_mask || - (bool(mat_mask[mat_mask_offset]) && - bool(vec_mask[vec_mask_offset]))) { - T block_scale{1}; - if (has_mul_operand_mask) { - block_scale = - T(mat_mask[mat_mask_offset]) * T(vec_mask[vec_mask_offset]); - } - - for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) { - v_coeff[tm] = static_cast(in_vec[bm + tm]); - - if (has_mul_operand_mask) { - v_coeff[tm] *= block_scale; - } - - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn]; - } - - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += v_coeff[tm] * inter[tn]; - } - } - } - } - } - - // Apply out scale - if (has_mul_output_mask) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] *= out_scale; - } - } - - // Simdgroup accumulations - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - MLX_MTL_PRAGMA_UNROLL - for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) { - result[tn] += simd_shuffle_down(result[tn], SN * sm); - } - } - - // Threadgroup accumulation results - if (needs_tgp_reduction) { - threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn; - if (thrM == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - tgp_results[tn] = result[tn]; - } - - threadgroup_barrier(mem_flags::mem_none); - - if (sgM == 0) { - MLX_MTL_PRAGMA_UNROLL - for (int sgm = 1; sgm < BM; sgm++) { - MLX_MTL_PRAGMA_UNROLL - for (int tn = 0; tn < TN; tn++) { - result[tn] += tgp_results[sgm * (blockN + TN) + tn]; - } - } - } - } - } - - // Threadgroup accumulation and writing out results - if (cm == 0 && out_col < out_vec_size) { - MLX_MTL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - out_vec[out_col + j] = static_cast(result[j]); - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -/// Matrix vector multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoNCBatch> /* Batch ndim > 1 */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_masked( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* vector_batch_stride [[buffer(11)]], - const constant int64_t* matrix_batch_stride [[buffer(12)]], - const device out_mask_t* out_mask [[buffer(20)]], - const device op_mask_t* mat_mask [[buffer(21)]], - const device op_mask_t* vec_mask [[buffer(22)]], - const constant int* mask_strides [[buffer(23)]], - const constant int64_t* mask_batch_strides [[buffer(24)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = - GEMVKernel; - threadgroup float tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - constexpr bool has_operand_mask = !metal::is_same_v; - constexpr bool has_output_mask = !metal::is_same_v; - - // Update batch offsets - if (kDoNCBatch) { - in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); - mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - - if (has_output_mask) { - out_mask += - elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); - mask_batch_strides += batch_ndim; - } - - if (has_operand_mask) { - const constant auto* mask_strides_mat = mask_batch_strides; - const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); - - mat_mask += batch_offsets.x; - vec_mask += batch_offsets.y; - } - - } else { - in_vec += tid.z * vector_batch_stride[0]; - mat += tid.z * matrix_batch_stride[0]; - - if (has_output_mask) { - out_mask += tid.z * mask_batch_strides[0]; - mask_batch_strides += batch_ndim; - } - - if (has_operand_mask) { - mat_mask += tid.z * mask_batch_strides[0]; - vec_mask += tid.z * mask_batch_strides[batch_ndim]; - } - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - out_mask, - mat_mask, - vec_mask, - mask_strides, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} - -/////////////////////////////////////////////////////////////////////////////// -/// Vector matrix multiplication -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - const int BM, /* Threadgroup rows (in simdgroups) */ - const int BN, /* Threadgroup cols (in simdgroups) */ - const int SM, /* Simdgroup rows (in threads) */ - const int SN, /* Simdgroup cols (in threads) */ - const int TM, /* Thread rows (in elements) */ - const int TN, /* Thread cols (in elements) */ - const bool kDoNCBatch> /* Batch ndim > 1 */ -[[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_masked( - const device T* mat [[buffer(0)]], - const device T* in_vec [[buffer(1)]], - device T* out_vec [[buffer(3)]], - const constant int& in_vec_size [[buffer(4)]], - const constant int& out_vec_size [[buffer(5)]], - const constant int& marix_ld [[buffer(6)]], - const constant int& batch_ndim [[buffer(9)]], - const constant int* batch_shape [[buffer(10)]], - const constant int64_t* vector_batch_stride [[buffer(11)]], - const constant int64_t* matrix_batch_stride [[buffer(12)]], - const device out_mask_t* out_mask [[buffer(20)]], - const device op_mask_t* mat_mask [[buffer(21)]], - const device op_mask_t* vec_mask [[buffer(22)]], - const constant int* mask_strides [[buffer(23)]], - const constant int64_t* mask_batch_strides [[buffer(24)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using gemv_kernel = - GEMVTKernel; - threadgroup float tgp_memory - [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size]; - - constexpr bool has_operand_mask = !metal::is_same_v; - constexpr bool has_output_mask = !metal::is_same_v; - - // Update batch offsets - if (kDoNCBatch) { - in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim); - mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim); - - if (has_output_mask) { - out_mask += - elem_to_loc(tid.z, batch_shape, mask_batch_strides, batch_ndim); - mask_batch_strides += batch_ndim; - } - - if (has_operand_mask) { - const constant auto* mask_strides_mat = mask_batch_strides; - const constant auto* mask_strides_vec = mask_strides_mat + batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, mask_strides_mat, mask_strides_vec, batch_ndim); - - mat_mask += batch_offsets.x; - vec_mask += batch_offsets.y; - } - - } else { - in_vec += tid.z * vector_batch_stride[0]; - mat += tid.z * matrix_batch_stride[0]; - - if (has_output_mask) { - out_mask += tid.z * mask_batch_strides[0]; - mask_batch_strides += batch_ndim; - } - - if (has_operand_mask) { - mat_mask += tid.z * mask_batch_strides[0]; - vec_mask += tid.z * mask_batch_strides[batch_ndim]; - } - } - - out_vec += tid.z * out_vec_size; - - gemv_kernel::run( - mat, - in_vec, - out_vec, - in_vec_size, - out_vec_size, - marix_ld, - out_mask, - mat_mask, - vec_mask, - mask_strides, - gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory, - tid, - lid, - simd_gid, - simd_lid); -} diff --git a/Source/Cmlx/mlx-generated/metal/hadamard.h b/Source/Cmlx/mlx-generated/metal/hadamard.h deleted file mode 100644 index d6c08f17..00000000 --- a/Source/Cmlx/mlx-generated/metal/hadamard.h +++ /dev/null @@ -1,182 +0,0 @@ -// Copyright © 2024 Apple Inc. -#include -#include - -#include "steel/defines.h" - -using namespace metal; - -// Thread local Hadamard transform for 2^R -template -METAL_FUNC void radix_func(thread float* x) { - constexpr short logR = __builtin_ctz(R); - short h = 1; - STEEL_PRAGMA_UNROLL - for (short s = 0; s < logR; s++) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < R / 2; i++) { - short k = i & (h - 1); - short j = ((i - k) << 1) + k; - float a = x[j]; - float b = x[j + h]; - x[j] = a + b; - x[j + h] = a - b; - } - h <<= 1; - } -} - -template -[[kernel]] void hadamard_n( - const device T* in [[buffer(0)]], - device T* out [[buffer(1)]], - constant const float& scale, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - // Compute a Hadamard transform of size N = 2^k - // - // Equivalent to: - // from scipy.linalg import hadamard - // y = hadamard(len(x)) @ x - - constexpr short num_threads = N / max_radix; - constexpr short logN = __builtin_ctz(N); - constexpr short logR = __builtin_ctz(max_radix); - constexpr short num_steps = logN / logR; - constexpr short logFinal = logN % logR; - constexpr short final_radix = 1 << (logFinal); - - int batch_idx = elem.y * N * stride + elem.z; - short i = elem.x; - - threadgroup T buf[N]; - - // Read values from device - if (stride == 1) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; - STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; - } - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix; j++) { - buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - float x[max_radix]; - short h = 1; - - STEEL_PRAGMA_UNROLL - for (short s = 0; s < num_steps; s++) { - short k = i & (h - 1); - short j = ((i - k) << logR) + k; - - STEEL_PRAGMA_UNROLL - for (short r = 0; r < max_radix; r++) { - x[r] = buf[j + h * r]; - } - - radix_func(x); - - STEEL_PRAGMA_UNROLL - for (short r = 0; r < max_radix; r++) { - buf[j + h * r] = T(x[r]); - } - - h <<= logR; - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Do the final radix - // e.g. max_radix = 16 - // N = 1024 = 16 * 16 * 4 - if (final_radix > 1) { - // Each thread does multiple butterflies - STEEL_PRAGMA_UNROLL - for (int t = 0; t < max_radix / final_radix; t++) { - short index = i + t * num_threads; - short k = index & (h - 1); - short j = ((index - k) << logFinal) + k; - STEEL_PRAGMA_UNROLL - for (short r = 0; r < final_radix; r++) { - x[r] = buf[j + h * r]; - } - - radix_func(x); - - STEEL_PRAGMA_UNROLL - for (short r = 0; r < final_radix; r++) { - buf[j + h * r] = T(x[r]); - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // Write values to device - if (stride == 1) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; - STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); - } - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix; j++) { - out[batch_idx + (j * num_threads + i) * stride] = - buf[j * num_threads + i]; - } - } -} - -template -[[kernel]] void hadamard_m( - const device T* in [[buffer(0)]], - device T* out [[buffer(1)]], - constant const float& scale, - uint3 elem [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - // Compute a Hadamard transform of size M - // using a naive O(M^2) codelet. - // - // This kernel is the second stage in the computation - // of a Hadamard transform of size M*N where N = 2^k. - - int index = elem.x * grid.y + elem.y; - short i = index % (N / read_width); - int batch_idx = index / (N / read_width) * M * N; - - float x[read_width][M]; - STEEL_PRAGMA_UNROLL - for (short c = 0; c < M; c++) { - STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - x[r][c] = in[batch_idx + c * N + i * read_width + r]; - } - } - - STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - // This function is JIT compiled for M - // using the Hadamard matrix strings in `metal/hadamard.cpp` - hadamard_radix_m(x[r]); - } - - // Write back to device - STEEL_PRAGMA_UNROLL - for (short c = 0; c < M; c++) { - STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - out[batch_idx + c * N + i * read_width + r] = T(x[r][c] * scale); - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/gather.h b/Source/Cmlx/mlx-generated/metal/indexing/gather.h deleted file mode 100644 index d99c46c6..00000000 --- a/Source/Cmlx/mlx-generated/metal/indexing/gather.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../indexing/indexing.h" - -template -METAL_FUNC void gather_impl( - const device T* src [[buffer(0)]], - device T* out [[buffer(1)]], - const constant int* src_shape [[buffer(2)]], - const constant int64_t* src_strides [[buffer(3)]], - const constant size_t& src_ndim [[buffer(4)]], - const constant int* slice_sizes [[buffer(5)]], - const constant int* axes [[buffer(6)]], - const thread Indices& indices, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - LocT src_idx = 0; - for (int i = 0; i < NIDX; ++i) { - LocT idx_loc; - if (IDX_NDIM == 0) { - idx_loc = 0; - } else if (IDX_NDIM == 1) { - idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); - } else { - idx_loc = index.x * static_cast(indices.strides[indices.ndim * i]); - idx_loc += indices.row_contiguous[i] - ? index.y - : elem_to_loc( - index.y, - &indices.shapes[indices.ndim * i + 1], - &indices.strides[indices.ndim * i + 1], - indices.ndim - 1); - } - auto ax = axes[i]; - auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], src_shape[ax]); - src_idx += static_cast(idx_val) * static_cast(src_strides[ax]); - } - - auto src_offset = - elem_to_loc(index.z, slice_sizes, src_strides, src_ndim); - - LocT out_idx = index.z; - if (IDX_NDIM == 1) { - out_idx += static_cast(grid_dim.z) * index.x; - } else if (IDX_NDIM >= 2) { - out_idx += grid_dim.z * (index.x * static_cast(grid_dim.y) + index.y); - } - out[out_idx] = src[src_offset + src_idx]; -} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/gather_axis.h b/Source/Cmlx/mlx-generated/metal/indexing/gather_axis.h deleted file mode 100644 index bf490ade..00000000 --- a/Source/Cmlx/mlx-generated/metal/indexing/gather_axis.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -template -[[kernel]] void gather_axis( - const device T* src [[buffer(0)]], - const device IdxT* indices [[buffer(1)]], - device T* out [[buffer(2)]], - const constant int* shape [[buffer(3)]], - const constant int64_t* src_strides [[buffer(4)]], - const constant int64_t* idx_strides [[buffer(5)]], - const constant size_t& ndim [[buffer(6)]], - const constant int& axis [[buffer(7)]], - const constant int& axis_size [[buffer(8)]], - const constant size_t& src_ax_stride [[buffer(9)]], - const constant size_t& idx_ax_stride [[buffer(10)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - LocT elem_idx = index.z * static_cast(grid_dim.x); - LocT out_idx = elem_idx * grid_dim.y + index.x; - - LocT idx_loc = index.y * static_cast(idx_ax_stride); - if (IdxC) { - idx_loc += out_idx; - } else { - idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); - } - - auto idx_val = indices[idx_loc]; - if (is_signed_v) { - idx_val = (idx_val < 0) ? idx_val + axis_size : idx_val; - } - - LocT src_idx = idx_val * static_cast(src_ax_stride); - if (SrcC) { - src_idx += elem_idx * axis_size + index.x; - } else { - src_idx += elem_to_loc(elem_idx + index.x, shape, src_strides, ndim); - } - - out_idx += index.y * static_cast(grid_dim.x); - out[out_idx] = src[src_idx]; -} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/gather_front.h b/Source/Cmlx/mlx-generated/metal/indexing/gather_front.h deleted file mode 100644 index 2cd6eb41..00000000 --- a/Source/Cmlx/mlx-generated/metal/indexing/gather_front.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include "../indexing/indexing.h" - -template -[[kernel]] void gather_front( - const device T* src, - const device IdxT* indices, - device T* out, - const constant int64_t& stride, - const constant int& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto idx = offset_neg_idx(indices[index.y], size); - LocT src_idx = static_cast(stride) * idx; - LocT out_idx = static_cast(stride) * index.y; - - int s_idx = N * index.x; - for (int i = 0; i < N && s_idx < stride; ++i, ++s_idx) { - out[out_idx + s_idx] = src[src_idx + s_idx]; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/indexing.h b/Source/Cmlx/mlx-generated/metal/indexing/indexing.h deleted file mode 100644 index 2a4b4f92..00000000 --- a/Source/Cmlx/mlx-generated/metal/indexing/indexing.h +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#pragma once - -#include - -template -struct Indices { - const array buffers; - const constant int* shapes; - const constant int64_t* strides; - const constant bool* row_contiguous; - const int ndim; -}; - -template -METAL_FUNC size_t offset_neg_idx(IdxT idx, int size) { - if (is_unsigned_v) { - return idx; - } else { - return (idx < 0) ? idx + size : idx; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/masked_scatter.h b/Source/Cmlx/mlx-generated/metal/indexing/masked_scatter.h deleted file mode 100644 index 2ba54740..00000000 --- a/Source/Cmlx/mlx-generated/metal/indexing/masked_scatter.h +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -constant mlx::os_log logger("mlx", "masked_assign"); - -template -[[kernel]] void masked_assign_impl( - const device bool* mask [[buffer(0)]], - const device uint* scatter_offsets [[buffer(1)]], - const device T* src [[buffer(2)]], - device T* out [[buffer(3)]], - const constant int* src_shapes [[buffer(4)]], - const constant int64_t* src_strides [[buffer(5)]], - const constant int& src_ndim [[buffer(6)]], - const constant int64_t& src_batch_size [[buffer(7)]], - const constant int64_t& mask_batch_size [[buffer(8)]], - uint idx [[thread_position_in_grid]]) { - const bool mask_value = mask[idx]; - if (!mask_value) { - return; - } - - const uint src_index = scatter_offsets[idx]; - if (src_index >= src_batch_size) { - logger.log_debug("Out of bound read from src"); - return; - } - - const uint batch_idx = idx / mask_batch_size; - - if (src_contiguous) { - out[idx] = src[batch_idx * src_batch_size + src_index]; - } else { - out[idx] = src[elem_to_loc( - batch_idx * src_batch_size + src_index, - src_shapes, - src_strides, - src_ndim)]; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/scatter.h b/Source/Cmlx/mlx-generated/metal/indexing/scatter.h deleted file mode 100644 index 99e65d20..00000000 --- a/Source/Cmlx/mlx-generated/metal/indexing/scatter.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../indexing/indexing.h" - -template < - typename T, - typename IdxT, - typename Op, - int NIDX, - bool UPD_ROW_CONTIG, - int NWORK, - typename LocT> -METAL_FUNC void scatter_impl( - const device T* updates, - device mlx_atomic* out, - const constant int* upd_shape, - const constant int64_t* upd_strides, - const constant size_t& upd_ndim, - const constant size_t& upd_size, - const constant int* out_shape, - const constant int64_t* out_strides, - const constant size_t& out_ndim, - const constant int* axes, - const constant size_t& idx_size, - const thread Indices& indices, - uint2 gid [[thread_position_in_grid]]) { - Op op; - - auto ind_idx = gid.y * NWORK; - LocT out_offset = 0; - if (upd_size > 1) { - out_offset = elem_to_loc( - gid.x, upd_shape + indices.ndim, out_strides, out_ndim); - } - - for (int j = 0; j < NWORK && ind_idx < idx_size; ++j, ind_idx++) { - LocT out_idx = out_offset; - for (int i = 0; i < NIDX; ++i) { - auto idx_loc = indices.row_contiguous[i] - ? ind_idx - : elem_to_loc( - ind_idx, - &indices.shapes[indices.ndim * i], - &indices.strides[indices.ndim * i], - indices.ndim); - auto ax = axes[i]; - auto idx_val = offset_neg_idx(indices.buffers[i][idx_loc], out_shape[ax]); - out_idx += - static_cast(idx_val) * static_cast(out_strides[ax]); - } - auto upd_idx = ind_idx * static_cast(upd_size) + gid.x; - if constexpr (!UPD_ROW_CONTIG) { - upd_idx = elem_to_loc(upd_idx, upd_shape, upd_strides, upd_ndim); - } - op.atomic_update(out, updates[upd_idx], out_idx); - } -} diff --git a/Source/Cmlx/mlx-generated/metal/indexing/scatter_axis.h b/Source/Cmlx/mlx-generated/metal/indexing/scatter_axis.h deleted file mode 100644 index 73fd7ab4..00000000 --- a/Source/Cmlx/mlx-generated/metal/indexing/scatter_axis.h +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -template < - typename T, - typename IdxT, - typename LocT, - typename Op, - bool UpdC, - bool IdxC> -[[kernel]] void scatter_axis( - const device T* upd [[buffer(0)]], - const device IdxT* indices [[buffer(1)]], - device mlx_atomic* out [[buffer(2)]], - const constant int* shape [[buffer(3)]], - const constant int64_t* upd_strides [[buffer(4)]], - const constant int64_t* idx_strides [[buffer(5)]], - const constant size_t& ndim [[buffer(6)]], - const constant int& axis [[buffer(7)]], - const constant int& out_axis_size [[buffer(8)]], - const constant size_t& upd_ax_stride [[buffer(9)]], - const constant size_t& idx_ax_stride [[buffer(10)]], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - Op op; - - LocT elem_idx = index.z * static_cast(grid_dim.x); - - LocT idx_loc = index.y * static_cast(idx_ax_stride); - if (IdxC) { - idx_loc += elem_idx * grid_dim.y + index.x; - } else { - idx_loc += elem_to_loc(elem_idx + index.x, shape, idx_strides, ndim); - } - - auto idx_val = indices[idx_loc]; - if (is_signed_v) { - idx_val = (idx_val < 0) ? idx_val + out_axis_size : idx_val; - } - - LocT upd_idx = index.y * static_cast(upd_ax_stride); - if (UpdC) { - upd_idx += elem_idx * grid_dim.y + index.x; - } else { - upd_idx += elem_to_loc(elem_idx + index.x, shape, upd_strides, ndim); - } - - LocT out_idx = elem_idx * static_cast(out_axis_size) + - idx_val * grid_dim.x + index.x; - op.atomic_update(out, upd[upd_idx], out_idx); -} diff --git a/Source/Cmlx/mlx-generated/metal/layer_norm.metal b/Source/Cmlx/mlx-generated/metal/layer_norm.metal deleted file mode 100644 index e1c862c9..00000000 --- a/Source/Cmlx/mlx-generated/metal/layer_norm.metal +++ /dev/null @@ -1,433 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include -#include - -#include "utils.h" - -using namespace metal; - -constant bool has_w [[function_constant(20)]]; - -template -inline void initialize_buffer( - threadgroup float* xs, - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - if (simd_group_id == 0) { - for (int i = 0; i < N; i++) { - xs[N * simd_lane_id + i] = 0; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); -} - -template -inline void threadgroup_sum( - thread float* x, - threadgroup float* xs, - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - for (int i = 0; i < N; i++) { - x[i] = simd_sum(x[i]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_lane_id == 0) { - for (int i = 0; i < N; i++) { - xs[N * simd_group_id + i] = x[i]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N; i++) { - x[i] = xs[N * simd_lane_id + i]; - x[i] = simd_sum(x[i]); - } -} - -template -[[kernel]] void layer_norm_single_row( - const device T* x, - const device T* w, - const device T* b, - device T* out, - constant float& eps, - constant uint& axis_size, - constant uint& w_stride, - constant uint& b_stride, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; - - // Initialize the registers and threadgroup memory - float thread_x[N_READS] = {0}; - threadgroup float local_buffer[SIMD_SIZE] = {0}; - initialize_buffer(local_buffer, simd_lane_id, simd_group_id); - - // Advance the pointers - x += gid * size_t(axis_size) + lid * N_READS; - w += w_stride * lid * N_READS; - b += b_stride * lid * N_READS; - out += gid * size_t(axis_size) + lid * N_READS; - - // Compute some variables for reading writing etc - const bool safe = lid * N_READS + N_READS <= axis_size; - const int n = axis_size - lid * N_READS; - - // Read the inputs - if (safe) { - for (int i = 0; i < N_READS; i++) { - thread_x[i] = x[i]; - } - } else { - for (int i = 0; i < n; i++) { - thread_x[i] = x[i]; - } - } - - // Compute the mean - float mean = 0; - for (int i = 0; i < N_READS; i++) { - mean += thread_x[i]; - } - threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); - mean /= axis_size; - - // Compute the normalizer - float normalizer = 0; - if (!safe) { - for (int i = n; i < N_READS; i++) { - thread_x[i] = mean; - } - } - for (int i = 0; i < N_READS; i++) { - thread_x[i] -= mean; - normalizer += thread_x[i] * thread_x[i]; - } - threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); - normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); - - // Write the outputs - if (safe) { - for (int i = 0; i < N_READS; i++) { - thread_x[i] *= normalizer; - out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; - } - } else { - for (int i = 0; i < n; i++) { - thread_x[i] *= normalizer; - out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; - } - } -} - -template -[[kernel]] void layer_norm_looped( - const device T* x, - const device T* w, - const device T* b, - device T* out, - constant float& eps, - constant uint& axis_size, - constant uint& w_stride, - constant uint& b_stride, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; - - threadgroup float local_buffer[SIMD_SIZE]; - initialize_buffer(local_buffer, simd_lane_id, simd_group_id); - - x += gid * size_t(axis_size) + lid * N_READS; - w += w_stride * lid * N_READS; - b += b_stride * lid * N_READS; - - // Compute the mean - float mean = 0; - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - mean += x[i + r]; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - mean += x[i + r]; - } - } - } - } - threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); - mean /= axis_size; - - // Compute the normalizer - float normalizer = 0; - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - float t = x[i + r] - mean; - normalizer += t * t; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - float t = x[i + r] - mean; - normalizer += t * t; - } - } - } - } - threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); - normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); - - // Write the outputs - out += gid * size_t(axis_size) + lid * N_READS; - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - float xi = (x[r + i] - mean) * normalizer; - out[r + i] = - w[w_stride * (i + r)] * static_cast(xi) + b[b_stride * (i + r)]; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - float xi = (x[r + i] - mean) * normalizer; - out[r + i] = w[w_stride * (i + r)] * static_cast(xi) + - b[b_stride * (i + r)]; - } - } - } - } -} - -template -[[kernel]] void vjp_layer_norm_single_row( - const device T* x, - const device T* w, - const device T* g, - device T* gx, - device T* gw, - constant float& eps, - constant uint& axis_size, - constant uint& w_stride, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; - - // Advance the input pointers - x += gid * size_t(axis_size) + lid * N_READS; - g += gid * size_t(axis_size) + lid * N_READS; - w += w_stride * lid * N_READS; - - // Initialize the registers and threadgroup memory - float thread_x[N_READS] = {0}; - float thread_w[N_READS] = {0}; - float thread_g[N_READS] = {0}; - threadgroup float local_buffer[3 * SIMD_SIZE]; - initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); - - // Compute some variables for reading writing etc - const bool safe = lid * N_READS + N_READS <= axis_size; - const int n = axis_size - lid * N_READS; - - // Read the inputs - if (safe) { - for (int i = 0; i < N_READS; i++) { - thread_x[i] = x[i]; - thread_g[i] = g[i]; - thread_w[i] = w[i * w_stride]; - } - } else { - for (int i = 0; i < n; i++) { - thread_x[i] = x[i]; - thread_g[i] = g[i]; - thread_w[i] = w[i * w_stride]; - } - } - - // Compute the mean - float mean = 0; - for (int i = 0; i < N_READS; i++) { - mean += thread_x[i]; - } - threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); - mean /= axis_size; - - // Compute the neccesary scaling factors using the mean - if (!safe) { - for (int i = n; i < N_READS; i++) { - thread_x[i] = mean; - } - } - float factors[3] = {0}; - constexpr int meanwg = 0; - constexpr int meanwgxc = 1; - constexpr int normalizer2 = 2; - for (int i = 0; i < N_READS; i++) { - thread_x[i] -= mean; - factors[meanwg] += thread_w[i] * thread_g[i]; - factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i]; - factors[normalizer2] += thread_x[i] * thread_x[i]; - } - threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); - factors[meanwg] /= axis_size; - factors[meanwgxc] /= axis_size; - factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); - float normalizer = metal::precise::sqrt(factors[normalizer2]); - - // Write the outputs - gx += gid * size_t(axis_size) + lid * N_READS; - gw += gid * size_t(axis_size) + lid * N_READS; - if (safe) { - for (int i = 0; i < N_READS; i++) { - thread_x[i] *= normalizer; - gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - - thread_x[i] * factors[meanwgxc] * factors[normalizer2]); - if (has_w) { - gw[i] = static_cast(thread_g[i] * thread_x[i]); - } - } - } else { - for (int i = 0; i < n; i++) { - thread_x[i] *= normalizer; - gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - - thread_x[i] * factors[meanwgxc] * factors[normalizer2]); - if (has_w) { - gw[i] = static_cast(thread_g[i] * thread_x[i]); - } - } - } -} - -template -[[kernel]] void vjp_layer_norm_looped( - const device T* x, - const device T* w, - const device T* g, - device T* gx, - device T* gw, - constant float& eps, - constant uint& axis_size, - constant uint& w_stride, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; - - // Advance the input pointers - x += gid * size_t(axis_size) + lid * N_READS; - g += gid * size_t(axis_size) + lid * N_READS; - w += w_stride * lid * N_READS; - - threadgroup float local_buffer[3 * SIMD_SIZE]; - initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); - - // Compute the mean - float mean = 0; - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - mean += x[i + r]; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - mean += x[i + r]; - } - } - } - } - threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); - mean /= axis_size; - - // Compute the neccesary scaling factors using the mean - float factors[3] = {0}; - constexpr int meanwg = 0; - constexpr int meanwgxc = 1; - constexpr int normalizer2 = 2; - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - float t = x[i + r] - mean; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - float wg = wi * gi; - factors[meanwg] += wg; - factors[meanwgxc] += wg * t; - factors[normalizer2] += t * t; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - float t = x[i + r] - mean; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - float wg = wi * gi; - factors[meanwg] += wg; - factors[meanwgxc] += wg * t; - factors[normalizer2] += t * t; - } - } - } - } - threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); - factors[meanwg] /= axis_size; - factors[meanwgxc] /= axis_size; - factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); - float normalizer = metal::precise::sqrt(factors[normalizer2]); - - // Write the outputs - gx += gid * size_t(axis_size) + lid * N_READS; - gw += gid * size_t(axis_size) + lid * N_READS; - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - float xi = (x[i + r] - mean) * normalizer; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - gx[i + r] = static_cast( - normalizer * (wi * gi - factors[meanwg]) - - xi * factors[meanwgxc] * factors[normalizer2]); - if (has_w) { - gw[i + r] = static_cast(gi * xi); - } - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - float xi = (x[i + r] - mean) * normalizer; - float wi = w[(i + r) * w_stride]; - float gi = g[i + r]; - gx[i + r] = static_cast( - normalizer * (wi * gi - factors[meanwg]) - - xi * factors[meanwgxc] * factors[normalizer2]); - if (has_w) { - gw[i + r] = static_cast(gi * xi); - } - } - } - } - } -} - -// clang-format off -#define instantiate_layer_norm(name, itype) \ - instantiate_kernel("layer_norm" #name, layer_norm_single_row, itype) \ - instantiate_kernel("vjp_layer_norm" #name, vjp_layer_norm_single_row, itype) \ - instantiate_kernel("layer_norm_looped" #name, layer_norm_looped, itype) \ - instantiate_kernel("vjp_layer_norm_looped" #name, vjp_layer_norm_looped, itype) - -instantiate_layer_norm(float32, float) -instantiate_layer_norm(float16, half) -instantiate_layer_norm(bfloat16, bfloat16_t) // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/logging.h b/Source/Cmlx/mlx-generated/metal/logging.h deleted file mode 100644 index 7b3ee046..00000000 --- a/Source/Cmlx/mlx-generated/metal/logging.h +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#if defined(__METAL_VERSION__) && (__METAL_VERSION__ >= 320) -#include - -namespace mlx { -using os_log = metal::os_log; -} // namespace mlx - -#else - -namespace mlx { -struct os_log { - constexpr os_log(constant char*, constant char*) constant {} - - template - void log_debug(constant char*, Args...) const {} - - template - void log_debug(constant char*, Args...) const constant {} -}; -} // namespace mlx - -#endif \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/logsumexp.h b/Source/Cmlx/mlx-generated/metal/logsumexp.h deleted file mode 100644 index c746050b..00000000 --- a/Source/Cmlx/mlx-generated/metal/logsumexp.h +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright © 2025 Apple Inc. - -template -[[kernel]] void logsumexp( - const device T* in, - device T* out, - constant int& axis_size, - uint gid [[threadgroup_position_in_grid]], - uint _lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - int lid = _lid; - - constexpr int SIMD_SIZE = 32; - - threadgroup AccT local_max[SIMD_SIZE]; - threadgroup AccT local_normalizer[SIMD_SIZE]; - - AccT ld[N_READS]; - - in += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - ld[i] = AccT(in[i]); - } - } else { - for (int i = 0; i < N_READS; i++) { - ld[i] = - ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; - } - } - if (simd_group_id == 0) { - local_max[simd_lane_id] = Limits::min; - local_normalizer[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Get the max - AccT maxval = Limits::finite_min; - for (int i = 0; i < N_READS; i++) { - maxval = (maxval < ld[i]) ? ld[i] : maxval; - } - maxval = simd_max(maxval); - if (simd_lane_id == 0) { - local_max[simd_group_id] = maxval; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_group_id == 0) { - maxval = simd_max(local_max[simd_lane_id]); - if (simd_lane_id == 0) { - local_max[0] = maxval; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - maxval = local_max[0]; - - // Compute exp(x_i - maxval) and store the partial sums in local_normalizer - AccT normalizer = 0; - for (int i = 0; i < N_READS; i++) { - normalizer += fast::exp(ld[i] - maxval); - } - normalizer = simd_sum(normalizer); - if (simd_lane_id == 0) { - local_normalizer[simd_group_id] = normalizer; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); - } - } -} - -template -[[kernel]] void logsumexp_looped( - const device T* in, - device T* out, - constant int& axis_size, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - in += gid * size_t(axis_size); - - constexpr int SIMD_SIZE = 32; - - threadgroup AccT local_max[SIMD_SIZE]; - threadgroup AccT local_normalizer[SIMD_SIZE]; - - // Get the max and the normalizer in one go - AccT prevmax; - AccT maxval = Limits::finite_min; - AccT normalizer = 0; - for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); - r++) { - int offset = r * lsize * N_READS + lid * N_READS; - AccT vals[N_READS]; - if (offset + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - vals[i] = AccT(in[offset + i]); - } - } else { - for (int i = 0; i < N_READS; i++) { - vals[i] = - (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; - } - } - prevmax = maxval; - for (int i = 0; i < N_READS; i++) { - maxval = (maxval < vals[i]) ? vals[i] : maxval; - } - normalizer *= fast::exp(prevmax - maxval); - for (int i = 0; i < N_READS; i++) { - normalizer += fast::exp(vals[i] - maxval); - } - } - prevmax = maxval; - maxval = simd_max(maxval); - normalizer *= fast::exp(prevmax - maxval); - normalizer = simd_sum(normalizer); - - prevmax = maxval; - if (simd_lane_id == 0) { - local_max[simd_group_id] = maxval; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - maxval = simd_max(local_max[simd_lane_id]); - normalizer *= fast::exp(prevmax - maxval); - if (simd_lane_id == 0) { - local_normalizer[simd_group_id] = normalizer; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - normalizer = simd_sum(local_normalizer[simd_lane_id]); - - if (lid == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); - } -} diff --git a/Source/Cmlx/mlx-generated/metal/quantized.h b/Source/Cmlx/mlx-generated/metal/quantized.h deleted file mode 100644 index 5ac4c6e1..00000000 --- a/Source/Cmlx/mlx-generated/metal/quantized.h +++ /dev/null @@ -1,2508 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -using namespace metal; - -#define MLX_MTL_CONST static constant constexpr const - -MLX_MTL_CONST int SIMD_SIZE = 32; -MLX_MTL_CONST int QUAD_SIZE = 4; - -template -inline constexpr short get_pack_factor() { - return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); -} - -template -inline constexpr short get_bytes_per_pack() { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); -} - -template -inline U load_vector(const device T* x, thread U* x_thread) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - U sum = 0; - - if (bits == 2) { - for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 4.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 64.0f; - } - } - - else if (bits == 3) { - for (int i = 0; i < values_per_thread; i += 8) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + - x[i + 6] + x[i + 7]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 8.0f; - x_thread[i + 2] = x[i + 2] / 64.0f; - x_thread[i + 3] = x[i + 3] / 2.0f; - x_thread[i + 4] = x[i + 4] / 16.0f; - x_thread[i + 5] = x[i + 5] / 128.0f; - x_thread[i + 6] = x[i + 6] / 4.0f; - x_thread[i + 7] = x[i + 7] / 32.0f; - } - } - - else if (bits == 4) { - for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 16.0f; - x_thread[i + 2] = x[i + 2] / 256.0f; - x_thread[i + 3] = x[i + 3] / 4096.0f; - } - } - - else if (bits == 5) { - for (int i = 0; i < values_per_thread; i += 8) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + - x[i + 6] + x[i + 7]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 32.0f; - x_thread[i + 2] = x[i + 2] / 4.0f; - x_thread[i + 3] = x[i + 3] / 128.0f; - x_thread[i + 4] = x[i + 4] / 16.0f; - x_thread[i + 5] = x[i + 5] / 2.0f; - x_thread[i + 6] = x[i + 6] / 64.0f; - x_thread[i + 7] = x[i + 7] / 8.0f; - } - } - - else if (bits == 6) { - for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 64.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 4.0f; - } - } - - else if (bits == 8) { - for (int i = 0; i < values_per_thread; i++) { - sum += x[i]; - x_thread[i] = x[i]; - } - } - - return sum; -} - -template -inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - U sum = 0; - - if (bits == 2) { - for (int i = 0; i < N; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 4.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 64.0f; - } - } - - else if (bits == 3) { - for (int i = 0; i < N; i += 8) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + - x[i + 6] + x[i + 7]; - - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 8.0f; - x_thread[i + 2] = x[i + 2] / 64.0f; - x_thread[i + 3] = x[i + 3] / 2.0f; - x_thread[i + 4] = x[i + 4] / 16.0f; - x_thread[i + 5] = x[i + 5] / 128.0f; - x_thread[i + 6] = x[i + 6] / 4.0f; - x_thread[i + 7] = x[i + 7] / 32.0f; - } - } - - else if (bits == 4) { - for (int i = 0; i < N; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 16.0f; - x_thread[i + 2] = x[i + 2] / 256.0f; - x_thread[i + 3] = x[i + 3] / 4096.0f; - } - } - - else if (bits == 5) { - for (int i = 0; i < N; i += 8) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + - x[i + 6] + x[i + 7]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 32.0f; - x_thread[i + 2] = x[i + 2] / 4.0f; - x_thread[i + 3] = x[i + 3] / 128.0f; - x_thread[i + 4] = x[i + 4] / 16.0f; - x_thread[i + 5] = x[i + 5] / 2.0f; - x_thread[i + 6] = x[i + 6] / 64.0f; - x_thread[i + 7] = x[i + 7] / 8.0f; - } - } - - else if (bits == 6) { - for (int i = 0; i < N; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 64.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 4.0f; - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - sum += x[i]; - x_thread[i] = x[i]; - } - } - - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } - - return sum; -} - -template -inline U qdot( - const device uint8_t* w, - const thread U* x_thread, - U scale, - U bias, - U sum) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - U accum = 0; - - if (bits == 2) { - for (int i = 0; i < (values_per_thread / 4); i++) { - accum += - (x_thread[4 * i] * (w[i] & 0x03) + - x_thread[4 * i + 1] * (w[i] & 0x0c) + - x_thread[4 * i + 2] * (w[i] & 0x30) + - x_thread[4 * i + 3] * (w[i] & 0xc0)); - } - } - - else if (bits == 3) { - for (int i = 0; i < (values_per_thread / 8); i++) { - x_thread += 8 * i; - w += 3 * i; - - accum += (w[0] & 0x07) * x_thread[0]; - accum += (w[0] & 0x38) * x_thread[1]; - accum += (w[0] & 0xc0) * x_thread[2]; - accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); - - accum += (w[1] & 0x0e) * x_thread[3]; - accum += (w[1] & 0x70) * x_thread[4]; - accum += (w[1] & 0x80) * x_thread[5]; - accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); - - accum += (w[2] & 0x1c) * x_thread[6]; - accum += (w[2] & 0xe0) * x_thread[7]; - } - } - - else if (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - for (int i = 0; i < (values_per_thread / 4); i++) { - accum += - (x_thread[4 * i] * (ws[i] & 0x000f) + - x_thread[4 * i + 1] * (ws[i] & 0x00f0) + - x_thread[4 * i + 2] * (ws[i] & 0x0f00) + - x_thread[4 * i + 3] * (ws[i] & 0xf000)); - } - } - - else if (bits == 5) { - for (int i = 0; i < (values_per_thread / 8); i++) { - x_thread += 8 * i; - w += 5 * i; - - accum += (w[0] & 0x1f) * x_thread[0]; - accum += (w[0] & 0xe0) * x_thread[1]; - accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); - accum += (w[1] & 0x7c) * x_thread[2]; - accum += (w[1] & 0x80) * x_thread[3]; - accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); - accum += (w[2] & 0xf0) * x_thread[4]; - accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); - accum += (w[3] & 0x3e) * x_thread[5]; - accum += (w[3] & 0xc0) * x_thread[6]; - accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); - accum += (w[4] & 0xf8) * x_thread[7]; - } - } - - else if (bits == 6) { - for (int i = 0; i < (values_per_thread / 4); i++) { - x_thread += 4 * i; - w += 3 * i; - - accum += (w[0] & 0x3f) * x_thread[0]; - - accum += (w[0] & 0xc0) * x_thread[1]; - accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); - - accum += (w[1] & 0xf0) * x_thread[2]; - accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); - - accum += (w[2] & 0xfc) * x_thread[3]; - } - } - - else if (bits == 8) { - for (int i = 0; i < values_per_thread; i++) { - accum += x_thread[i] * w[i]; - } - } - - return scale * accum + sum * bias; -} - -template -inline U qdot_safe( - const device uint8_t* w, - const thread U* x_thread, - U scale, - U bias, - U sum, - int N) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - U accum = 0; - - if (bits == 2) { - for (int i = 0; i < (N / 4); i++) { - accum += - (x_thread[4 * i] * (w[i] & 0x03) + - x_thread[4 * i + 1] * (w[i] & 0x0c) + - x_thread[4 * i + 2] * (w[i] & 0x30) + - x_thread[4 * i + 3] * (w[i] & 0xc0)); - } - } - - else if (bits == 3) { - for (int i = 0; i < (N / 8); i++) { - x_thread += 8 * i; - w += 3 * i; - - accum += (w[0] & 0x07) * x_thread[0]; - accum += (w[0] & 0x38) * x_thread[1]; - accum += (w[0] & 0xc0) * x_thread[2]; - accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); - - accum += (w[1] & 0x0e) * x_thread[3]; - accum += (w[1] & 0x70) * x_thread[4]; - accum += (w[1] & 0x80) * x_thread[5]; - accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); - - accum += (w[2] & 0x1c) * x_thread[6]; - accum += (w[2] & 0xe0) * x_thread[7]; - } - } - - else if (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - for (int i = 0; i < (N / 4); i++) { - accum += - (x_thread[4 * i] * (ws[i] & 0x000f) + - x_thread[4 * i + 1] * (ws[i] & 0x00f0) + - x_thread[4 * i + 2] * (ws[i] & 0x0f00) + - x_thread[4 * i + 3] * (ws[i] & 0xf000)); - } - } - - else if (bits == 5) { - for (int i = 0; i < (N / 8); i++) { - x_thread += 8 * i; - w += 5 * i; - - accum += (w[0] & 0x1f) * x_thread[0]; - accum += (w[0] & 0xe0) * x_thread[1]; - accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); - accum += (w[1] & 0x7c) * x_thread[2]; - accum += (w[1] & 0x80) * x_thread[3]; - accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); - accum += (w[2] & 0xf0) * x_thread[4]; - accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); - accum += (w[3] & 0x3e) * x_thread[5]; - accum += (w[3] & 0xc0) * x_thread[6]; - accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); - accum += (w[4] & 0xf8) * x_thread[7]; - } - } - - else if (bits == 6) { - for (int i = 0; i < (N / 4); i++) { - x_thread += 4 * i; - w += 3 * i; - - accum += (w[0] & 0x3f) * x_thread[0]; - - accum += (w[0] & 0xc0) * x_thread[1]; - accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); - - accum += (w[1] & 0xf0) * x_thread[2]; - accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); - - accum += (w[2] & 0xfc) * x_thread[3]; - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - accum += x_thread[i] * w[i]; - } - } - - return scale * accum + sum * bias; -} - -template -inline void -qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - if (bits == 2) { - U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; - for (int i = 0; i < (values_per_thread / 4); i++) { - result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); - result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); - result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); - result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); - } - } - - else if (bits == 3) { - for (int i = 0; i < (values_per_thread / 8); i++) { - uint8_t w0 = w[3 * i]; - uint8_t w1 = w[3 * i + 1]; - uint8_t w2 = w[3 * i + 2]; - - result[8 * i] += x * ((w0 & 0x7) * scale + bias); - result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); - result[8 * i + 2] += - x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); - result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); - result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); - result[8 * i + 5] += - x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); - result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); - result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); - } - } - - else if (bits == 4) { - U s[2] = {scale, scale / 16.0f}; - for (int i = 0; i < (values_per_thread / 2); i++) { - result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); - result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); - } - } - - else if (bits == 5) { - for (int i = 0; i < (values_per_thread / 8); i++) { - uint8_t w0 = w[5 * i]; - uint8_t w1 = w[5 * i + 1]; - uint8_t w2 = w[5 * i + 2]; - uint8_t w3 = w[5 * i + 3]; - uint8_t w4 = w[5 * i + 4]; - result[8 * i] += x * ((w0 & 0x1f) * scale + bias); - result[8 * i + 1] += - x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); - result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); - result[8 * i + 3] += - x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); - result[8 * i + 4] += - x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); - result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); - result[8 * i + 6] += - x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); - result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); - } - } - - else if (bits == 6) { - for (int i = 0; i < (values_per_thread / 4); i++) { - uint8_t w0 = w[3 * i]; - uint8_t w1 = w[3 * i + 1]; - uint8_t w2 = w[3 * i + 2]; - - result[4 * i] += x * ((w0 & 0x3f) * scale + bias); - result[4 * i + 1] += - x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); - result[4 * i + 2] += - x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); - result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); - } - } - - else if (bits == 8) { - for (int i = 0; i < values_per_thread; i++) { - result[i] += x * (scale * w[i] + bias); - } - } -} - -template -inline void -dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - if (bits == 2) { - U s[4] = { - scale, - scale / static_cast(4.0f), - scale / static_cast(16.0f), - scale / static_cast(64.0f)}; - for (int i = 0; i < (N / 4); i++) { - w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; - w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; - w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; - w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; - } - } - - else if (bits == 3) { - for (int i = 0; i < (N / 8); i++) { - w_local += 8 * i; - w += 3 * i; - - w_local[0] = (w[0] & 0x7) * scale + bias; - w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; - w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; - w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; - w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; - w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; - w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; - w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; - } - } - - else if (bits == 4) { - U s[2] = {scale, scale / static_cast(16.0f)}; - for (int i = 0; i < (N / 2); i++) { - w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; - w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; - } - } - - else if (bits == 5) { - for (int i = 0; i < (N / 8); i++) { - w_local += 8 * i; - w += 5 * i; - - w_local[0] = (w[0] & 0x1f) * scale + bias; - w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; - w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; - w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; - w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; - w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; - w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; - w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; - } - } - - else if (bits == 6) { - for (int i = 0; i < (N / 4); i++) { - w_local += 4 * i; - w += 3 * i; - w_local[0] = (w[0] & 0x3f) * scale + bias; - w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; - w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; - w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - w_local[i] = scale * w[i] + bias; - } - } -} - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short group_size, - short bits> -struct QuantizedBlockLoader { - static_assert( - BCOLS <= group_size, - "The group size should be larger than the columns"); - static_assert( - group_size % BCOLS == 0, - "The group size should be divisible by the columns"); - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - MLX_MTL_CONST short pack_factor = get_pack_factor(); - MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); - MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; - MLX_MTL_CONST short n_reads = - (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; - MLX_MTL_CONST short group_steps = group_size / BCOLS; - - const int src_ld; - const int tile_stride; - short group_step_cnt; - const int group_stride; - - const short thread_idx; - const short bi; - const short bj; - - threadgroup T* dst; - const device uint8_t* src; - const device T* scales; - const device T* biases; - - QuantizedBlockLoader( - const device uint8_t* src_, - const device T* scales_, - const device T* biases_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride( - reduction_dim ? BCOLS_PACKED * bytes_per_pack - : BROWS * src_ld * bytes_per_pack / pack_factor), - group_step_cnt(0), - group_stride(BROWS * src_ld / group_size), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(n_reads * thread_idx / BCOLS_PACKED), - bj((n_reads * thread_idx) % BCOLS_PACKED), - dst(dst_ + bi * dst_ld + bj * pack_factor), - src(src_ + bi * src_ld * bytes_per_pack / pack_factor + - bj * bytes_per_pack), - scales(scales_ + bi * src_ld / group_size), - biases(biases_ + bi * src_ld / group_size) {} - - void load_unsafe() const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - T scale = *scales; - T bias = *biases; - for (int i = 0; i < n_reads; i++) { - dequantize( - src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); - } - } - - void load_safe(short2 src_tile_dim) const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - if (reduction_dim == 1 && bi >= src_tile_dim.x) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - if (reduction_dim == 0 && bi >= src_tile_dim.y) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - T scale = *scales; - T bias = *biases; - for (int i = 0; i < n_reads; i++) { - dequantize( - (device uint8_t*)(src + i * bytes_per_pack), - scale, - bias, - dst + i * pack_factor); - } - } - - void next() { - src += tile_stride; - if (reduction_dim == 1) { - if (group_steps > 1) { - group_step_cnt++; - if (group_step_cnt == group_steps) { - group_step_cnt = 0; - scales++; - biases++; - } - } else { - scales++; - biases++; - } - } else { - scales += group_stride; - biases += group_stride; - } - } -}; - -template -METAL_FUNC void qmv_quad_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - constant int& in_vec_size, - const constant int& out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { - constexpr int quads_per_simd = SIMD_SIZE / QUAD_SIZE; - constexpr int pack_factor = 32 / bits; - constexpr int values_per_thread = D / QUAD_SIZE; - constexpr int packs_per_thread = values_per_thread / pack_factor; - constexpr int scale_step_per_thread = group_size / values_per_thread; - constexpr int results_per_quadgroup = 8; - - typedef float U; - - thread U x_thread[values_per_thread]; - thread U result[results_per_quadgroup] = {0}; - - // Adjust positions - const int in_vec_size_w = in_vec_size / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * quads_per_simd * results_per_quadgroup + quad_gid; - - w += out_row * in_vec_size_w + quad_lid * packs_per_thread; - scales += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; - biases += out_row * in_vec_size_g + quad_lid / scale_step_per_thread; - x += tid.x * in_vec_size + quad_lid * values_per_thread; - y += tid.x * out_vec_size + out_row; - - U sum = load_vector(x, x_thread); - - for (int row = 0; row < results_per_quadgroup; row++) { - auto wl = (const device uint8_t*)(w + row * in_vec_size_w * quads_per_simd); - const device T* sl = scales + row * in_vec_size_g * quads_per_simd; - const device T* bl = biases + row * in_vec_size_g * quads_per_simd; - - U s = sl[0]; - U b = bl[0]; - if (row * quads_per_simd + out_row < out_vec_size) { - result[row] += qdot(wl, x_thread, s, b, sum); - } - } - - for (int row = 0; row < results_per_quadgroup; row++) { - result[row] = quad_sum(result[row]); - if (quad_lid == 0 && row * quads_per_simd + out_row < out_vec_size) { - y[row * quads_per_simd] = static_cast(result[row]); - } - } -} - -template -METAL_FUNC void qmv_fast_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int packs_per_thread = bits == 2 ? 1 : 2; - constexpr int num_simdgroups = 2; - constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - constexpr int values_per_thread = pack_factor * packs_per_thread; - constexpr int block_size = values_per_thread * SIMD_SIZE; - constexpr int scale_step_per_thread = group_size / values_per_thread; - - const device uint8_t* ws = (const device uint8_t*)w; - - typedef float U; - - thread U x_thread[values_per_thread]; - thread U result[results_per_simdgroup] = {0}; - - // Adjust positions - const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + - simd_gid * results_per_simdgroup; - - ws += out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; - scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.x * in_vec_size + simd_lid * values_per_thread; - y += tid.x * out_vec_size + out_row; - - for (int k = 0; k < in_vec_size; k += block_size) { - U sum = load_vector(x, x_thread); - - for (int row = 0; row < results_per_simdgroup; row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += qdot(wl, x_thread, s, b, sum); - } - - ws += block_size * bytes_per_pack / pack_factor; - scales += block_size / group_size; - biases += block_size / group_size; - x += block_size; - } - - for (int row = 0; row < results_per_simdgroup; row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { - y[row] = static_cast(result[row]); - } - } -} - -template -METAL_FUNC void qmv_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - const constant int& in_vec_size, - const constant int& out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int num_simdgroups = 2; - constexpr int results_per_simdgroup = 4; - constexpr int packs_per_thread = 1; - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int values_per_thread = pack_factor * packs_per_thread; - constexpr int block_size = values_per_thread * SIMD_SIZE; - constexpr int scale_step_per_thread = group_size / values_per_thread; - - const device uint8_t* ws = (const device uint8_t*)w; - - typedef float U; - - thread U x_thread[values_per_thread]; - thread U result[results_per_simdgroup] = {0}; - - // Adjust positions - const int in_vec_size_w = in_vec_size * bytes_per_pack / pack_factor; - const int in_vec_size_g = in_vec_size / group_size; - const int out_row = tid.y * (num_simdgroups * results_per_simdgroup) + - simd_gid * results_per_simdgroup; - const int used_out_row = min(out_vec_size - results_per_simdgroup, out_row); - - if (out_row >= out_vec_size) { - return; - } - - // In this case we need to properly guard all our reads because there isn't - // even 1 tile in the matrix - if (out_vec_size < (num_simdgroups * results_per_simdgroup)) { - ws += - out_row * in_vec_size_w + simd_lid * packs_per_thread * bytes_per_pack; - scales += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - biases += out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.x * in_vec_size + simd_lid * values_per_thread; - y += tid.x * out_vec_size + out_row; - - int k = 0; - for (; k < in_vec_size - block_size; k += block_size) { - U sum = load_vector(x, x_thread); - - for (int row = 0; - row < results_per_simdgroup && out_row + row < out_vec_size; - row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += - qdot(wl, x_thread, s, b, sum); - } - - ws += block_size * bytes_per_pack / pack_factor; - scales += block_size / group_size; - biases += block_size / group_size; - x += block_size; - } - const int remaining = clamp( - static_cast(in_vec_size - k - simd_lid * values_per_thread), - 0, - values_per_thread); - if (remaining > 0) { - U sum = load_vector_safe( - x, x_thread, remaining); - - for (int row = 0; - row < results_per_simdgroup && out_row + row < out_vec_size; - row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += qdot_safe( - wl, x_thread, s, b, sum, remaining); - } - } - - for (int row = 0; - row < results_per_simdgroup && out_row + row < out_vec_size; - row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { - y[row] = static_cast(result[row]); - } - } - } - - // In this case the last tile is moved back to redo some output values - else { - ws += used_out_row * in_vec_size_w + - simd_lid * packs_per_thread * bytes_per_pack; - scales += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - biases += used_out_row * in_vec_size_g + simd_lid / scale_step_per_thread; - x += tid.x * in_vec_size + simd_lid * values_per_thread; - y += tid.x * out_vec_size + used_out_row; - - int k = 0; - for (; k < in_vec_size - block_size; k += block_size) { - U sum = load_vector(x, x_thread); - - for (int row = 0; row < results_per_simdgroup; row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += - qdot(wl, x_thread, s, b, sum); - } - - ws += block_size * bytes_per_pack / pack_factor; - scales += block_size / group_size; - biases += block_size / group_size; - x += block_size; - } - const int remaining = clamp( - static_cast(in_vec_size - k - simd_lid * values_per_thread), - 0, - values_per_thread); - if (remaining > 0) { - U sum = load_vector_safe( - x, x_thread, remaining); - - for (int row = 0; row < results_per_simdgroup; row++) { - auto wl = (const device uint8_t*)(ws + row * in_vec_size_w); - const device T* sl = scales + row * in_vec_size_g; - const device T* bl = biases + row * in_vec_size_g; - - U s = sl[0]; - U b = bl[0]; - result[row] += qdot_safe( - wl, x_thread, s, b, sum, remaining); - } - } - for (int row = 0; row < results_per_simdgroup; row++) { - result[row] = simd_sum(result[row]); - if (simd_lid == 0) { - y[row] = static_cast(result[row]); - } - } - } -} - -template -METAL_FUNC void qvm_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - const int in_vec_size, - const int out_vec_size, - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int num_simdgroups = 2; - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int tn = 32 / pack_factor; - constexpr int block_size = SIMD_SIZE; - - using W_T = - typename ConditionalType::type; - const device W_T* ws = (const device W_T*)w; - - typedef float U; - typedef struct { - W_T wi[tn * bytes_per_pack]; - } vec_w; - - thread vec_w w_local; - thread U result[tn * pack_factor] = {0}; - thread U scale = 1; - thread U bias = 0; - thread U x_local = 0; - - // Adjust positions - const int out_vec_size_w = out_vec_size * bytes_per_pack / pack_factor; - const int out_vec_size_g = out_vec_size / group_size; - int out_col = pack_factor * tn * (tid.y * num_simdgroups + simd_gid); - ws += out_col * bytes_per_pack / pack_factor + simd_lid * out_vec_size_w; - scales += out_col / group_size + simd_lid * out_vec_size_g; - biases += out_col / group_size + simd_lid * out_vec_size_g; - x += tid.x * in_vec_size + simd_lid; - y += tid.x * out_vec_size + out_col; - - if (out_col >= out_vec_size) { - return; - } - - // Loop over in_vec in blocks of block_size - int remaining = in_vec_size % block_size; - if (remaining == 0) { - for (int i = 0; i < in_vec_size; i += block_size) { - x_local = *x; - scale = *scales; - bias = *biases; - w_local = *((device vec_w*)ws); - qouter( - (thread uint8_t*)&w_local, x_local, scale, bias, result); - - x += block_size; - scales += block_size * out_vec_size_g; - biases += block_size * out_vec_size_g; - ws += block_size * out_vec_size_w; - } - } else { - for (int i = block_size; i < in_vec_size; i += block_size) { - x_local = *x; - scale = *scales; - bias = *biases; - w_local = *((device vec_w*)ws); - - qouter( - (thread uint8_t*)&w_local, x_local, scale, bias, result); - - x += block_size; - scales += block_size * out_vec_size_g; - biases += block_size * out_vec_size_g; - ws += block_size * out_vec_size_w; - } - if (static_cast(simd_lid) < remaining) { - x_local = *x; - scale = *scales; - bias = *biases; - w_local = *((device vec_w*)ws); - } else { - x_local = 0; - scale = 0; - bias = 0; - } - qouter( - (thread uint8_t*)&w_local, x_local, scale, bias, result); - } - -// Accumulate in the simdgroup -#pragma clang loop unroll(full) - for (int k = 0; k < tn * pack_factor; k++) { - result[k] = simd_sum(result[k]); - } - - // Store the result - if (simd_lid == 0) { -#pragma clang loop unroll(full) - for (int k = 0; k < tn * pack_factor; k++) { - y[k] = static_cast(result[k]); - } - } -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const int BM = 32, - const int BK = 32, - const int BN = 32> -METAL_FUNC void qmm_t_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - threadgroup T* Xs, - threadgroup T* Ws, - const constant int& K, - const constant int& N, - const constant int& M, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - (void)lid; - - constexpr int WM = 2; - constexpr int WN = 2; - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel:: - BlockMMA; - using loader_x_t = - mlx::steel::BlockLoader; - using loader_w_t = QuantizedBlockLoader< - T, - BN, - BK, - BK_padded, - 1, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - // Set the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - - auto wl = (const device uint8_t*)w; - - x += y_row * static_cast(K); - wl += y_col * K_w; - scales += y_col * K_g; - biases += y_col * K_g; - y += y_row * static_cast(N) + y_col; - - // Make the x loader and mma operation - const short num_els = min(BM, M - y_row); - const short num_outs = min(BN, N - y_col); - loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); - mma_t mma_op(simd_gid, simd_lid); - - if (num_els < BM) { - if (!aligned_N && num_outs < BN) { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_safe(short2(BK, num_outs)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } else { - if (!aligned_N && num_outs < BN) { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_safe(short2(BK, num_outs)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - if (num_els < BM || num_outs < BN) { - mma_op.store_result_safe(y, N, short2(num_outs, num_els)); - } else { - mma_op.store_result(y, N); - } -} - -template < - typename T, - const int group_size, - const int bits, - const int BM = 32, - const int BK = 32, - const int BN = 32> -METAL_FUNC void qmm_n_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - threadgroup T* Xs, - threadgroup T* Ws, - const constant int& K, - const constant int& N, - const constant int& M, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - (void)lid; - - constexpr int WM = 2; - constexpr int WN = 2; - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - // Instantiate the appropriate BlockMMA and Loader - using mma_t = mlx::steel:: - BlockMMA; - using loader_x_t = mlx::steel:: - BlockLoader; - using loader_w_t = QuantizedBlockLoader< - T, - BK, - BN, - BN_padded, - 0, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - auto wl = (const device uint8_t*)w; - - // Set the block - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - x += y_row * static_cast(K); - wl += y_col * bytes_per_pack / pack_factor; - scales += y_col / group_size; - biases += y_col / group_size; - y += y_row * static_cast(N) + y_col; - - // Make the x loader and mma operation - const short num_els = min(BM, M - y_row); - loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); - loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); - mma_t mma_op(simd_gid, simd_lid); - - if (num_els < BM) { - if ((K % BK) != 0) { - const int k_blocks = K / BK; - for (int k = 0; k < k_blocks; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - const short num_k = K - k_blocks * BK; - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(num_k, num_els)); - loader_w.load_safe(short2(BN, num_k)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(BK, num_els)); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } else { - if ((K % BK) != 0) { - const int k_blocks = K / BK; - for (int k = 0; k < k_blocks; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - const short num_k = K - k_blocks * BK; - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_safe(short2(num_k, BM)); - loader_w.load_safe(short2(BN, num_k)); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - } else { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_x.load_unsafe(); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(Xs, Ws); - loader_x.next(); - loader_w.next(); - } - } - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - if (num_els < BM) { - mma_op.store_result_safe(y, N, short2(BN, num_els)); - } else { - mma_op.store_result(y, N); - } -} - -template -METAL_FUNC void adjust_matrix_offsets( - const device T*& x, - const device uint32_t*& w, - const device T*& scales, - const device T*& biases, - device T*& y, - int output_stride, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int64_t* b_strides, - uint3 tid [[threadgroup_position_in_grid]]) { - // Set the input/output matrices - uint32_t x_idx = tid.z; - uint32_t w_idx = tid.z; - if (x_batch_ndims == 1) { - x += x_idx * x_strides[0]; - } else { - x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); - } - if (w_batch_ndims == 1) { - w += w_idx * w_strides[0]; - scales += w_idx * s_strides[0]; - biases += w_idx * b_strides[0]; - } else { - ulong3 idx = elem_to_loc_broadcast( - w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); - w += idx.x; - scales += idx.y; - biases += idx.z; - } - y += tid.z * output_stride; -} - -template -METAL_FUNC void adjust_matrix_offsets( - const device T*& x, - const device uint32_t*& w, - const device T*& scales, - const device T*& biases, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T*& y, - int output_stride, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int64_t* b_strides, - uint3 tid [[threadgroup_position_in_grid]]) { - // Set the input/output matrices - uint32_t x_idx; - uint32_t w_idx; - if (batch_ndims == 1) { - x_idx = lhs_indices[tid.z * lhs_strides[0]]; - w_idx = rhs_indices[tid.z * rhs_strides[0]]; - } else { - ulong2 idx = elem_to_loc_broadcast( - tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); - x_idx = lhs_indices[idx.x]; - w_idx = rhs_indices[idx.y]; - } - if (x_batch_ndims == 1) { - x += x_idx * x_strides[0]; - } else { - x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); - } - if (w_batch_ndims == 1) { - w += w_idx * w_strides[0]; - scales += w_idx * s_strides[0]; - biases += w_idx * b_strides[0]; - } else { - ulong3 idx = elem_to_loc_broadcast( - w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); - w += idx.x; - scales += idx.y; - biases += idx.z; - } - y += tid.z * output_stride; -} - -template -[[kernel]] void affine_qmv_quad( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - uint3 tid [[threadgroup_position_in_grid]], - uint quad_gid [[quadgroup_index_in_threadgroup]], - uint quad_lid [[thread_index_in_quadgroup]]) { - if (batched) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - biases, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - } - qmv_quad_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - quad_gid, - quad_lid); -} - -template -[[kernel]] void affine_qmv_fast( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - if (batched) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - biases, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - } - qmv_fast_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template -[[kernel]] void affine_qmv( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - if (batched) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - biases, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - } - qmv_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template -[[kernel]] void affine_qvm( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - if (batched) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - biases, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - } - qvm_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template -[[kernel]] void affine_qvm_split_k( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& in_vec_size [[buffer(5)]], - const constant int& out_vec_size [[buffer(6)]], - const constant int& x_batch_ndims [[buffer(7)]], - const constant int* x_shape [[buffer(8)]], - const constant int64_t* x_strides [[buffer(9)]], - const constant int& w_batch_ndims [[buffer(10)]], - const constant int* w_shape [[buffer(11)]], - const constant int64_t* w_strides [[buffer(12)]], - const constant int64_t* s_strides [[buffer(13)]], - const constant int64_t* b_strides [[buffer(14)]], - const constant int& final_block_size [[buffer(15)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - biases, - y, - out_vec_size * M, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - - // When (in_vec_size % split_k != 0) the final block needs to be smaller - int in_vec_size_adj = - tid.z % split_k == split_k - 1 ? final_block_size : in_vec_size; - - qvm_impl( - w, - scales, - biases, - x, - y, - in_vec_size_adj, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const bool batched, - const int BM = 32, - const int BK = 32, - const int BN = 32> -[[kernel]] void affine_qmm_t( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BN * BK_padded]; - - if (batched) { - adjust_matrix_offsets( - x, - w, - scales, - biases, - y, - M * N, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - } - qmm_t_impl( - w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool batched, - const int BM = 32, - const int BK = 32, - const int BN = 32> -[[kernel]] void affine_qmm_n( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; - - if (batched) { - adjust_matrix_offsets( - x, - w, - scales, - biases, - y, - M * N, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - } - - qmm_n_impl( - w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template -[[kernel]] void affine_gather_qmv_fast( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& x_batch_ndims [[buffer(9)]], - const constant int* x_shape [[buffer(10)]], - const constant int64_t* x_strides [[buffer(11)]], - const constant int& w_batch_ndims [[buffer(12)]], - const constant int* w_shape [[buffer(13)]], - const constant int64_t* w_strides [[buffer(14)]], - const constant int64_t* s_strides [[buffer(15)]], - const constant int64_t* b_strides [[buffer(16)]], - const constant int& batch_ndims [[buffer(17)]], - const constant int* batch_shape [[buffer(18)]], - const constant int64_t* lhs_strides [[buffer(19)]], - const constant int64_t* rhs_strides [[buffer(20)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - out_vec_size * M, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmv_fast_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template -[[kernel]] void affine_gather_qmv( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& x_batch_ndims [[buffer(9)]], - const constant int* x_shape [[buffer(10)]], - const constant int64_t* x_strides [[buffer(11)]], - const constant int& w_batch_ndims [[buffer(12)]], - const constant int* w_shape [[buffer(13)]], - const constant int64_t* w_strides [[buffer(14)]], - const constant int64_t* s_strides [[buffer(15)]], - const constant int64_t* b_strides [[buffer(16)]], - const constant int& batch_ndims [[buffer(17)]], - const constant int* batch_shape [[buffer(18)]], - const constant int64_t* lhs_strides [[buffer(19)]], - const constant int64_t* rhs_strides [[buffer(20)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - out_vec_size * M, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmv_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template -[[kernel]] void affine_gather_qvm( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& in_vec_size [[buffer(7)]], - const constant int& out_vec_size [[buffer(8)]], - const constant int& x_batch_ndims [[buffer(9)]], - const constant int* x_shape [[buffer(10)]], - const constant int64_t* x_strides [[buffer(11)]], - const constant int& w_batch_ndims [[buffer(12)]], - const constant int* w_shape [[buffer(13)]], - const constant int64_t* w_strides [[buffer(14)]], - const constant int64_t* s_strides [[buffer(15)]], - const constant int64_t* b_strides [[buffer(16)]], - const constant int& batch_ndims [[buffer(17)]], - const constant int* batch_shape [[buffer(18)]], - const constant int64_t* lhs_strides [[buffer(19)]], - const constant int64_t* rhs_strides [[buffer(20)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - int M = x_shape[x_batch_ndims]; - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - out_vec_size * M, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qvm_impl( - w, - scales, - biases, - x, - y, - in_vec_size, - out_vec_size, - tid, - simd_gid, - simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const int BM = 32, - const int BK = 32, - const int BN = 32> -[[kernel]] void affine_gather_qmm_t( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& K [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& M [[buffer(9)]], - const constant int& x_batch_ndims [[buffer(10)]], - const constant int* x_shape [[buffer(11)]], - const constant int64_t* x_strides [[buffer(12)]], - const constant int& w_batch_ndims [[buffer(13)]], - const constant int* w_shape [[buffer(14)]], - const constant int64_t* w_strides [[buffer(15)]], - const constant int64_t* s_strides [[buffer(16)]], - const constant int64_t* b_strides [[buffer(17)]], - const constant int& batch_ndims [[buffer(18)]], - const constant int* batch_shape [[buffer(19)]], - const constant int64_t* lhs_strides [[buffer(20)]], - const constant int64_t* rhs_strides [[buffer(21)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BN * BK_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmm_t_impl( - w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const int BM = 32, - const int BK = 32, - const int BN = 32> -[[kernel]] void affine_gather_qmm_n( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& K [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& M [[buffer(9)]], - const constant int& x_batch_ndims [[buffer(10)]], - const constant int* x_shape [[buffer(11)]], - const constant int64_t* x_strides [[buffer(12)]], - const constant int& w_batch_ndims [[buffer(13)]], - const constant int* w_shape [[buffer(14)]], - const constant int64_t* w_strides [[buffer(15)]], - const constant int64_t* s_strides [[buffer(16)]], - const constant int64_t* b_strides [[buffer(17)]], - const constant int& batch_ndims [[buffer(18)]], - const constant int* batch_shape [[buffer(19)]], - const constant int64_t* lhs_strides [[buffer(20)]], - const constant int64_t* rhs_strides [[buffer(21)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[BK * BN_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmm_n_impl( - w, scales, biases, x, y, Xs, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - int group_size, - int bits, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose> -[[kernel]] void affine_gather_qmm_rhs( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* indices [[buffer(4)]], - device T* y [[buffer(5)]], - const constant int& M [[buffer(6)]], - const constant int& N [[buffer(7)]], - const constant int& K [[buffer(8)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - using mma_t = mlx::steel::BlockMMA< - T, - T, - BM, - BN, - BK, - WM, - WN, - false, - transpose, - BK_padded, - transpose ? BK_padded : BN_padded>; - using loader_x_t = - mlx::steel::BlockLoader; - using loader_w_t = QuantizedBlockLoader< - T, - transpose ? BN : BK, - transpose ? BK : BN, - transpose ? BK_padded : BN_padded, - transpose, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - threadgroup T Xs[BM * BK_padded]; - threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; - - // Compute the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int N_w = N * bytes_per_pack / pack_factor; - const int N_g = N / group_size; - const int K_it = K / BK; - const size_t stride_w = transpose ? N * K_w : K * N_w; - const size_t stride_s = transpose ? N * K_g : K * N_g; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - const size_t y_row_long = size_t(y_row); - const size_t y_col_long = size_t(y_col); - - // Prepare threadgroup bounds - const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); - const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); - - // Calculate the final tiles in the case that K is not aligned - const int k_remain = K - K_it * BK; - const short2 tile_x = short2(k_remain, tgp_bm); - const short2 tile_w = - transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - // Move x and output to the correct block - auto wl = (const device uint8_t*)w; - x += y_row_long * K; - y += y_row_long * N + y_col_long; - wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; - scales += transpose ? y_col_long * K_g : y_col / group_size; - biases += transpose ? y_col_long * K_g : y_col / group_size; - - // Do as many matmuls as necessary - uint32_t index; - short offset; - uint32_t index_next = indices[y_row]; - short offset_next = 0; - int n = 0; - while (n < tgp_bm) { - n++; - offset = offset_next; - index = index_next; - offset_next = tgp_bm; - for (; n < tgp_bm; n++) { - if (indices[y_row + n] != index) { - offset_next = n; - index_next = indices[y_row + n]; - break; - } - } - threadgroup_barrier(mem_flags::mem_none); - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - // Prepare threadgroup loading operations - thread loader_x_t loader_x(x, K, Xs, simd_group_id, simd_lane_id); - thread loader_w_t loader_w( - wl + index * stride_w, - scales + index * stride_s, - biases + index * stride_s, - transpose ? K : N, - Ws, - simd_group_id, - simd_lane_id); - - // Matrices are all aligned check nothing - if (align_M && align_N) { - gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize(Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - - // Store results to device memory - if (offset_next - offset == BM) { - mma_op.store_result(y, N); - } else { - mma_op.store_result_slice( - y, N, short2(0, offset), short2(BN, offset_next)); - } - } else { - // Tile aligned so check outside of the hot loop - if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { - gemm_loop_aligned(Xs, Ws, mma_op, loader_x, loader_w, K_it); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - - // Store results to device memory - if (offset_next - offset == BM) { - mma_op.store_result(y, N); - } else { - mma_op.store_result_slice( - y, N, short2(0, offset), short2(BN, offset_next)); - } - } - - // Tile partially aligned check rows - else if (align_N || tgp_bn == BN) { - gemm_loop_unaligned( - Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - mma_op.store_result_slice( - y, N, short2(0, offset), short2(BN, offset_next)); - } - - // Tile partially aligned check cols - else if (align_M || tgp_bm == BM) { - gemm_loop_unaligned( - Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - mma_op.store_result_slice( - y, N, short2(0, offset), short2(tgp_bn, offset_next)); - } - - // Nothing aligned so check both rows and cols - else { - gemm_loop_unaligned( - Xs, Ws, mma_op, loader_x, loader_w, K_it, tgp_bm, tgp_bn, BK); - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - gemm_loop_finalize( - Xs, Ws, mma_op, loader_x, loader_w, tile_x, tile_w); - } - mma_op.store_result_slice( - y, N, short2(0, offset), short2(tgp_bn, offset_next)); - } - } - } -} - -template -[[kernel]] void affine_quantize( - const device T* w [[buffer(0)]], - device uint8_t* out [[buffer(1)]], - device T* scales [[buffer(2)]], - device T* biases [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - constexpr float eps = 1e-7; - constexpr int simd_size = 32; - constexpr float n_bins = (1 << bits) - 1; - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - constexpr int values_per_reduce = group_size / simd_size; - constexpr int writes_per_reduce = pack_factor / values_per_reduce; - constexpr int writes_per_pack = - writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - - static_assert( - group_size % simd_size == 0, - "Group size must be divisible by simd size."); - - size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t in_index = offset * values_per_reduce; - size_t out_index = power_of_2_bits - ? offset * writes_per_pack - : offset * bytes_per_pack / writes_per_reduce; - - float w_thread[values_per_reduce]; - float w_min = Limits::max; - float w_max = 0; - -#pragma clang loop unroll(full) - for (int i = 0; i < values_per_reduce; i++) { - float val = w[in_index + i]; - w_thread[i] = val; - w_min = min(w_min, val); - w_max = max(w_max, val); - } - - w_min = simd_min(w_min); - w_max = simd_max(w_max); - - float scale = max((w_max - w_min) / n_bins, eps); - bool side = abs(w_min) > abs(w_max); - scale = side ? scale : -scale; - float edge = side ? w_min : w_max; - float q0 = round(edge / scale); - bool at_zero = q0 == 0.0f; - scale = at_zero ? scale : edge / q0; - float bias = at_zero ? 0 : edge; - - // Write out the scales and biases - size_t gindex = in_index / group_size; - if (in_index % group_size == 0) { - scales[gindex] = static_cast(scale); - biases[gindex] = static_cast(bias); - } - - using OutType = metal::conditional_t; - OutType output = 0; - -#pragma clang loop unroll(full) - for (int i = 0; i < values_per_reduce; i++) { - uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); - if (bits == 8) { - output = val; - } else { - output |= val << (bits * (i % pack_factor)); - } - - if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { - out[out_index + i / pack_factor] = output; - output = 0; - } else { -#pragma clang loop unroll(full) - for (int j = 1; j < writes_per_reduce; j++) { - uint8_t sval = simd_shuffle_down(val, j); - output |= static_cast(sval) - << (bits * (j * values_per_reduce + i)); - } - } - } - if (bits == 3 || bits == 6) { - if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { - out[out_index] = output & 0xff; - out[out_index + 1] = (output & 0xff00) >> 8; - out[out_index + 2] = (output & 0xff0000) >> 16; - } - } else if (bits == 5) { - if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { - out[out_index] = output & 0xff; - out[out_index + 1] = (output & 0xff00) >> 8; - out[out_index + 2] = (output & 0xff0000) >> 16; - out[out_index + 3] = (output & 0xff000000) >> 24; - out[out_index + 4] = (output & 0xff00000000) >> 32; - } - } else { - if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { - out[out_index / writes_per_reduce] = output; - } - } -} - -template -[[kernel]] void affine_dequantize( - const device uint8_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - device T* out [[buffer(3)]], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t oindex = offset * pack_factor; - size_t gindex = oindex / group_size; - T scale = scales[gindex]; - T bias = biases[gindex]; - - out += oindex; - - if (bits == 3) { - w += offset * bytes_per_pack; - out[0] = (w[0] & 0x7) * scale + bias; - out[1] = ((w[0] & 0x38) >> 3) * scale + bias; - out[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; - out[3] = ((w[1] & 0xe) >> 1) * scale + bias; - out[4] = ((w[1] & 0x70) >> 4) * scale + bias; - out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; - out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; - out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; - } else if (bits == 5) { - w += offset * bytes_per_pack; - out[0] = (w[0] & 0x1f) * scale + bias; - out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; - out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; - out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; - out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; - out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; - out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; - out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; - } else if (bits == 6) { - w += offset * bytes_per_pack; - out[0] = (w[0] & 0x3f) * scale + bias; - out[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; - out[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; - out[3] = ((w[2] >> 2) & 0x3f) * scale + bias; - } else { - uint val = w[offset]; -#pragma clang loop unroll(full) - for (int i = 0; i < pack_factor; i++) { - uint8_t d; - if (bits == 2) { - d = (val >> (bits * i)) & 0x03; - } else if (bits == 4) { - d = (val >> (bits * i)) & 0x0f; - } else if (bits == 8) { - d = val; - } - out[i] = scale * d + bias; - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/quantized_nax.h b/Source/Cmlx/mlx-generated/metal/quantized_nax.h deleted file mode 100644 index c26ff646..00000000 --- a/Source/Cmlx/mlx-generated/metal/quantized_nax.h +++ /dev/null @@ -1,1705 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include - -using namespace metal; -using namespace mlx::steel; - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -using namespace metal; - -#define MLX_MTL_CONST static constant constexpr const - -MLX_MTL_CONST int SIMD_SIZE = 32; -MLX_MTL_CONST int QUAD_SIZE = 4; - -template -inline constexpr short get_pack_factor() { - return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); -} - -template -inline constexpr short get_bytes_per_pack() { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); -} - -template -inline U load_vector(const device T* x, thread U* x_thread) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - U sum = 0; - - if (bits == 2) { - for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 4.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 64.0f; - } - } - - else if (bits == 3) { - for (int i = 0; i < values_per_thread; i += 8) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + - x[i + 6] + x[i + 7]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 8.0f; - x_thread[i + 2] = x[i + 2] / 64.0f; - x_thread[i + 3] = x[i + 3] / 2.0f; - x_thread[i + 4] = x[i + 4] / 16.0f; - x_thread[i + 5] = x[i + 5] / 128.0f; - x_thread[i + 6] = x[i + 6] / 4.0f; - x_thread[i + 7] = x[i + 7] / 32.0f; - } - } - - else if (bits == 4) { - for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 16.0f; - x_thread[i + 2] = x[i + 2] / 256.0f; - x_thread[i + 3] = x[i + 3] / 4096.0f; - } - } - - else if (bits == 5) { - for (int i = 0; i < values_per_thread; i += 8) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + - x[i + 6] + x[i + 7]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 32.0f; - x_thread[i + 2] = x[i + 2] / 4.0f; - x_thread[i + 3] = x[i + 3] / 128.0f; - x_thread[i + 4] = x[i + 4] / 16.0f; - x_thread[i + 5] = x[i + 5] / 2.0f; - x_thread[i + 6] = x[i + 6] / 64.0f; - x_thread[i + 7] = x[i + 7] / 8.0f; - } - } - - else if (bits == 6) { - for (int i = 0; i < values_per_thread; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 64.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 4.0f; - } - } - - else if (bits == 8) { - for (int i = 0; i < values_per_thread; i++) { - sum += x[i]; - x_thread[i] = x[i]; - } - } - - return sum; -} - -template -inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - U sum = 0; - - if (bits == 2) { - for (int i = 0; i < N; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 4.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 64.0f; - } - } - - else if (bits == 3) { - for (int i = 0; i < N; i += 8) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + - x[i + 6] + x[i + 7]; - - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 8.0f; - x_thread[i + 2] = x[i + 2] / 64.0f; - x_thread[i + 3] = x[i + 3] / 2.0f; - x_thread[i + 4] = x[i + 4] / 16.0f; - x_thread[i + 5] = x[i + 5] / 128.0f; - x_thread[i + 6] = x[i + 6] / 4.0f; - x_thread[i + 7] = x[i + 7] / 32.0f; - } - } - - else if (bits == 4) { - for (int i = 0; i < N; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 16.0f; - x_thread[i + 2] = x[i + 2] / 256.0f; - x_thread[i + 3] = x[i + 3] / 4096.0f; - } - } - - else if (bits == 5) { - for (int i = 0; i < N; i += 8) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + - x[i + 6] + x[i + 7]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 32.0f; - x_thread[i + 2] = x[i + 2] / 4.0f; - x_thread[i + 3] = x[i + 3] / 128.0f; - x_thread[i + 4] = x[i + 4] / 16.0f; - x_thread[i + 5] = x[i + 5] / 2.0f; - x_thread[i + 6] = x[i + 6] / 64.0f; - x_thread[i + 7] = x[i + 7] / 8.0f; - } - } - - else if (bits == 6) { - for (int i = 0; i < N; i += 4) { - sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; - x_thread[i] = x[i]; - x_thread[i + 1] = x[i + 1] / 64.0f; - x_thread[i + 2] = x[i + 2] / 16.0f; - x_thread[i + 3] = x[i + 3] / 4.0f; - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - sum += x[i]; - x_thread[i] = x[i]; - } - } - - for (int i = N; i < values_per_thread; i++) { - x_thread[i] = 0; - } - - return sum; -} - -template -inline U qdot( - const device uint8_t* w, - const thread U* x_thread, - U scale, - U bias, - U sum) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - U accum = 0; - - if (bits == 2) { - for (int i = 0; i < (values_per_thread / 4); i++) { - accum += - (x_thread[4 * i] * (w[i] & 0x03) + - x_thread[4 * i + 1] * (w[i] & 0x0c) + - x_thread[4 * i + 2] * (w[i] & 0x30) + - x_thread[4 * i + 3] * (w[i] & 0xc0)); - } - } - - else if (bits == 3) { - for (int i = 0; i < (values_per_thread / 8); i++) { - x_thread += 8 * i; - w += 3 * i; - - accum += (w[0] & 0x07) * x_thread[0]; - accum += (w[0] & 0x38) * x_thread[1]; - accum += (w[0] & 0xc0) * x_thread[2]; - accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); - - accum += (w[1] & 0x0e) * x_thread[3]; - accum += (w[1] & 0x70) * x_thread[4]; - accum += (w[1] & 0x80) * x_thread[5]; - accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); - - accum += (w[2] & 0x1c) * x_thread[6]; - accum += (w[2] & 0xe0) * x_thread[7]; - } - } - - else if (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - for (int i = 0; i < (values_per_thread / 4); i++) { - accum += - (x_thread[4 * i] * (ws[i] & 0x000f) + - x_thread[4 * i + 1] * (ws[i] & 0x00f0) + - x_thread[4 * i + 2] * (ws[i] & 0x0f00) + - x_thread[4 * i + 3] * (ws[i] & 0xf000)); - } - } - - else if (bits == 5) { - for (int i = 0; i < (values_per_thread / 8); i++) { - x_thread += 8 * i; - w += 5 * i; - - accum += (w[0] & 0x1f) * x_thread[0]; - accum += (w[0] & 0xe0) * x_thread[1]; - accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); - accum += (w[1] & 0x7c) * x_thread[2]; - accum += (w[1] & 0x80) * x_thread[3]; - accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); - accum += (w[2] & 0xf0) * x_thread[4]; - accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); - accum += (w[3] & 0x3e) * x_thread[5]; - accum += (w[3] & 0xc0) * x_thread[6]; - accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); - accum += (w[4] & 0xf8) * x_thread[7]; - } - } - - else if (bits == 6) { - for (int i = 0; i < (values_per_thread / 4); i++) { - x_thread += 4 * i; - w += 3 * i; - - accum += (w[0] & 0x3f) * x_thread[0]; - - accum += (w[0] & 0xc0) * x_thread[1]; - accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); - - accum += (w[1] & 0xf0) * x_thread[2]; - accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); - - accum += (w[2] & 0xfc) * x_thread[3]; - } - } - - else if (bits == 8) { - for (int i = 0; i < values_per_thread; i++) { - accum += x_thread[i] * w[i]; - } - } - - return scale * accum + sum * bias; -} - -template -inline U qdot_safe( - const device uint8_t* w, - const thread U* x_thread, - U scale, - U bias, - U sum, - int N) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - U accum = 0; - - if (bits == 2) { - for (int i = 0; i < (N / 4); i++) { - accum += - (x_thread[4 * i] * (w[i] & 0x03) + - x_thread[4 * i + 1] * (w[i] & 0x0c) + - x_thread[4 * i + 2] * (w[i] & 0x30) + - x_thread[4 * i + 3] * (w[i] & 0xc0)); - } - } - - else if (bits == 3) { - for (int i = 0; i < (N / 8); i++) { - x_thread += 8 * i; - w += 3 * i; - - accum += (w[0] & 0x07) * x_thread[0]; - accum += (w[0] & 0x38) * x_thread[1]; - accum += (w[0] & 0xc0) * x_thread[2]; - accum += (w[1] & 0x01) * (x_thread[2] * 256.0f); - - accum += (w[1] & 0x0e) * x_thread[3]; - accum += (w[1] & 0x70) * x_thread[4]; - accum += (w[1] & 0x80) * x_thread[5]; - accum += (w[2] & 0x03) * (x_thread[5] * 256.0f); - - accum += (w[2] & 0x1c) * x_thread[6]; - accum += (w[2] & 0xe0) * x_thread[7]; - } - } - - else if (bits == 4) { - const device uint16_t* ws = (const device uint16_t*)w; - for (int i = 0; i < (N / 4); i++) { - accum += - (x_thread[4 * i] * (ws[i] & 0x000f) + - x_thread[4 * i + 1] * (ws[i] & 0x00f0) + - x_thread[4 * i + 2] * (ws[i] & 0x0f00) + - x_thread[4 * i + 3] * (ws[i] & 0xf000)); - } - } - - else if (bits == 5) { - for (int i = 0; i < (N / 8); i++) { - x_thread += 8 * i; - w += 5 * i; - - accum += (w[0] & 0x1f) * x_thread[0]; - accum += (w[0] & 0xe0) * x_thread[1]; - accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); - accum += (w[1] & 0x7c) * x_thread[2]; - accum += (w[1] & 0x80) * x_thread[3]; - accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); - accum += (w[2] & 0xf0) * x_thread[4]; - accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); - accum += (w[3] & 0x3e) * x_thread[5]; - accum += (w[3] & 0xc0) * x_thread[6]; - accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); - accum += (w[4] & 0xf8) * x_thread[7]; - } - } - - else if (bits == 6) { - for (int i = 0; i < (N / 4); i++) { - x_thread += 4 * i; - w += 3 * i; - - accum += (w[0] & 0x3f) * x_thread[0]; - - accum += (w[0] & 0xc0) * x_thread[1]; - accum += (w[1] & 0x0f) * (x_thread[1] * 256.0f); - - accum += (w[1] & 0xf0) * x_thread[2]; - accum += (w[2] & 0x03) * (x_thread[2] * 256.0f); - - accum += (w[2] & 0xfc) * x_thread[3]; - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - accum += x_thread[i] * w[i]; - } - } - - return scale * accum + sum * bias; -} - -template -inline void -qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - if (bits == 2) { - U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; - for (int i = 0; i < (values_per_thread / 4); i++) { - result[4 * i] += x * (s[0] * (w[i] & 0x03) + bias); - result[4 * i + 1] += x * (s[1] * (w[i] & 0x0c) + bias); - result[4 * i + 2] += x * (s[2] * (w[i] & 0x30) + bias); - result[4 * i + 3] += x * (s[3] * (w[i] & 0xc0) + bias); - } - } - - else if (bits == 3) { - for (int i = 0; i < (values_per_thread / 8); i++) { - uint8_t w0 = w[3 * i]; - uint8_t w1 = w[3 * i + 1]; - uint8_t w2 = w[3 * i + 2]; - - result[8 * i] += x * ((w0 & 0x7) * scale + bias); - result[8 * i + 1] += x * (((w0 & 0x38) >> 3) * scale + bias); - result[8 * i + 2] += - x * ((((w0 & 0xc0) >> 6) + ((w1 & 0x1) << 2)) * scale + bias); - result[8 * i + 3] += x * (((w1 & 0xe) >> 1) * scale + bias); - result[8 * i + 4] += x * (((w1 & 0x70) >> 4) * scale + bias); - result[8 * i + 5] += - x * ((((w1 & 0x80) >> 7) + ((w2 & 0x3) << 1)) * scale + bias); - result[8 * i + 6] += x * (((w2 & 0x1c) >> 2) * scale + bias); - result[8 * i + 7] += x * (((w2 & 0xe0) >> 5) * scale + bias); - } - } - - else if (bits == 4) { - U s[2] = {scale, scale / 16.0f}; - for (int i = 0; i < (values_per_thread / 2); i++) { - result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); - result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); - } - } - - else if (bits == 5) { - for (int i = 0; i < (values_per_thread / 8); i++) { - uint8_t w0 = w[5 * i]; - uint8_t w1 = w[5 * i + 1]; - uint8_t w2 = w[5 * i + 2]; - uint8_t w3 = w[5 * i + 3]; - uint8_t w4 = w[5 * i + 4]; - result[8 * i] += x * ((w0 & 0x1f) * scale + bias); - result[8 * i + 1] += - x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); - result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); - result[8 * i + 3] += - x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); - result[8 * i + 4] += - x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); - result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); - result[8 * i + 6] += - x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); - result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); - } - } - - else if (bits == 6) { - for (int i = 0; i < (values_per_thread / 4); i++) { - uint8_t w0 = w[3 * i]; - uint8_t w1 = w[3 * i + 1]; - uint8_t w2 = w[3 * i + 2]; - - result[4 * i] += x * ((w0 & 0x3f) * scale + bias); - result[4 * i + 1] += - x * ((((w0 >> 6) & 0x03) + ((w1 & 0x0f) << 2)) * scale + bias); - result[4 * i + 2] += - x * ((((w1 >> 4) & 0x0f) + ((w2 & 0x03) << 4)) * scale + bias); - result[4 * i + 3] += x * (((w2 >> 2) & 0x3f) * scale + bias); - } - } - - else if (bits == 8) { - for (int i = 0; i < values_per_thread; i++) { - result[i] += x * (scale * w[i] + bias); - } - } -} - -template -inline void -dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - if (bits == 2) { - U s[4] = { - scale, - scale / static_cast(4.0f), - scale / static_cast(16.0f), - scale / static_cast(64.0f)}; - for (int i = 0; i < (N / 4); i++) { - w_local[4 * i] = s[0] * (w[i] & 0x03) + bias; - w_local[4 * i + 1] = s[1] * (w[i] & 0x0c) + bias; - w_local[4 * i + 2] = s[2] * (w[i] & 0x30) + bias; - w_local[4 * i + 3] = s[3] * (w[i] & 0xc0) + bias; - } - } - - else if (bits == 3) { - for (int i = 0; i < (N / 8); i++) { - w_local += 8 * i; - w += 3 * i; - - w_local[0] = (w[0] & 0x7) * scale + bias; - w_local[1] = ((w[0] & 0x38) >> 3) * scale + bias; - w_local[2] = (((w[0] & 0xc0) >> 6) + ((w[1] & 0x1) << 2)) * scale + bias; - w_local[3] = ((w[1] & 0xe) >> 1) * scale + bias; - w_local[4] = ((w[1] & 0x70) >> 4) * scale + bias; - w_local[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; - w_local[6] = ((w[2] & 0x1c) >> 2) * scale + bias; - w_local[7] = ((w[2] & 0xe0) >> 5) * scale + bias; - } - } - - else if (bits == 4) { - U s[2] = {scale, scale / static_cast(16.0f)}; - for (int i = 0; i < (N / 2); i++) { - w_local[2 * i] = s[0] * (w[i] & 0x0f) + bias; - w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; - } - } - - else if (bits == 5) { - for (int i = 0; i < (N / 8); i++) { - w_local += 8 * i; - w += 5 * i; - - w_local[0] = (w[0] & 0x1f) * scale + bias; - w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; - w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; - w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; - w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; - w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; - w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; - w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; - } - } - - else if (bits == 6) { - for (int i = 0; i < (N / 4); i++) { - w_local += 4 * i; - w += 3 * i; - w_local[0] = (w[0] & 0x3f) * scale + bias; - w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; - w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; - w_local[3] = ((w[2] >> 2) & 0x3f) * scale + bias; - } - } - - else if (bits == 8) { - for (int i = 0; i < N; i++) { - w_local[i] = scale * w[i] + bias; - } - } -} - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short group_size, - short bits> -struct QuantizedBlockLoader { - static_assert( - BCOLS <= group_size, - "The group size should be larger than the columns"); - static_assert( - group_size % BCOLS == 0, - "The group size should be divisible by the columns"); - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - MLX_MTL_CONST short pack_factor = get_pack_factor(); - MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); - MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; - MLX_MTL_CONST short n_reads = - (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; - MLX_MTL_CONST short group_steps = group_size / BCOLS; - - const int src_ld; - const int tile_stride; - short group_step_cnt; - const int group_stride; - - const short thread_idx; - const short bi; - const short bj; - - threadgroup T* dst; - const device uint8_t* src; - const device T* scales; - const device T* biases; - - QuantizedBlockLoader( - const device uint8_t* src_, - const device T* scales_, - const device T* biases_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride( - reduction_dim ? BCOLS_PACKED * bytes_per_pack - : BROWS * src_ld * bytes_per_pack / pack_factor), - group_step_cnt(0), - group_stride(BROWS * src_ld / group_size), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(n_reads * thread_idx / BCOLS_PACKED), - bj((n_reads * thread_idx) % BCOLS_PACKED), - dst(dst_ + bi * dst_ld + bj * pack_factor), - src(src_ + bi * src_ld * bytes_per_pack / pack_factor + - bj * bytes_per_pack), - scales(scales_ + bi * src_ld / group_size), - biases(biases_ + bi * src_ld / group_size) {} - - void load_unsafe() const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - T scale = *scales; - T bias = *biases; - for (int i = 0; i < n_reads; i++) { - dequantize( - src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); - } - } - - void load_safe(short2 src_tile_dim) const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - if (reduction_dim == 1 && bi >= src_tile_dim.x) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - if (reduction_dim == 0 && bi >= src_tile_dim.y) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - T scale = *scales; - T bias = *biases; - for (int i = 0; i < n_reads; i++) { - dequantize( - (device uint8_t*)(src + i * bytes_per_pack), - scale, - bias, - dst + i * pack_factor); - } - } - - void next() { - src += tile_stride; - if (reduction_dim == 1) { - if (group_steps > 1) { - group_step_cnt++; - if (group_step_cnt == group_steps) { - group_step_cnt = 0; - scales++; - biases++; - } - } else { - scales++; - biases++; - } - } else { - scales += group_stride; - biases += group_stride; - } - } -}; - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short bits> -struct QuantizedBlockLoader< - T, - BROWS, - BCOLS, - dst_ld, - reduction_dim, - tgp_size, - 32, - bits> { - MLX_MTL_CONST short group_size = 32; - - static_assert( - BCOLS % group_size == 0, - "The group size should be divisible by the columns"); - static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || - bits == 8, - "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - - MLX_MTL_CONST short pack_factor = get_pack_factor(); - MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); - MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; - MLX_MTL_CONST short n_reads = - (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; - MLX_MTL_CONST short n_groups = BCOLS / group_size; - - static_assert( - (BCOLS_PACKED / n_reads) == n_groups, - "Other configurations are not yet supported"); - - const int src_ld; - const int tile_stride; - const int group_stride; - - const short thread_idx; - const short bi; - const short bj; - - const short group_id; - - threadgroup T* dst; - const device uint8_t* src; - const device T* scales; - const device T* biases; - - QuantizedBlockLoader( - const device uint8_t* src_, - const device T* scales_, - const device T* biases_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride( - reduction_dim ? BCOLS_PACKED * bytes_per_pack - : BROWS * src_ld * bytes_per_pack / pack_factor), - group_stride(BROWS * src_ld / group_size), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(n_reads * thread_idx / BCOLS_PACKED), - bj((n_reads * thread_idx) % BCOLS_PACKED), - group_id((bj * pack_factor) / group_size), - dst(dst_ + bi * dst_ld + bj * pack_factor), - src(src_ + bi * src_ld * bytes_per_pack / pack_factor + - bj * bytes_per_pack), - scales(scales_ + bi * src_ld / group_size + group_id), - biases(biases_ + bi * src_ld / group_size + group_id) {} - - void load_unsafe() const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - T scale = *scales; - T bias = *biases; - for (int i = 0; i < n_reads; i++) { - dequantize( - src + i * bytes_per_pack, scale, bias, dst + i * pack_factor); - } - } - - void load_safe(short2 src_tile_dim) const { - if (BCOLS_PACKED * BROWS < tgp_size && bi >= BROWS) { - return; - } - - if (reduction_dim == 1 && bi >= src_tile_dim.x) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - if (reduction_dim == 0 && bi >= src_tile_dim.y) { - for (int i = 0; i < n_reads * pack_factor; i++) { - dst[i] = T(0); - } - return; - } - - T scale = *scales; - T bias = *biases; - for (int i = 0; i < n_reads; i++) { - dequantize( - (device uint8_t*)(src + i * bytes_per_pack), - scale, - bias, - dst + i * pack_factor); - } - } - - void next() { - src += tile_stride; - if (reduction_dim == 1) { - // if (group_steps > 1) { - // group_step_cnt++; - // if (group_step_cnt == group_steps) { - // group_step_cnt = 0; - // scales++; - // biases++; - // } - // } else { - scales += n_groups; - biases += n_groups; - // } - } else { - scales += n_groups * group_stride; - biases += n_groups * group_stride; - } - } -}; - -template -METAL_FUNC void adjust_matrix_offsets( - const device T*& x, - const device uint32_t*& w, - const device T*& scales, - const device T*& biases, - device T*& y, - int output_stride, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int64_t* b_strides, - uint3 tid [[threadgroup_position_in_grid]]) { - // Set the input/output matrices - uint32_t x_idx = tid.z; - uint32_t w_idx = tid.z; - if (x_batch_ndims == 1) { - x += x_idx * x_strides[0]; - } else { - x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); - } - if (w_batch_ndims == 1) { - w += w_idx * w_strides[0]; - scales += w_idx * s_strides[0]; - biases += w_idx * b_strides[0]; - } else { - ulong3 idx = elem_to_loc_broadcast( - w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); - w += idx.x; - scales += idx.y; - biases += idx.z; - } - y += tid.z * output_stride; -} - -template -METAL_FUNC void adjust_matrix_offsets( - const device T*& x, - const device uint32_t*& w, - const device T*& scales, - const device T*& biases, - const device uint32_t* lhs_indices, - const device uint32_t* rhs_indices, - device T*& y, - int output_stride, - const constant int& batch_ndims, - const constant int* batch_shape, - const constant int64_t* lhs_strides, - const constant int64_t* rhs_strides, - const constant int& x_batch_ndims, - const constant int* x_shape, - const constant int64_t* x_strides, - const constant int& w_batch_ndims, - const constant int* w_shape, - const constant int64_t* w_strides, - const constant int64_t* s_strides, - const constant int64_t* b_strides, - uint3 tid [[threadgroup_position_in_grid]]) { - // Set the input/output matrices - uint32_t x_idx; - uint32_t w_idx; - if (batch_ndims == 1) { - x_idx = lhs_indices[tid.z * lhs_strides[0]]; - w_idx = rhs_indices[tid.z * rhs_strides[0]]; - } else { - ulong2 idx = elem_to_loc_broadcast( - tid.z, batch_shape, lhs_strides, rhs_strides, batch_ndims); - x_idx = lhs_indices[idx.x]; - w_idx = rhs_indices[idx.y]; - } - if (x_batch_ndims == 1) { - x += x_idx * x_strides[0]; - } else { - x += elem_to_loc(x_idx, x_shape, x_strides, x_batch_ndims); - } - if (w_batch_ndims == 1) { - w += w_idx * w_strides[0]; - scales += w_idx * s_strides[0]; - biases += w_idx * b_strides[0]; - } else { - ulong3 idx = elem_to_loc_broadcast( - w_idx, w_shape, w_strides, s_strides, b_strides, w_batch_ndims); - w += idx.x; - scales += idx.y; - biases += idx.z; - } - y += tid.z * output_stride; -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2> -METAL_FUNC void qmm_t_nax_tgp_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - threadgroup T* Ws, - const constant int& K, - const constant int& N, - const constant int& M, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - (void)lid; - - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - using loader_w_t = QuantizedBlockLoader< - T, - BN, - BK, - BK_padded, - 1, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - // Set the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - - auto wl = (const device uint8_t*)w; - - x += y_row * static_cast(K); - wl += y_col * K_w; - scales += y_col * K_g; - biases += y_col * K_g; - y += y_row * static_cast(N) + y_col; - - // Make the weight loader - loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); - - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; - constexpr short SM = BM / WM; - constexpr short SN = BN / WN; - constexpr short SK = 32; - - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; - - const short tm = SM * (simd_gid / WN); - const short tn = SN * (simd_gid % WN); - - constexpr bool transpose_a = false; - constexpr bool transpose_b = true; - - const short sgp_sm = min(SM, short(M - (y_row + tm))); - const bool is_unaligned_sm = (sgp_sm != SM); - - const short sgp_sn = aligned_N ? SN : min(SN, short(N - (y_col + tn))); - - const short tgp_bn = aligned_N ? BN : min(BN, int(N - (y_col))); - const bool is_unaligned_bn = aligned_N ? false : (tgp_bn != BN); - - using AccumType = float; - - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - NAXTile Dtile; - - Dtile.clear(); - - x += tm * K; - - dispatch_bool(!is_unaligned_sm, [&](auto kAlignedM) { - dispatch_bool(aligned_N || !is_unaligned_bn, [&](auto kAlignedN) { - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - if constexpr (kAlignedN.value) { - loader_w.load_unsafe(); - } else { - loader_w.load_safe(short2(BK, tgp_bn)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - volatile int compiler_barrier; - - if constexpr (kAlignedM.value) { - Atile.load(x + kk1, K); - } else { - Atile.load_safe(x + kk1, K, short2(SK, sgp_sm)); - } - - Btile.template load(Ws + tn * BK_padded + kk1); - - tile_matmad_nax( - Dtile, - Atile, - metal::bool_constant{}, - Btile, - metal::bool_constant{}); - - (void)compiler_barrier; - } - - x += BK; - loader_w.next(); - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - - if constexpr (kAlignedM.value && kAlignedN.value) { - Dtile.store(y + tm * N + tn, N); - } else if (kAlignedM.value && sgp_sn == SN) { - Dtile.store(y + tm * N + tn, N); - } else { - Dtile.store_safe(y + tm * N + tn, N, short2(sgp_sn, sgp_sm)); - } - }); - }); -} - -template < - typename T, - const int group_size, - const int bits, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2> -METAL_FUNC void qmm_n_nax_tgp_impl( - const device uint32_t* w, - const device T* scales, - const device T* biases, - const device T* x, - device T* y, - threadgroup T* Ws, - const constant int& K, - const constant int& N, - const constant int& M, - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - (void)M; - - static_assert(BK >= SIMD_SIZE, "BK should be larger than SIMD_SIZE"); - static_assert(BK % SIMD_SIZE == 0, "BK should be divisible by SIMD_SIZE"); - - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - using loader_w_t = QuantizedBlockLoader< - T, - BK, - BN, - BN_padded, - 0, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - // Set the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - - auto wl = (const device uint8_t*)w; - - x += y_row * static_cast(K); - wl += y_col * K_w; - scales += y_col * K_g; - biases += y_col * K_g; - y += y_row * static_cast(N) + y_col; - - // Make the x loader and mma operation - // const short num_els = min(BM, M - y_row); - // const short num_outs = min(BN, N - y_col); - loader_w_t loader_w(wl, scales, biases, K, Ws, simd_gid, simd_lid); - - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; - constexpr short SM = BM / WM; - constexpr short SN = BN / WN; - constexpr short SK = 32; - - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; - - const short tm = SM * (simd_gid / WN); - const short tn = SN * (simd_gid % WN); - - const short ldb_tgp = BN_padded; - - constexpr bool transpose_a = false; - constexpr bool transpose_b = false; - - using AccumType = float; - - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - NAXTile Dtile; - - Dtile.clear(); - - x += tm * K; - - for (int k = 0; k < K; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_w.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - volatile int compiler_barrier; - - Atile.load(x + kk1, K); - Btile.template load(Ws + tn + kk1 * ldb_tgp); - - tile_matmad_nax( - Dtile, - Atile, - metal::bool_constant{}, - Btile, - metal::bool_constant{}); - - (void)compiler_barrier; - } - - x += BK; - loader_w.next(); - } - - // Store results to device memory - threadgroup_barrier(mem_flags::mem_threadgroup); - - Dtile.store(y + tm * N + tn, N); -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const bool batched, - const int BM = 64, - const int BK = 32, - const int BN = 64, - const int WM = 2, - const int WN = 2> -[[kernel]] void affine_qmm_t_nax( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - threadgroup T Ws[BN * BK_padded]; - - if (batched) { - adjust_matrix_offsets( - x, - w, - scales, - biases, - y, - M * N, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - } - qmm_t_nax_tgp_impl( - w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool batched, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2> -[[kernel]] void affine_qmm_n_nax( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - device T* y [[buffer(4)]], - const constant int& K [[buffer(5)]], - const constant int& N [[buffer(6)]], - const constant int& M [[buffer(7)]], - const constant int& x_batch_ndims [[buffer(8)]], - const constant int* x_shape [[buffer(9)]], - const constant int64_t* x_strides [[buffer(10)]], - const constant int& w_batch_ndims [[buffer(11)]], - const constant int* w_shape [[buffer(12)]], - const constant int64_t* w_strides [[buffer(13)]], - const constant int64_t* s_strides [[buffer(14)]], - const constant int64_t* b_strides [[buffer(15)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Ws[BK * BN_padded]; - - if (batched) { - adjust_matrix_offsets( - x, - w, - scales, - biases, - y, - M * N, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - } - - qmm_n_nax_tgp_impl( - w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const bool aligned_N, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2> -[[kernel]] void affine_gather_qmm_t_nax( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& K [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& M [[buffer(9)]], - const constant int& x_batch_ndims [[buffer(10)]], - const constant int* x_shape [[buffer(11)]], - const constant int64_t* x_strides [[buffer(12)]], - const constant int& w_batch_ndims [[buffer(13)]], - const constant int* w_shape [[buffer(14)]], - const constant int64_t* w_strides [[buffer(15)]], - const constant int64_t* s_strides [[buffer(16)]], - const constant int64_t* b_strides [[buffer(17)]], - const constant int& batch_ndims [[buffer(18)]], - const constant int* batch_shape [[buffer(19)]], - const constant int64_t* lhs_strides [[buffer(20)]], - const constant int64_t* rhs_strides [[buffer(21)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BK_padded = (BK + 16 / sizeof(T)); - - threadgroup T Ws[BN * BK_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmm_t_nax_tgp_impl( - w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - const int group_size, - const int bits, - const int BM = 64, - const int BK = 64, - const int BN = 64, - const int WM = 2, - const int WN = 2> -[[kernel]] void affine_gather_qmm_n_nax( - const device uint32_t* w [[buffer(0)]], - const device T* scales [[buffer(1)]], - const device T* biases [[buffer(2)]], - const device T* x [[buffer(3)]], - const device uint32_t* lhs_indices [[buffer(4)]], - const device uint32_t* rhs_indices [[buffer(5)]], - device T* y [[buffer(6)]], - const constant int& K [[buffer(7)]], - const constant int& N [[buffer(8)]], - const constant int& M [[buffer(9)]], - const constant int& x_batch_ndims [[buffer(10)]], - const constant int* x_shape [[buffer(11)]], - const constant int64_t* x_strides [[buffer(12)]], - const constant int& w_batch_ndims [[buffer(13)]], - const constant int* w_shape [[buffer(14)]], - const constant int64_t* w_strides [[buffer(15)]], - const constant int64_t* s_strides [[buffer(16)]], - const constant int64_t* b_strides [[buffer(17)]], - const constant int& batch_ndims [[buffer(18)]], - const constant int* batch_shape [[buffer(19)]], - const constant int64_t* lhs_strides [[buffer(20)]], - const constant int64_t* rhs_strides [[buffer(21)]], - uint3 tid [[threadgroup_position_in_grid]], - uint lid [[thread_index_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - threadgroup T Ws[BK * BN_padded]; - - adjust_matrix_offsets( - x, - w, - scales, - biases, - lhs_indices, - rhs_indices, - y, - M * N, - batch_ndims, - batch_shape, - lhs_strides, - rhs_strides, - x_batch_ndims, - x_shape, - x_strides, - w_batch_ndims, - w_shape, - w_strides, - s_strides, - b_strides, - tid); - qmm_n_nax_tgp_impl( - w, scales, biases, x, y, Ws, K, N, M, tid, lid, simd_gid, simd_lid); -} - -template < - typename T, - int group_size, - int bits, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose> -[[kernel]] void affine_gather_qmm_rhs_nax( - const device T* x [[buffer(0)]], - const device uint32_t* w [[buffer(1)]], - const device T* scales [[buffer(2)]], - const device T* biases [[buffer(3)]], - const device uint32_t* indices [[buffer(4)]], - device T* y [[buffer(5)]], - const constant int& M [[buffer(6)]], - const constant int& N [[buffer(7)]], - const constant int& K [[buffer(8)]], - uint3 tid [[threadgroup_position_in_grid]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = get_pack_factor(); - constexpr int bytes_per_pack = get_bytes_per_pack(); - constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int BN_padded = (BN + 16 / sizeof(T)); - - using loader_w_t = QuantizedBlockLoader< - T, - transpose ? BN : BK, - transpose ? BK : BN, - transpose ? BK_padded : BN_padded, - transpose, - WM * WN * SIMD_SIZE, - group_size, - bits>; - - threadgroup T Ws[transpose ? BN * BK_padded : BK * BN_padded]; - - // Compute the block - const int K_w = K * bytes_per_pack / pack_factor; - const int K_g = K / group_size; - const int N_w = N * bytes_per_pack / pack_factor; - const int N_g = N / group_size; - const int K_it = K / BK; - const size_t stride_w = transpose ? N * K_w : K * N_w; - const size_t stride_s = transpose ? N * K_g : K * N_g; - const int y_row = tid.y * BM; - const int y_col = tid.x * BN; - const size_t y_row_long = size_t(y_row); - const size_t y_col_long = size_t(y_col); - - // Prepare threadgroup bounds - const short tgp_bm = align_M ? BM : short(min(BM, M - y_row)); - const short tgp_bn = align_N ? BN : short(min(BN, N - y_col)); - - // Calculate the final tiles in the case that K is not aligned - const int k_remain = K - K_it * BK; - const short2 tile_w = - transpose ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - // Move x and output to the correct block - auto wl = (const device uint8_t*)w; - x += y_row_long * K; - y += y_row_long * N + y_col_long; - wl += transpose ? y_col_long * K_w : y_col * bytes_per_pack / pack_factor; - scales += transpose ? y_col_long * K_g : y_col / group_size; - biases += transpose ? y_col_long * K_g : y_col / group_size; - - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; - constexpr short SM = BM / WM; - constexpr short SN = BN / WN; - constexpr short SK = 32; - - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; - - const short tm = SM * (simd_group_id / WN); - const short tn = SN * (simd_group_id % WN); - - const short sgp_sm = - align_M ? SM : min(SM, short(max(0, (M - (y_row + tm))))); - const short sgp_sn = - align_N ? SN : min(SN, short(max(0, (N - (y_col + tn))))); - - const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); - const bool is_unaligned_bn = align_N ? false : (tgp_bn != BN); - - constexpr short BR = transpose ? TN : TK; - constexpr short BC = transpose ? TK : TN; - - using AccumType = float; - - using ASubTile = NAXSubTile; - using BSubTile = NAXSubTile; - using DSubTile = NAXSubTile; - - // Do as many matmuls as necessary - uint32_t index; - short offset; - uint32_t index_next = indices[y_row]; - short offset_next = 0; - int n = 0; - while (n < tgp_bm) { - n++; - offset = offset_next; - index = index_next; - offset_next = tgp_bm; - for (; n < tgp_bm; n++) { - if (indices[y_row + n] != index) { - offset_next = n; - index_next = indices[y_row + n]; - break; - } - } - threadgroup_barrier(mem_flags::mem_none); - - NAXTile Dtile; - - Dtile.clear(); - - const device T* xn = x + tm * K; - - // Prepare threadgroup loading operations - thread loader_w_t loader_w( - wl + index * stride_w, - scales + index * stride_s, - biases + index * stride_s, - transpose ? K : N, - Ws, - simd_group_id, - simd_lane_id); - - dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { - dispatch_bool(align_N || !is_unaligned_bn, [&](auto kAlignedN) { - for (int k = 0; k < K_it; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - if constexpr (kAlignedN.value) { - loader_w.load_unsafe(); - } else { - loader_w.load_safe( - transpose ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - volatile int compiler_barrier; - - if constexpr (kAlignedM.value) { - Atile.load(xn + kk1, K); - } else { - Atile.load_safe(xn + kk1, K, short2(SK, sgp_sm)); - } - - if constexpr (transpose) { - Btile.template load(Ws + tn * BK_padded + kk1); - } else { - Btile.template load(Ws + tn + kk1 * BN_padded); - } - - tile_matmad_nax( - Dtile, - Atile, - metal::bool_constant{}, - Btile, - metal::bool_constant{}); - - (void)compiler_barrier; - } - - xn += BK; - loader_w.next(); - } - - if (!align_K) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_w.load_safe(tile_w); - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - volatile int compiler_barrier; - - const short psk = min(int(SK), max(0, (BK - kk1))); - Atile.load_safe(xn + kk1, K, short2(psk, sgp_sm)); - - if constexpr (transpose) { - Btile.template load(Ws + tn * BK_padded + kk1); - } else { - Btile.template load(Ws + tn + kk1 * BN_padded); - } - - tile_matmad_nax( - Dtile, - Atile, - metal::bool_constant{}, - Btile, - metal::bool_constant{}); - - (void)compiler_barrier; - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - const short m_lo_lim = min(int(sgp_sm), max(0, offset - tm)); - const short m_hi_lim = min(int(sgp_sm), max(0, offset_next - tm)); - - // Store results to device memory - if constexpr (kAlignedN.value) { - if (m_lo_lim == 0 && m_hi_lim == SM) { - Dtile.store(y + tm * N + tn, N); - } else { - Dtile.store_slice( - y + tm * N + tn, N, short2(0, m_lo_lim), short2(SN, m_hi_lim)); - } - } else { - Dtile.store_slice( - y + tm * N + tn, - N, - short2(0, m_lo_lim), - short2(sgp_sn, m_hi_lim)); - } - }); - }); - } -} \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/quantized_utils.h b/Source/Cmlx/mlx-generated/metal/quantized_utils.h deleted file mode 100644 index 38253f8f..00000000 --- a/Source/Cmlx/mlx-generated/metal/quantized_utils.h +++ /dev/null @@ -1,90 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include -#include - -template -METAL_FUNC void gemm_loop_aligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template < - bool rows_aligned, - bool cols_aligned, - bool transpose, - typename T, - typename mma_t, - typename loader_a_t, - typename loader_b_t> -METAL_FUNC void gemm_loop_unaligned( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const int k_iterations, - const short tgp_bm, - const short tgp_bn, - const short tgp_bk) { - for (int k = 0; k < k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup memory - if (rows_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(short2(tgp_bk, tgp_bm)); - } - if (cols_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe( - transpose ? short2(tgp_bk, tgp_bn) : short2(tgp_bn, tgp_bk)); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } -} - -template -METAL_FUNC void gemm_loop_finalize( - threadgroup T* As, - threadgroup T* Bs, - thread mma_t& mma_op, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - const short2 tile_a, - const short2 tile_b) { - loader_a.load_safe(tile_a); - loader_b.load_safe(tile_b); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); -} diff --git a/Source/Cmlx/mlx-generated/metal/random.metal b/Source/Cmlx/mlx-generated/metal/random.metal deleted file mode 100644 index eb6234d8..00000000 --- a/Source/Cmlx/mlx-generated/metal/random.metal +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright © 2023 Apple Inc. - -#include "utils.h" - -static constexpr constant uint32_t rotations[2][4] = { - {13, 15, 26, 6}, - {17, 29, 16, 24}}; - -union rbits { - uint2 val; - uchar4 bytes[2]; -}; - -rbits threefry2x32_hash(const thread uint2& key, uint2 count) { - uint4 ks = {key.x, key.y, key.x ^ key.y ^ 0x1BD11BDA}; - - rbits v; - v.val.x = count.x + ks[0]; - v.val.y = count.y + ks[1]; - - for (int i = 0; i < 5; ++i) { - for (auto r : rotations[i % 2]) { - v.val.x += v.val.y; - v.val.y = (v.val.y << r) | (v.val.y >> (32 - r)); - v.val.y ^= v.val.x; - } - v.val.x += ks[(i + 1) % 3]; - v.val.y += ks[(i + 2) % 3] + i + 1; - } - - return v; -} - -[[kernel]] void rbitsc( - device const uint32_t* keys, - device char* out, - constant const bool& odd, - constant const uint& bytes_per_key, - uint2 grid_dim [[threads_per_grid]], - uint2 index [[thread_position_in_grid]]) { - auto kidx = 2 * index.x; - auto key = uint2(keys[kidx], keys[kidx + 1]); - auto half_size = grid_dim.y - odd; - out += index.x * bytes_per_key; - bool drop_last = odd && (index.y == half_size); - auto bits = threefry2x32_hash( - key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); - size_t idx = size_t(index.y) << 2; - for (int i = 0; i < 4; ++i) { - out[idx + i] = bits.bytes[0][i]; - } - if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { - int edge_bytes = (bytes_per_key % 4); - for (int i = 0; i < edge_bytes; ++i) { - out[idx + i] = bits.bytes[1][i]; - } - } else { - for (int i = 0; i < 4; ++i) { - out[idx + i] = bits.bytes[1][i]; - } - } - } -} - -[[kernel]] void rbits( - device const uint32_t* keys, - device char* out, - constant const bool& odd, - constant const uint& bytes_per_key, - constant const int& ndim, - constant const int* key_shape, - constant const int64_t* key_strides, - uint2 grid_dim [[threads_per_grid]], - uint2 index [[thread_position_in_grid]]) { - auto kidx = 2 * index.x; - auto k1_elem = elem_to_loc(kidx, key_shape, key_strides, ndim); - auto k2_elem = elem_to_loc(kidx + 1, key_shape, key_strides, ndim); - auto key = uint2(keys[k1_elem], keys[k2_elem]); - auto half_size = grid_dim.y - odd; - out += size_t(index.x) * bytes_per_key; - bool drop_last = odd && (index.y == half_size); - auto bits = threefry2x32_hash( - key, uint2(index.y, drop_last ? 0 : index.y + grid_dim.y)); - size_t idx = size_t(index.y) << 2; - for (int i = 0; i < 4; ++i) { - out[idx + i] = bits.bytes[0][i]; - } - if (!drop_last) { - idx = (drop_last ? 0 : size_t(index.y) + grid_dim.y) << 2; - if ((index.y + 1) == half_size && (bytes_per_key % 4) > 0) { - int edge_bytes = (bytes_per_key % 4); - for (int i = 0; i < edge_bytes; ++i) { - out[idx + i] = bits.bytes[1][i]; - } - } else { - for (int i = 0; i < 4; ++i) { - out[idx + i] = bits.bytes[1][i]; - } - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/reduce.h b/Source/Cmlx/mlx-generated/metal/reduce.h deleted file mode 100644 index 8d1f609d..00000000 --- a/Source/Cmlx/mlx-generated/metal/reduce.h +++ /dev/null @@ -1,5 +0,0 @@ -#pragma once -#include "reduction/reduce_all.h" -#include "reduction/reduce_col.h" -#include "reduction/reduce_init.h" -#include "reduction/reduce_row.h" diff --git a/Source/Cmlx/mlx-generated/metal/reduce_utils.h b/Source/Cmlx/mlx-generated/metal/reduce_utils.h deleted file mode 100644 index f5ccc3f1..00000000 --- a/Source/Cmlx/mlx-generated/metal/reduce_utils.h +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "atomic.h" -#include "reduction/ops.h" diff --git a/Source/Cmlx/mlx-generated/metal/reduction/ops.h b/Source/Cmlx/mlx-generated/metal/reduction/ops.h deleted file mode 100644 index 11d8e83a..00000000 --- a/Source/Cmlx/mlx-generated/metal/reduction/ops.h +++ /dev/null @@ -1,275 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#pragma once - -#include -#include - -#define DEFINE_SIMD_REDUCE() \ - template = true> \ - T simd_reduce(T val) { \ - return simd_reduce_impl(val); \ - } \ - \ - template = true> \ - T simd_reduce(T val) { \ - for (short i = simd_size / 2; i > 0; i /= 2) { \ - val = operator()(val, simd_shuffle_down(val, i)); \ - } \ - return val; \ - } - -static constant constexpr const uint8_t simd_size = 32; - -union bool4_or_uint { - bool4 b; - unsigned int i; -}; - -struct None { - template - void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { - mlx_atomic_store_explicit(out, val, offset); - } -}; - -template -struct And { - DEFINE_SIMD_REDUCE() - - bool simd_reduce_impl(bool val) { - return simd_all(val); - } - - static constexpr constant bool init = true; - - void atomic_update( - device mlx_atomic* out, - bool val, - int elem_idx, - size_t offset = 0) { - if (!val) { - bool4_or_uint update; - update.b = {true, true, true, true}; - update.b[elem_idx] = false; - mlx_atomic_fetch_and_explicit(out, update.i, offset); - } - } - - void - atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { - if (!val) { - mlx_atomic_store_explicit(out, val, offset); - } - } - - // Non atomic update - void update(device bool* out, bool val) { - *out &= val; - } - - // Operator - bool operator()(bool a, bool b) { - return a && b; - } -}; - -template -struct Or { - DEFINE_SIMD_REDUCE() - - bool simd_reduce_impl(bool val) { - return simd_any(val); - } - - static constexpr constant bool init = false; - - void atomic_update( - device mlx_atomic* out, - bool val, - int elem_idx, - size_t offset = 0) { - if (val) { - bool4_or_uint update; - update.b = {false, false, false, false}; - update.b[elem_idx] = true; - mlx_atomic_fetch_or_explicit(out, update.i, offset); - } - } - - void - atomic_update(device mlx_atomic* out, bool val, size_t offset = 0) { - if (val) { - mlx_atomic_store_explicit(out, val, offset); - } - } - - // Non atomic update - void update(device bool* out, bool val) { - *out |= val; - } - - // Operator - bool operator()(bool a, bool b) { - return a || b; - } -}; - -template -struct Sum { - DEFINE_SIMD_REDUCE() - - template - T simd_reduce_impl(T val) { - return simd_sum(val); - } - - static constexpr constant U init = U(0); - - template - void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { - mlx_atomic_fetch_add_explicit(out, val, offset); - } - - // Operator - U operator()(U a, U b) { - return a + b; - } -}; - -template -struct Prod { - DEFINE_SIMD_REDUCE() - - template - T simd_reduce_impl(T val) { - return simd_product(val); - } - - static constexpr constant U init = U(1); - - template - void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { - mlx_atomic_fetch_mul_explicit(out, val, offset); - } - - // Operator - U operator()(U a, U b) { - return a * b; - } -}; - -template -struct Min { - DEFINE_SIMD_REDUCE() - - template - metal::enable_if_t, T> simd_reduce_impl(T val) { - return simd_min(val); - } - - template - metal::enable_if_t, T> simd_reduce_impl(T val) { - if (simd_any(val != val)) { - return static_cast(NAN); - } - return simd_min(val); - } - - static constexpr constant U init = Limits::max; - - template - void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { - mlx_atomic_fetch_min_explicit(out, val, offset); - } - - // Operator - template - metal::enable_if_t, T> operator()(T a, T b) { - return a < b ? a : b; - } - - template - metal::enable_if_t, T> operator()(T a, T b) { - if (metal::isnan(a) || metal::isnan(b)) { - return static_cast(NAN); - } else { - return a < b ? a : b; - } - } - - template <> - complex64_t operator()(complex64_t a, complex64_t b) { - bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); - bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); - - if (!real_is_nan && !imag_is_nan) { - return a < b ? a : b; - } else if (real_is_nan && !imag_is_nan) { - return complex64_t( - static_cast(NAN), a.imag < b.imag ? a.imag : b.imag); - } else if (!real_is_nan && imag_is_nan) { - return complex64_t( - a.real < b.real ? a.real : b.real, static_cast(NAN)); - } else { - return complex64_t(static_cast(NAN), static_cast(NAN)); - } - }; -}; -template -struct Max { - DEFINE_SIMD_REDUCE() - - template - metal::enable_if_t, T> simd_reduce_impl(T val) { - return simd_max(val); - } - - template - metal::enable_if_t, T> simd_reduce_impl(T val) { - if (simd_any(val != val)) { - return static_cast(NAN); - } - return simd_max(val); - } - - static constexpr constant U init = Limits::min; - - template - void atomic_update(device mlx_atomic* out, T val, size_t offset = 0) { - mlx_atomic_fetch_max_explicit(out, val, offset); - } - - // Operator - template - metal::enable_if_t, T> operator()(T a, T b) { - return a > b ? a : b; - } - - template - metal::enable_if_t, T> operator()(T a, T b) { - if (metal::isnan(a) || metal::isnan(b)) { - return static_cast(NAN); - } else { - return a > b ? a : b; - } - } - - template <> - complex64_t operator()(complex64_t a, complex64_t b) { - bool real_is_nan = metal::isnan(a.real) || metal::isnan(b.real); - bool imag_is_nan = metal::isnan(a.imag) || metal::isnan(b.imag); - - if (!real_is_nan && !imag_is_nan) { - return a > b ? a : b; - } else if (real_is_nan && !imag_is_nan) { - return complex64_t( - static_cast(NAN), a.imag > b.imag ? a.imag : b.imag); - } else if (!real_is_nan && imag_is_nan) { - return complex64_t( - a.real > b.real ? a.real : b.real, static_cast(NAN)); - } else { - return complex64_t(static_cast(NAN), static_cast(NAN)); - } - } -}; diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h deleted file mode 100644 index e0d08392..00000000 --- a/Source/Cmlx/mlx-generated/metal/reduction/reduce_all.h +++ /dev/null @@ -1,66 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -template < - typename T, - typename U, - typename Op, - typename IdxT = int64_t, - int N_READS = REDUCE_N_READS> -[[kernel]] void all_reduce( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& in_size [[buffer(2)]], - const constant size_t& row_size [[buffer(3)]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - Op op; - threadgroup U shared_vals[simd_size]; - - U total = Op::init; - IdxT start_idx = gid.y * IdxT(row_size); - IdxT actual_row = - (start_idx + row_size <= in_size) ? row_size : in_size - start_idx; - IdxT blocks = actual_row / (lsize.x * N_READS); - int extra = actual_row - blocks * (lsize.x * N_READS); - extra -= lid.x * N_READS; - start_idx += lid.x * N_READS; - in += start_idx; - - if (extra >= N_READS) { - blocks++; - extra = 0; - } - - for (IdxT b = 0; b < blocks; b++) { - for (int i = 0; i < N_READS; i++) { - total = op(static_cast(in[i]), total); - } - in += lsize.x * N_READS; - } - if (extra > 0) { - for (int i = 0; i < extra; i++) { - total = op(static_cast(in[i]), total); - } - } - - // Reduction within simd group - total = op.simd_reduce(total); - if (simd_per_group > 1) { - if (simd_lane_id == 0) { - shared_vals[simd_group_id] = total; - } - - // Reduction within thread group - threadgroup_barrier(mem_flags::mem_threadgroup); - total = lid.x < simd_per_group ? shared_vals[lid.x] : op.init; - total = op.simd_reduce(total); - } - - if (lid.x == 0) { - out[gid.y] = total; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h deleted file mode 100644 index c109faf0..00000000 --- a/Source/Cmlx/mlx-generated/metal/reduction/reduce_col.h +++ /dev/null @@ -1,398 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -template -[[kernel]] void col_reduce_small( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant int64_t& reduction_stride [[buffer(3)]], - const constant int* shape [[buffer(4)]], - const constant int64_t* strides [[buffer(5)]], - const constant int& ndim [[buffer(6)]], - const constant int* reduce_shape [[buffer(7)]], - const constant int64_t* reduce_strides [[buffer(8)]], - const constant int& reduce_ndim [[buffer(9)]], - const constant size_t& non_col_reductions [[buffer(10)]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 gsize [[threadgroups_per_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]]) { - constexpr int n_reads = 4; - Op op; - LoopedElemToLoc 2)> loop(reduce_ndim); - const device T* row; - - U totals[n_reads]; - for (int i = 0; i < n_reads; i++) { - totals[i] = Op::init; - } - - IdxT column = IdxT(gid.x) * lsize.x * n_reads + lid.x * n_reads; - if (column >= reduction_stride) { - return; - } - bool safe = column + n_reads <= reduction_stride; - - IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); - IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); - in += in_idx + column; - - IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); - loop.next(lid.y, reduce_shape, reduce_strides); - for (IdxT r = lid.y; r < total_rows; r += lsize.y) { - row = in + loop.location(); - if (safe) { - for (int i = 0; i < n_reads; i++) { - totals[i] = op(static_cast(row[i]), totals[i]); - } - } else { - U vals[n_reads]; - for (int i = 0; i < n_reads; i++) { - vals[i] = - (column + i < reduction_stride) ? static_cast(row[i]) : op.init; - } - for (int i = 0; i < n_reads; i++) { - totals[i] = op(vals[i], totals[i]); - } - } - loop.next(lsize.y, reduce_shape, reduce_strides); - } - - if (lsize.y > 1) { - // lsize.y should be <= 8 - threadgroup U shared_vals[32 * 8 * n_reads]; - for (int i = 0; i < n_reads; i++) { - shared_vals[lid.y * lsize.x * n_reads + lid.x * n_reads + i] = totals[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (lid.y == 0) { - for (int i = 0; i < n_reads; i++) { - totals[i] = shared_vals[lid.x * n_reads + i]; - } - for (uint j = 1; j < lsize.y; j++) { - for (int i = 0; i < n_reads; i++) { - totals[i] = - op(shared_vals[j * lsize.x * n_reads + lid.x * n_reads + i], - totals[i]); - } - } - } - } - - if (lid.y == 0) { - out += out_idx * IdxT(reduction_stride) + column; - if (safe) { - for (int i = 0; i < n_reads; i++) { - out[i] = totals[i]; - } - } else { - for (int i = 0; column + i < reduction_stride; i++) { - out[i] = totals[i]; - } - } - } -} - -template -[[kernel]] void col_reduce_longcolumn( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant size_t& reduction_stride [[buffer(3)]], - const constant int* shape [[buffer(4)]], - const constant int64_t* strides [[buffer(5)]], - const constant int& ndim [[buffer(6)]], - const constant int* reduce_shape [[buffer(7)]], - const constant int64_t* reduce_strides [[buffer(8)]], - const constant int& reduce_ndim [[buffer(9)]], - const constant size_t& non_col_reductions [[buffer(10)]], - const constant size_t& out_size [[buffer(11)]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 gsize [[threadgroups_per_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]]) { - Op op; - LoopedElemToLoc 2)> loop(reduce_ndim); - const device T* row; - - IdxT out_idx = gid.x + gsize.x * IdxT(gid.y); - IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); - in += in_idx + lid.x; - - U total = Op::init; - IdxT total_rows = IdxT(non_col_reductions) * IdxT(reduction_size); - loop.next(gid.z * lsize.y + lid.y, reduce_shape, reduce_strides); - for (IdxT r = gid.z * lsize.y + lid.y; r < total_rows; - r += lsize.y * gsize.z) { - row = in + loop.location(); - total = op(static_cast(*row), total); - loop.next(lsize.y * gsize.z, reduce_shape, reduce_strides); - } - - threadgroup U shared_vals[32 * 32]; - shared_vals[lid.y * lsize.x + lid.x] = total; - threadgroup_barrier(mem_flags::mem_threadgroup); - if (lid.y == 0) { - for (uint i = 1; i < lsize.y; i++) { - total = op(total, shared_vals[i * lsize.x + lid.x]); - } - out[gid.z * IdxT(out_size) + out_idx * IdxT(reduction_stride) + lid.x] = - total; - } -} - -/** - * Our approach is the following simple looped approach: - * 1. Each thread keeps running totals for BN / n_simdgroups outputs. - * 2. Load a tile BM, BN in registers and accumulate in the running totals - * 3. Move ahead by BM steps until the column axis and the non column - * reductions are exhausted. - * 6. If BM == 32 then transpose in SM and simd reduce the running totals. - * Otherwise write in shared memory and BN threads accumulate the running - * totals with a loop. - * 7. Write them to the output - */ -template < - typename T, - typename U, - typename Op, - typename IdxT, - int NDIMS, - int BM, - int BN> -[[kernel]] void col_reduce_looped( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant int64_t& reduction_stride [[buffer(3)]], - const constant int* shape [[buffer(4)]], - const constant int64_t* strides [[buffer(5)]], - const constant int& ndim [[buffer(6)]], - const constant int* reduce_shape [[buffer(7)]], - const constant int64_t* reduce_strides [[buffer(8)]], - const constant int& reduce_ndim [[buffer(9)]], - const constant size_t& non_col_reductions [[buffer(10)]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 gsize [[threadgroups_per_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - Op op; - constexpr int n_simdgroups = 8; - constexpr short tgp_size = n_simdgroups * simd_size; - constexpr short n_reads = (BM * BN) / tgp_size; - constexpr short n_read_blocks = BN / n_reads; - - threadgroup U shared_vals[BN * BM]; - U totals[n_reads]; - LoopedElemToLoc 2)> loop(reduce_ndim); - const device T* row; - - for (int i = 0; i < n_reads; i++) { - totals[i] = Op::init; - } - - short lid = simd_group_id * simd_size + simd_lane_id; - short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); - IdxT column = BN * gid.x + offset.x; - bool safe = column + n_reads <= reduction_stride; - - IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); - IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); - in += in_idx + column; - - IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); - loop.next(offset.y, reduce_shape, reduce_strides); - for (IdxT r = offset.y; r < total; r += BM) { - row = in + loop.location(); - - if (safe) { - for (int i = 0; i < n_reads; i++) { - totals[i] = op(static_cast(row[i]), totals[i]); - } - } else { - U vals[n_reads]; - for (int i = 0; i < n_reads; i++) { - vals[i] = - (column + i < reduction_stride) ? static_cast(row[i]) : op.init; - } - for (int i = 0; i < n_reads; i++) { - totals[i] = op(vals[i], totals[i]); - } - } - - loop.next(BM, reduce_shape, reduce_strides); - } - - // We can use a simd reduction to accumulate across BM so each thread writes - // the partial output to SM and then each simdgroup does BN / n_simdgroups - // accumulations. - if (BM == 32) { - constexpr int n_outputs = BN / n_simdgroups; - static_assert( - BM != 32 || n_outputs == n_reads, - "The tile should be selected such that n_outputs == n_reads"); - for (int i = 0; i < n_reads; i++) { - shared_vals[offset.y * BN + offset.x + i] = totals[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - short2 out_offset(simd_group_id * n_outputs, simd_lane_id); - for (int i = 0; i < n_outputs; i++) { - totals[i] = - op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); - } - - // Write the output. - if (simd_lane_id == 0) { - IdxT out_column = BN * gid.x + out_offset.x; - out += out_idx * IdxT(reduction_stride) + out_column; - if (out_column + n_outputs <= reduction_stride) { - for (int i = 0; i < n_outputs; i++) { - out[i] = totals[i]; - } - } else { - for (int i = 0; out_column + i < reduction_stride; i++) { - out[i] = totals[i]; - } - } - } - } - - // Each thread holds n_reads partial results. We write them all out to shared - // memory and threads with offset.y == 0 aggregate the columns and write the - // outputs. - else { - short x_block = offset.x / n_reads; - for (int i = 0; i < n_reads; i++) { - shared_vals[x_block * BM * n_reads + i * BM + offset.y] = totals[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (offset.y == 0) { - for (int i = 0; i < n_reads; i++) { - for (int j = 1; j < BM; j++) { - totals[i] = - op(shared_vals[x_block * BM * n_reads + i * BM + j], totals[i]); - } - } - } - - // Write the output. - if (offset.y == 0) { - out += out_idx * IdxT(reduction_stride) + column; - if (safe) { - for (int i = 0; i < n_reads; i++) { - out[i] = totals[i]; - } - } else { - for (int i = 0; column + i < reduction_stride; i++) { - out[i] = totals[i]; - } - } - } - } -} - -template < - typename T, - typename U, - typename Op, - typename IdxT, - int NDIMS, - int BM, - int BN> -[[kernel]] void col_reduce_2pass( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant int64_t& reduction_stride [[buffer(3)]], - const constant int* shape [[buffer(4)]], - const constant int64_t* strides [[buffer(5)]], - const constant int& ndim [[buffer(6)]], - const constant int* reduce_shape [[buffer(7)]], - const constant int64_t* reduce_strides [[buffer(8)]], - const constant int& reduce_ndim [[buffer(9)]], - const constant size_t& non_col_reductions [[buffer(10)]], - const constant size_t& out_size [[buffer(11)]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 gsize [[threadgroups_per_grid]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - Op op; - constexpr int n_simdgroups = 8; - constexpr short tgp_size = n_simdgroups * simd_size; - constexpr short n_reads = (BM * BN) / tgp_size; - constexpr short n_read_blocks = BN / n_reads; - constexpr int n_outputs = BN / n_simdgroups; - constexpr short outer_blocks = 32; - static_assert(BM == 32, "BM should be equal to 32"); - - threadgroup U shared_vals[BN * BM]; - U totals[n_reads]; - LoopedElemToLoc 2)> loop(reduce_ndim); - const device T* row; - - for (int i = 0; i < n_reads; i++) { - totals[i] = Op::init; - } - - short lid = simd_group_id * simd_size + simd_lane_id; - short2 offset((lid % n_read_blocks) * n_reads, lid / n_read_blocks); - IdxT column = BN * gid.x + offset.x; - bool safe = column + n_reads <= reduction_stride; - - IdxT full_idx = gid.y + gsize.y * IdxT(gid.z); - IdxT block_idx = full_idx / IdxT(out_size); - IdxT out_idx = full_idx % IdxT(out_size); - IdxT in_idx = elem_to_loc(out_idx, shape, strides, ndim); - in += in_idx + column; - - IdxT total = IdxT(non_col_reductions) * IdxT(reduction_size); - loop.next(offset.y + block_idx * BM, reduce_shape, reduce_strides); - for (IdxT r = offset.y + block_idx * BM; r < total; r += outer_blocks * BM) { - row = in + loop.location(); - - if (safe) { - for (int i = 0; i < n_reads; i++) { - totals[i] = op(static_cast(row[i]), totals[i]); - } - } else { - U vals[n_reads]; - for (int i = 0; i < n_reads; i++) { - vals[i] = - (column + i < reduction_stride) ? static_cast(row[i]) : op.init; - } - for (int i = 0; i < n_reads; i++) { - totals[i] = op(vals[i], totals[i]); - } - } - - loop.next(outer_blocks * BM, reduce_shape, reduce_strides); - } - - // We can use a simd reduction to accumulate across BM so each thread writes - // the partial output to SM and then each simdgroup does BN / n_simdgroups - // accumulations. - for (int i = 0; i < n_reads; i++) { - shared_vals[offset.y * BN + offset.x + i] = totals[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - short2 out_offset(simd_group_id * n_outputs, simd_lane_id); - for (int i = 0; i < n_outputs; i++) { - totals[i] = - op.simd_reduce(shared_vals[out_offset.y * BN + out_offset.x + i]); - } - - // Write the output. - if (simd_lane_id == 0) { - IdxT out_column = BN * gid.x + out_offset.x; - out += full_idx * IdxT(reduction_stride) + out_column; - if (out_column + n_outputs <= reduction_stride) { - for (int i = 0; i < n_outputs; i++) { - out[i] = totals[i]; - } - } else { - for (int i = 0; out_column + i < reduction_stride; i++) { - out[i] = totals[i]; - } - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_init.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_init.h deleted file mode 100644 index 604efa78..00000000 --- a/Source/Cmlx/mlx-generated/metal/reduction/reduce_init.h +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -template -[[kernel]] void init_reduce( - device T* out [[buffer(0)]], - uint tid [[thread_position_in_grid]]) { - out[tid] = Op::init; -} diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h deleted file mode 100644 index 936d75bb..00000000 --- a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h +++ /dev/null @@ -1,369 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -// Row reduction utilities -// - `per_thread_row_reduce` collaborative partial reduction in the threadgroup -// - `threadgroup_reduce` collaborative reduction in the threadgroup such that -// lid.x == 0 holds the reduced value -// - `thread_reduce` simple loop and reduce the row - -/** - * The thread group collaboratively reduces across the rows with bounds - * checking. In the end each thread holds a part of the reduction. - */ -template < - typename T, - typename U, - typename Op, - int N_READS = REDUCE_N_READS, - int N_WRITES = REDUCE_N_WRITES> -METAL_FUNC void per_thread_row_reduce( - thread U totals[N_WRITES], - const device T* inputs[N_WRITES], - int blocks, - int extra, - uint lsize_x, - uint lid_x) { - Op op; - - // Set up the accumulator registers - for (int i = 0; i < N_WRITES; i++) { - totals[i] = Op::init; - } - - // Loop over the reduction size within thread group - for (int i = 0; i < blocks; i++) { - for (int j = 0; j < N_WRITES; j++) { - for (int i = 0; i < N_READS; i++) { - totals[j] = op(static_cast(inputs[j][i]), totals[j]); - } - - inputs[j] += lsize_x * N_READS; - } - } - - // Separate case for the last set as we close the reduction size - int index = lid_x * N_READS; - if (index + N_READS <= extra) { - for (int j = 0; j < N_WRITES; j++) { - for (int i = 0; i < N_READS; i++) { - totals[j] = op(static_cast(inputs[j][i]), totals[j]); - } - } - } else { - for (int j = 0; j < N_WRITES; j++) { - for (int i = 0; index + i < extra; i++) { - totals[j] = op(static_cast(inputs[j][i]), totals[j]); - } - } - } -} - -/** - * Consecutive rows in a contiguous array. - */ -template < - typename T, - typename U, - typename Op, - int N_READS = REDUCE_N_READS, - int N_WRITES = REDUCE_N_WRITES> -METAL_FUNC void per_thread_row_reduce( - thread U totals[N_WRITES], - const device T* in, - const constant size_t& reduction_size, - int blocks, - int extra, - uint lsize_x, - uint lid_x) { - // Set up the input pointers - const device T* inputs[N_WRITES]; - inputs[0] = in + lid_x * N_READS; - for (int i = 1; i < N_READS; i++) { - inputs[i] = inputs[i - 1] + reduction_size; - } - - per_thread_row_reduce( - totals, inputs, blocks, extra, lsize_x, lid_x); -} - -/** - * Consecutive rows in an arbitrarily ordered array. - */ -template < - typename T, - typename U, - typename Op, - int N_READS = REDUCE_N_READS, - int N_WRITES = REDUCE_N_WRITES> -METAL_FUNC void per_thread_row_reduce( - thread U totals[N_WRITES], - const device T* in, - const int64_t row_idx, - int blocks, - int extra, - const constant int* shape, - const constant int64_t* strides, - const constant int& ndim, - uint lsize_x, - uint lid_x) { - // Set up the input pointers - const device T* inputs[N_WRITES]; - in += lid_x * N_READS; - for (int i = 0; i < N_READS; i++) { - inputs[i] = in + elem_to_loc(row_idx + i, shape, strides, ndim); - } - - per_thread_row_reduce( - totals, inputs, blocks, extra, lsize_x, lid_x); -} - -/** - * Reduce within the threadgroup. - */ -template < - typename T, - typename U, - typename Op, - int N_READS = REDUCE_N_READS, - int N_WRITES = REDUCE_N_WRITES> -METAL_FUNC void threadgroup_reduce( - thread U totals[N_WRITES], - threadgroup U* shared_vals, - uint3 lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - Op op; - - // Simdgroup first - for (int i = 0; i < N_WRITES; i++) { - totals[i] = op.simd_reduce(totals[i]); - } - - // Across simdgroups - if (simd_per_group > 1) { - if (simd_lane_id == 0) { - for (int i = 0; i < N_WRITES; i++) { - shared_vals[simd_group_id * N_WRITES + i] = totals[i]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - U values[N_WRITES]; - for (int i = 0; i < N_WRITES; i++) { - values[i] = (lid.x < simd_per_group) ? shared_vals[lid.x * N_WRITES + i] - : op.init; - } - - for (int i = 0; i < N_WRITES; i++) { - totals[i] = op.simd_reduce(values[i]); - } - } -} - -template -METAL_FUNC void -thread_reduce(thread U& total, const device T* row, int blocks, int extra) { - Op op; - for (int i = 0; i < blocks; i++) { - U vals[N_READS]; - for (int j = 0; j < N_READS; j++) { - vals[j] = row[j]; - } - for (int j = 0; j < N_READS; j++) { - total = op(vals[j], total); - } - row += N_READS; - } - for (int i = 0; i < extra; i++) { - total = op(*row++, total); - } -} - -// Reduction kernels -// - `row_reduce_small` depending on the non-row reductions and row size it -// either just loops over everything or a simd collaboratively reduces the -// non_row reductions. In the first case one thread is responsible for one -// output on the 2nd one simd is responsible for one output. -// - `row_reduce_simple` simple contiguous row reduction -// - `row_reduce_looped` simply loop and reduce each row for each non-row -// reduction. One threadgroup is responsible for one output. - -template < - typename T, - typename U, - typename Op, - typename IdxT, - int NDIMS, - int N_READS = REDUCE_N_READS> -[[kernel]] void row_reduce_small( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant int64_t& row_size [[buffer(2)]], - const constant int64_t& non_row_reductions [[buffer(3)]], - const constant int* shape [[buffer(4)]], - const constant int64_t* strides [[buffer(5)]], - const constant int& ndim [[buffer(6)]], - const constant int* reduce_shape [[buffer(7)]], - const constant int64_t* reduce_strides [[buffer(8)]], - const constant int& reduce_ndim [[buffer(9)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 gsize [[threadgroups_per_grid]], - uint3 tid [[thread_position_in_grid]], - uint3 tsize [[threads_per_grid]]) { - Op op; - - U total_val = Op::init; - LoopedElemToLoc 2)> loop(reduce_ndim); - - // Precompute some row reduction numbers - const device T* row; - int blocks = IdxT(row_size) / N_READS; - int extra = IdxT(row_size) % N_READS; - - if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { - // Simple loop over non_row_reductions and reduce the row in the thread. - IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); - in += elem_to_loc(out_idx, shape, strides, ndim); - - for (uint r = 0; r < non_row_reductions; r++) { - row = in + loop.location(); - thread_reduce(total_val, row, blocks, extra); - loop.next(reduce_shape, reduce_strides); - } - - out[out_idx] = total_val; - } else { - // Collaboratively reduce over non_row_reductions in the simdgroup. Each - // thread reduces every 32nd row and then a simple simd reduce. - IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); - in += elem_to_loc(out_idx, shape, strides, ndim); - - loop.next(simd_lane_id, reduce_shape, reduce_strides); - - for (uint r = simd_lane_id; r < non_row_reductions; r += simd_size) { - row = in + loop.location(); - thread_reduce(total_val, row, blocks, extra); - loop.next(simd_size, reduce_shape, reduce_strides); - } - - total_val = op.simd_reduce(total_val); - - if (simd_lane_id == 0) { - out[out_idx] = total_val; - } - } -} - -template < - typename T, - typename U, - typename Op, - typename IdxT = int64_t, - int N_READS = REDUCE_N_READS, - int N_WRITES = REDUCE_N_WRITES> -[[kernel]] void row_reduce_simple( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& reduction_size [[buffer(2)]], - const constant int64_t& out_size [[buffer(3)]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 gsize [[threadgroups_per_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - threadgroup U shared_vals[simd_size * N_WRITES]; - U totals[N_WRITES]; - - // Move to the row - IdxT out_idx = N_WRITES * (gid.y + gsize.y * IdxT(gid.z)); - if (out_idx + N_WRITES > out_size) { - out_idx = out_size - N_WRITES; - } - in += out_idx * IdxT(reduction_size); - out += out_idx; - - // Each thread reduces across the row - int blocks = IdxT(reduction_size) / (lsize.x * N_READS); - int extra = reduction_size - blocks * (lsize.x * N_READS); - per_thread_row_reduce( - totals, in, reduction_size, blocks, extra, lsize.x, lid.x); - - // Reduce across the threadgroup - threadgroup_reduce( - totals, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); - - // Write the output - if (lid.x == 0) { - for (int i = 0; i < N_WRITES; i++) { - out[i] = totals[i]; - } - } -} - -template < - typename T, - typename U, - typename Op, - typename IdxT, - int NDIMS, - int N_READS = REDUCE_N_READS> -[[kernel]] void row_reduce_looped( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant int64_t& row_size [[buffer(2)]], - const constant int64_t& non_row_reductions [[buffer(3)]], - const constant int* shape [[buffer(4)]], - const constant int64_t* strides [[buffer(5)]], - const constant int& ndim [[buffer(6)]], - const constant int* reduce_shape [[buffer(7)]], - const constant int64_t* reduce_strides [[buffer(8)]], - const constant int& reduce_ndim [[buffer(9)]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 gsize [[threadgroups_per_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_per_group [[simdgroups_per_threadgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - Op op; - threadgroup U shared_vals[simd_size]; - U total = Op::init; - - IdxT out_idx = gid.y + gsize.y * IdxT(gid.z); - - // lid.x * N_READS breaks the per_thread_row_reduce interface a bit. Maybe it - // needs a small refactor. - in += elem_to_loc(out_idx, shape, strides, ndim) + lid.x * N_READS; - - LoopedElemToLoc 2)> loop(reduce_ndim); - const device T* row; - int blocks = IdxT(row_size) / (lsize.x * N_READS); - int extra = row_size - blocks * (lsize.x * N_READS); - - for (IdxT i = 0; i < non_row_reductions; i++) { - row = in + loop.location(); - - // Each thread reduces across the row - U row_total; - per_thread_row_reduce( - &row_total, &row, blocks, extra, lsize.x, lid.x); - - // Aggregate across rows - total = op(total, row_total); - - loop.next(reduce_shape, reduce_strides); - } - - // Reduce across the threadgroup - threadgroup_reduce( - &total, shared_vals, lid, simd_lane_id, simd_per_group, simd_group_id); - - // Write the output - if (lid.x == 0) { - out[out_idx] = total; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/rms_norm.metal b/Source/Cmlx/mlx-generated/metal/rms_norm.metal deleted file mode 100644 index 22fae273..00000000 --- a/Source/Cmlx/mlx-generated/metal/rms_norm.metal +++ /dev/null @@ -1,391 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include -#include - -#include "utils.h" - -using namespace metal; - -constant bool has_w [[function_constant(20)]]; - -template -[[kernel]] void rms_single_row( - const device T* x, - const device T* w, - device T* out, - constant float& eps, - constant uint& axis_size, - constant uint& w_stride, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; - - threadgroup float local_inv_mean[1]; - threadgroup float local_sums[SIMD_SIZE]; - - float acc = 0; - x += gid * size_t(axis_size) + lid * N_READS; - w += w_stride * lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - float xi = x[i]; - acc += xi * xi; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - float xi = x[i]; - acc += xi * xi; - } - } - } - acc = simd_sum(acc); - // Initialize shared memory - if (simd_group_id == 0) { - local_sums[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sums[simd_group_id] = acc; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - acc = simd_sum(local_sums[simd_lane_id]); - if (simd_lane_id == 0) { - local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write the outputs - out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - out[i] = w[w_stride * i] * static_cast(x[i] * local_inv_mean[0]); - } - } - } -} - -template -[[kernel]] void rms_looped( - const device T* x, - const device T* w, - device T* out, - constant float& eps, - constant uint& axis_size, - constant uint& w_stride, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int SIMD_SIZE = 32; - threadgroup float local_inv_mean[1]; - threadgroup float local_sums[SIMD_SIZE]; - - float acc = 0; - x += gid * size_t(axis_size) + lid * N_READS; - w += w_stride * lid * N_READS; - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - acc += xi * xi; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - acc += xi * xi; - } - } - } - } - acc = simd_sum(acc); - // Initialize shared memory - if (simd_group_id == 0) { - local_sums[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sums[simd_group_id] = acc; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - acc = simd_sum(local_sums[simd_lane_id]); - if (simd_lane_id == 0) { - local_inv_mean[0] = metal::precise::rsqrt(acc / axis_size + eps); - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write the outputs - out += gid * size_t(axis_size) + lid * N_READS; - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - out[r + i] = w[w_stride * (i + r)] * - static_cast(x[r + i] * local_inv_mean[0]); - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - out[r + i] = w[w_stride * (i + r)] * - static_cast(x[r + i] * local_inv_mean[0]); - } - } - } - } -} - -template -[[kernel]] void vjp_rms_single_row( - const device T* x, - const device T* w, - const device T* g, - device T* gx, - device T* gw, - constant float& eps, - constant uint& axis_size, - constant uint& w_stride, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - // Advance the input pointers - x += gid * size_t(axis_size) + lid * N_READS; - g += gid * size_t(axis_size) + lid * N_READS; - w += w_stride * lid * N_READS; - - // Allocate registers for the computation and accumulators - float thread_x[N_READS]; - float thread_w[N_READS]; - float thread_g[N_READS]; - float sumx2 = 0; - float sumgwx = 0; - - // Allocate shared memory to implement the reduction - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumgwx[SIMD_SIZE]; - threadgroup float local_normalizer[1]; - threadgroup float local_meangwx[1]; - - // Read and accumulate locally - if (lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - thread_x[i] = x[i]; - thread_w[i] = w[w_stride * i]; - thread_g[i] = g[i]; - - sumx2 += thread_x[i] * thread_x[i]; - sumgwx += thread_x[i] * thread_w[i] * thread_g[i]; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - thread_w[i] = w[w_stride * i]; - thread_g[i] = g[i]; - - sumx2 += thread_x[i] * thread_x[i]; - sumgwx += thread_x[i] * thread_w[i] * thread_g[i]; - } - } - } - - // Accumulate across threads - sumx2 = simd_sum(sumx2); - sumgwx = simd_sum(sumgwx); - if (simd_group_id == 0) { - local_sumx2[simd_lane_id] = 0; - local_sumgwx[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_lane_id == 0) { - local_sumx2[simd_group_id] = sumx2; - local_sumgwx[simd_group_id] = sumgwx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_group_id == 0) { - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumgwx = simd_sum(local_sumgwx[simd_lane_id]); - if (simd_lane_id == 0) { - local_meangwx[0] = sumgwx / axis_size; - local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps); - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - float meangwx = local_meangwx[0]; - float normalizer = local_normalizer[0]; - float normalizer3 = normalizer * normalizer * normalizer; - - // Write the outputs - gx += gid * size_t(axis_size) + lid * N_READS; - gw += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - gx[i] = static_cast( - thread_g[i] * thread_w[i] * normalizer - - thread_x[i] * meangwx * normalizer3); - if (has_w) { - gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); - } - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - gx[i] = static_cast( - thread_g[i] * thread_w[i] * normalizer - - thread_x[i] * meangwx * normalizer3); - if (has_w) { - gw[i] = static_cast(thread_g[i] * thread_x[i] * normalizer); - } - } - } - } -} - -template -[[kernel]] void vjp_rms_looped( - const device T* x, - const device T* w, - const device T* g, - device T* gx, - device T* gw, - constant float& eps, - constant uint& axis_size, - constant uint& w_stride, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - // Advance the input pointers - x += gid * size_t(axis_size) + lid * N_READS; - g += gid * size_t(axis_size) + lid * N_READS; - w += w_stride * lid * N_READS; - - // Allocate registers for the accumulators - float sumx2 = 0; - float sumgwx = 0; - - // Allocate shared memory to implement the reduction - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumgwx[SIMD_SIZE]; - threadgroup float local_normalizer[1]; - threadgroup float local_meangwx[1]; - - // Read and accumulate locally - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - float wi = w[w_stride * (i + r)]; - float gi = g[i + r]; - - sumx2 += xi * xi; - sumgwx += xi * wi * gi; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - float wi = w[w_stride * (i + r)]; - float gi = g[i + r]; - - sumx2 += xi * xi; - sumgwx += xi * wi * gi; - } - } - } - } - - // Accumulate across threads - sumx2 = simd_sum(sumx2); - sumgwx = simd_sum(sumgwx); - if (simd_group_id == 0) { - local_sumx2[simd_lane_id] = 0; - local_sumgwx[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_lane_id == 0) { - local_sumx2[simd_group_id] = sumx2; - local_sumgwx[simd_group_id] = sumgwx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_group_id == 0) { - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumgwx = simd_sum(local_sumgwx[simd_lane_id]); - if (simd_lane_id == 0) { - local_meangwx[0] = sumgwx / axis_size; - local_normalizer[0] = metal::precise::rsqrt(sumx2 / axis_size + eps); - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - float meangwx = local_meangwx[0]; - float normalizer = local_normalizer[0]; - float normalizer3 = normalizer * normalizer * normalizer; - - // Write the outputs - gx += gid * size_t(axis_size) + lid * N_READS; - gw += gid * size_t(axis_size) + lid * N_READS; - for (uint r = 0; r < axis_size; r += lsize * N_READS) { - if (r + lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - float wi = w[w_stride * (i + r)]; - float gi = g[i + r]; - - gx[i + r] = - static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); - if (has_w) { - gw[i + r] = static_cast(gi * xi * normalizer); - } - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - float wi = w[w_stride * (i + r)]; - float gi = g[i + r]; - - gx[i + r] = - static_cast(gi * wi * normalizer - xi * meangwx * normalizer3); - if (has_w) { - gw[i + r] = static_cast(gi * xi * normalizer); - } - } - } - } - } -} - -// clang-format off -#define instantiate_rms(name, itype) \ - instantiate_kernel("rms" #name, rms_single_row, itype) \ - instantiate_kernel("vjp_rms" #name, vjp_rms_single_row, itype) \ - instantiate_kernel("rms_looped" #name, rms_looped, itype) \ - instantiate_kernel("vjp_rms_looped" #name, vjp_rms_looped, itype) - -instantiate_rms(float32, float) -instantiate_rms(float16, half) -instantiate_rms(bfloat16, bfloat16_t) // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/rope.metal b/Source/Cmlx/mlx-generated/metal/rope.metal deleted file mode 100644 index f8cafe78..00000000 --- a/Source/Cmlx/mlx-generated/metal/rope.metal +++ /dev/null @@ -1,229 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include - -#include "utils.h" - -constant bool forward [[function_constant(1)]]; -constant bool traditional [[function_constant(2)]]; -constant bool hs_transpose [[function_constant(3)]]; - -template -void rope_single_impl( - const device T* in, - device T* out, - constant const int& offset, - const float inv_freq, - constant const float& scale, - constant const int64_t& stride, - uint2 pos, - uint2 grid) { - float L = scale * static_cast(offset); - - // Compute costheta, sintheta - float theta = L * inv_freq; - float costheta = metal::fast::cos(theta); - float sintheta = metal::fast::sin(theta); - - // Compute the input and output indices - uint index_1, index_2; - if (traditional) { - index_1 = 2 * pos.x + pos.y * stride; - index_2 = index_1 + 1; - } else { - index_1 = pos.x + pos.y * stride; - index_2 = index_1 + grid.x; - } - - // Read and write the output - float x1 = static_cast(in[index_1]); - float x2 = static_cast(in[index_2]); - float rx1; - float rx2; - if (forward) { - rx1 = x1 * costheta - x2 * sintheta; - rx2 = x1 * sintheta + x2 * costheta; - } else { - rx1 = x2 * sintheta + x1 * costheta; - rx2 = x2 * costheta - x1 * sintheta; - } - out[index_1] = static_cast(rx1); - out[index_2] = static_cast(rx2); -} - -template -[[kernel]] void rope_single( - const device T* in [[buffer(0)]], - device T* out [[buffer(1)]], - constant const int& offset, - constant const float& scale, - constant const int64_t& stride, - constant const float& base [[buffer(10)]], - uint2 pos [[thread_position_in_grid]], - uint2 grid [[threads_per_grid]]) { - float d = static_cast(pos.x) / static_cast(grid.x); - float inv_freq = metal::exp2(-d * base); - rope_single_impl(in, out, offset, inv_freq, scale, stride, pos, grid); -} - -template -[[kernel]] void rope_single_freqs( - const device T* in [[buffer(0)]], - device T* out [[buffer(1)]], - constant const int& offset, - constant const float& scale, - constant const int64_t& stride, - const device float* freqs [[buffer(10)]], - constant const int64_t& freq_stride [[buffer(11)]], - uint2 pos [[thread_position_in_grid]], - uint2 grid [[threads_per_grid]]) { - float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); - rope_single_impl(in, out, offset, inv_freq, scale, stride, pos, grid); -} - -template -void rope_impl( - const device T* in, - device T* out, - const device int* offset, - const float inv_freq, - constant const float& scale, - constant const int64_t strides[3], - constant const int64_t out_strides[3], - constant const int64_t& offset_stride, - constant const int& n_head, - uint3 pos, - uint3 grid) { - auto n_head_up = N * ((n_head + N - 1) / N); - auto head_idx = static_cast((pos.z * N) % n_head_up); - auto batch_idx = (pos.z * N) / n_head_up; - auto batch_offset = offset[batch_idx * offset_stride]; - float L = scale * static_cast(pos.y + batch_offset); - auto mat_idx = batch_idx * n_head + head_idx; - - // Compute costheta, sintheta - float theta = L * inv_freq; - float costheta = metal::fast::cos(theta); - float sintheta = metal::fast::sin(theta); - // Compute the input and output indices - IdxT in_index_1; - if (hs_transpose) { - IdxT batch_stride = grid.y * IdxT(strides[1]); - in_index_1 = - batch_idx * batch_stride + pos.y * strides[1] + head_idx * strides[0]; - } else { - in_index_1 = pos.y * IdxT(strides[1]) + mat_idx * IdxT(strides[0]); - } - IdxT in_index_2; - IdxT out_index_1 = - pos.y * IdxT(out_strides[1]) + mat_idx * IdxT(out_strides[0]); - IdxT out_index_2; - if (traditional) { - out_index_1 += 2 * pos.x * IdxT(out_strides[2]); - out_index_2 = out_index_1 + 1; - in_index_1 += 2 * pos.x * IdxT(strides[2]); - in_index_2 = in_index_1 + IdxT(strides[2]); - } else { - out_index_1 += pos.x * IdxT(out_strides[2]); - out_index_2 = out_index_1 + grid.x * IdxT(out_strides[2]); - in_index_1 += pos.x * IdxT(strides[2]); - in_index_2 = in_index_1 + grid.x * IdxT(strides[2]); - } - for (int i = 0; i < N && head_idx + i < n_head; ++i) { - // Read and write the output - float x1 = static_cast(in[in_index_1]); - float x2 = static_cast(in[in_index_2]); - float rx1; - float rx2; - if (forward) { - rx1 = x1 * costheta - x2 * sintheta; - rx2 = x1 * sintheta + x2 * costheta; - } else { - rx1 = x2 * sintheta + x1 * costheta; - rx2 = x2 * costheta - x1 * sintheta; - } - out[out_index_1] = static_cast(rx1); - out[out_index_2] = static_cast(rx2); - in_index_1 += IdxT(strides[0]); - in_index_2 += IdxT(strides[0]); - out_index_1 += IdxT(out_strides[0]); - out_index_2 += IdxT(out_strides[0]); - } -} - -template -[[kernel]] void rope( - const device T* in [[buffer(0)]], - device T* out [[buffer(1)]], - const device int* offset, - constant const float& scale, - constant const int64_t strides[3], - constant const int64_t out_strides[3], - constant const int64_t& offset_stride, - constant const int& n_head, - constant const float& base [[buffer(10)]], - uint3 pos [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - float d = static_cast(pos.x) / static_cast(grid.x); - float inv_freq = metal::exp2(-d * base); - rope_impl( - in, - out, - offset, - inv_freq, - scale, - strides, - out_strides, - offset_stride, - n_head, - pos, - grid); -} - -template -[[kernel]] void rope_freqs( - const device T* in [[buffer(0)]], - device T* out [[buffer(1)]], - const device int* offset, - constant const float& scale, - constant const int64_t strides[3], - constant const int64_t out_strides[3], - constant const int64_t& offset_stride, - constant const int& n_head, - const device float* freqs [[buffer(10)]], - constant const int64_t& freq_stride [[buffer(11)]], - uint3 pos [[thread_position_in_grid]], - uint3 grid [[threads_per_grid]]) { - float inv_freq = 1.0 / (freqs[freq_stride * pos.x]); - rope_impl( - in, - out, - offset, - inv_freq, - scale, - strides, - out_strides, - offset_stride, - n_head, - pos, - grid); -} - -// clang-format off -#define instantiate_rope_g(name, type) \ - instantiate_kernel("rope_" #name, rope, type, int32_t) \ - instantiate_kernel("rope_freqs_" #name, rope_freqs, type, int32_t) \ - instantiate_kernel("rope_large_" #name, rope, type, int64_t) \ - instantiate_kernel("rope_freqs_large_" #name, rope_freqs, type, int64_t) - -#define instantiate_rope_s(name, type) \ - instantiate_kernel("rope_single_" #name, rope_single, type) \ - instantiate_kernel("rope_single_freqs_" #name, rope_single_freqs, type) - -#define instantiate_rope(name, type) \ - instantiate_rope_s(name, type) \ - instantiate_rope_g(name, type) - -instantiate_rope(float16, half) -instantiate_rope(bfloat16, bfloat16_t) -instantiate_rope(float32, float) // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal b/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal deleted file mode 100644 index ae04c6ba..00000000 --- a/Source/Cmlx/mlx-generated/metal/scaled_dot_product_attention.metal +++ /dev/null @@ -1,44 +0,0 @@ -#include - -// clang-format off -#include "utils.h" -#include "sdpa_vector.h" - -using namespace metal; - -// SDPA vector instantiations -#define instantiate_sdpa_vector_aggregation(type, value_dim) \ - instantiate_kernel( \ - "sdpa_vector_2pass_2_" #type "_" #value_dim, \ - sdpa_vector_2pass_2, \ - type, \ - value_dim) - -#define instantiate_sdpa_vector(type, qk_dim, value_dim) \ - instantiate_kernel( \ - "sdpa_vector_" #type "_" #qk_dim "_" #value_dim, \ - sdpa_vector, \ - type, \ - qk_dim, \ - value_dim) \ - instantiate_kernel( \ - "sdpa_vector_2pass_1_" #type "_" #qk_dim "_" #value_dim, \ - sdpa_vector_2pass_1, \ - type, \ - qk_dim, \ - value_dim) - -#define instantiate_sdpa_vector_heads(type) \ - instantiate_sdpa_vector(type, 64, 64) \ - instantiate_sdpa_vector(type, 96, 96) \ - instantiate_sdpa_vector(type, 128, 128) \ - instantiate_sdpa_vector(type, 256, 256) \ - instantiate_sdpa_vector_aggregation(type, 64) \ - instantiate_sdpa_vector_aggregation(type, 96) \ - instantiate_sdpa_vector_aggregation(type, 128) \ - instantiate_sdpa_vector_aggregation(type, 256) - -instantiate_sdpa_vector_heads(float) -instantiate_sdpa_vector_heads(bfloat16_t) -instantiate_sdpa_vector_heads(float16_t) - // clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/scan.h b/Source/Cmlx/mlx-generated/metal/scan.h deleted file mode 100644 index a1f10340..00000000 --- a/Source/Cmlx/mlx-generated/metal/scan.h +++ /dev/null @@ -1,514 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#pragma once - -#include "binary_ops.h" - -#define DEFINE_SIMD_SCAN() \ - template = true> \ - T simd_scan(T val) { \ - return simd_scan_impl(val); \ - } \ - \ - template = true> \ - T simd_scan(T val) { \ - for (int i = 1; i <= 16; i *= 2) { \ - val = operator()(val, simd_shuffle_and_fill_up(val, init, i)); \ - } \ - return val; \ - } - -#define DEFINE_SIMD_EXCLUSIVE_SCAN() \ - template = true> \ - T simd_exclusive_scan(T val) { \ - return simd_exclusive_scan_impl(val); \ - } \ - \ - template = true> \ - T simd_exclusive_scan(T val) { \ - val = simd_scan(val); \ - return simd_shuffle_and_fill_up(val, init, 1); \ - } - -template -struct CumSum { - DEFINE_SIMD_SCAN() - DEFINE_SIMD_EXCLUSIVE_SCAN() - - static constexpr constant U init = static_cast(0); - - template - U operator()(U a, T b) { - return a + b; - } - - U simd_scan_impl(U x) { - return simd_prefix_inclusive_sum(x); - } - - U simd_exclusive_scan_impl(U x) { - return simd_prefix_exclusive_sum(x); - } -}; - -template -struct CumProd { - DEFINE_SIMD_SCAN() - DEFINE_SIMD_EXCLUSIVE_SCAN() - - static constexpr constant U init = static_cast(1.0f); - - template - U operator()(U a, T b) { - return a * b; - } - - U simd_scan_impl(U x) { - return simd_prefix_inclusive_product(x); - } - - U simd_exclusive_scan_impl(U x) { - return simd_prefix_exclusive_product(x); - } -}; - -template <> -struct CumProd { - static constexpr constant bool init = true; - - template - bool operator()(bool a, T b) { - return a & static_cast(b); - } - - bool simd_scan(bool x) { - for (int i = 1; i <= 16; i *= 2) { - bool other = simd_shuffle_and_fill_up(x, init, i); - x &= other; - } - return x; - } - - bool simd_exclusive_scan(bool x) { - x = simd_scan(x); - return simd_shuffle_and_fill_up(x, init, 1); - } -}; - -template -struct CumMax { - static constexpr constant U init = Limits::min; - - template - U operator()(U a, T b) { - return (a >= b) ? a : b; - } - - U simd_scan(U x) { - for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_and_fill_up(x, init, i); - x = (x >= other) ? x : other; - } - return x; - } - - U simd_exclusive_scan(U x) { - x = simd_scan(x); - return simd_shuffle_and_fill_up(x, init, 1); - } -}; - -template -struct CumMin { - static constexpr constant U init = Limits::max; - - template - U operator()(U a, T b) { - return (a <= b) ? a : b; - } - - U simd_scan(U x) { - for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_and_fill_up(x, init, i); - x = (x <= other) ? x : other; - } - return x; - } - - U simd_exclusive_scan(U x) { - x = simd_scan(x); - return simd_shuffle_and_fill_up(x, init, 1); - } -}; - -template -struct CumLogaddexp { - static constexpr constant U init = Limits::min; - - template - U operator()(U a, T b) { - return LogAddExp{}(a, static_cast(b)); - } - - U simd_scan(U x) { - for (int i = 1; i <= 16; i *= 2) { - U other = simd_shuffle_and_fill_up(x, init, i); - x = LogAddExp{}(x, other); - } - return x; - } - - U simd_exclusive_scan(U x) { - x = simd_scan(x); - return simd_shuffle_and_fill_up(x, init, 1); - } -}; - -template -inline void load_unsafe(U values[N_READS], const device T* input) { - if (reverse) { - for (int i = 0; i < N_READS; i++) { - values[N_READS - i - 1] = input[i]; - } - } else { - for (int i = 0; i < N_READS; i++) { - values[i] = input[i]; - } - } -} - -template -inline void load_safe( - U values[N_READS], - const device T* input, - int start, - int total, - U init) { - if (reverse) { - for (int i = 0; i < N_READS; i++) { - values[N_READS - i - 1] = - (start + N_READS - i - 1 < total) ? input[i] : init; - } - } else { - for (int i = 0; i < N_READS; i++) { - values[i] = (start + i < total) ? input[i] : init; - } - } -} - -template -inline void write_unsafe(U values[N_READS], device U* out) { - if (reverse) { - for (int i = 0; i < N_READS; i++) { - out[i] = values[N_READS - i - 1]; - } - } else { - for (int i = 0; i < N_READS; i++) { - out[i] = values[i]; - } - } -} - -template -inline void write_safe(U values[N_READS], device U* out, int start, int total) { - if (reverse) { - for (int i = 0; i < N_READS; i++) { - if (start + N_READS - i - 1 < total) { - out[i] = values[N_READS - i - 1]; - } - } - } else { - for (int i = 0; i < N_READS; i++) { - if (start + i < total) { - out[i] = values[i]; - } - } - } -} - -template < - typename T, - typename U, - typename Op, - int N_READS, - bool inclusive, - bool reverse> -[[kernel]] void contiguous_scan( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& axis_size [[buffer(2)]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 gsize [[threadgroups_per_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int simd_size = 32; - Op op; - - // Position the pointers - size_t offset = (gid.y + gsize.y * size_t(gid.z)) * axis_size; - in += offset; - out += offset; - - // Compute the number of simd_groups - uint simd_groups = lsize.x / simd_size; - - // Allocate memory - U prefix = Op::init; - U values[N_READS]; - threadgroup U simdgroup_sums[32]; - - // Loop over the reduced axis in blocks of size ceildiv(axis_size, - // N_READS*lsize) - // Read block - // Compute inclusive scan of the block - // Compute inclusive scan per thread - // Compute exclusive scan of thread sums in simdgroup - // Write simdgroup sums in SM - // Compute exclusive scan of simdgroup sums - // Compute the output by scanning prefix, prev_simdgroup, prev_thread, - // value - // Write block - - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { - // Compute the block offset - uint offset = r * lsize.x * N_READS + lid.x * N_READS; - - // Read the values - if (reverse) { - if ((offset + N_READS) < axis_size) { - load_unsafe( - values, in + axis_size - offset - N_READS); - } else { - load_safe( - values, - in + axis_size - offset - N_READS, - offset, - axis_size, - Op::init); - } - } else { - if ((offset + N_READS) < axis_size) { - load_unsafe(values, in + offset); - } else { - load_safe( - values, in + offset, offset, axis_size, Op::init); - } - } - - // Compute an inclusive scan per thread - for (int i = 1; i < N_READS; i++) { - values[i] = op(values[i], values[i - 1]); - } - - // Compute exclusive scan of thread sums - U prev_thread = op.simd_exclusive_scan(values[N_READS - 1]); - - // Write simdgroup_sums to SM - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_lane_id == simd_size - 1) { - simdgroup_sums[simd_group_id] = op(prev_thread, values[N_READS - 1]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Compute exclusive scan of simdgroup_sums - if (simd_group_id == 0) { - U prev_simdgroup = op.simd_exclusive_scan(simdgroup_sums[simd_lane_id]); - simdgroup_sums[simd_lane_id] = prev_simdgroup; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Compute the output - for (int i = 0; i < N_READS; i++) { - values[i] = op(values[i], prefix); - values[i] = op(values[i], simdgroup_sums[simd_group_id]); - values[i] = op(values[i], prev_thread); - } - - // Write the values - if (reverse) { - if (inclusive) { - if ((offset + N_READS) < axis_size) { - write_unsafe( - values, out + axis_size - offset - N_READS); - } else { - write_safe( - values, out + axis_size - offset - N_READS, offset, axis_size); - } - } else { - if (lid.x == 0 && offset == 0) { - out[axis_size - 1] = Op::init; - } - if ((offset + N_READS + 1) < axis_size) { - write_unsafe( - values, out + axis_size - offset - 1 - N_READS); - } else { - write_safe( - values, - out + axis_size - offset - 1 - N_READS, - offset + 1, - axis_size); - } - } - } else { - if (inclusive) { - if ((offset + N_READS) < axis_size) { - write_unsafe(values, out + offset); - } else { - write_safe( - values, out + offset, offset, axis_size); - } - } else { - if (lid.x == 0 && offset == 0) { - out[0] = Op::init; - } - if ((offset + N_READS + 1) < axis_size) { - write_unsafe(values, out + offset + 1); - } else { - write_safe( - values, out + offset + 1, offset + 1, axis_size); - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Share the prefix - if (simd_group_id == simd_groups - 1 && simd_lane_id == simd_size - 1) { - simdgroup_sums[0] = values[N_READS - 1]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - prefix = simdgroup_sums[0]; - } -} - -template < - typename T, - typename U, - typename Op, - int N_READS, - bool inclusive, - bool reverse> -[[kernel]] void strided_scan( - const device T* in [[buffer(0)]], - device U* out [[buffer(1)]], - const constant size_t& axis_size [[buffer(2)]], - const constant size_t& stride [[buffer(3)]], - const constant size_t& stride_blocks [[buffer(4)]], - uint3 gid [[threadgroup_position_in_grid]], - uint3 gsize [[threadgroups_per_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - constexpr int simd_size = 32; - constexpr int BM = 32; - constexpr int BN = 32; - constexpr int BN_pad = 32 + 16 / sizeof(U); - constexpr int n_simds = BN / N_READS; - constexpr int n_scans = BN / n_simds; - Op op; - - threadgroup U read_buffer[BM * BN_pad]; - U values[n_scans]; - U prefix[n_scans]; - for (int i = 0; i < n_scans; i++) { - prefix[i] = Op::init; - } - - // Compute offsets - size_t full_gid = gid.y + gsize.y * size_t(gid.z); - size_t offset = full_gid / stride_blocks * axis_size * stride; - size_t global_index_x = full_gid % stride_blocks * BN; - uint read_offset_y = (lid.x * N_READS) / BN; - uint read_offset_x = (lid.x * N_READS) % BN; - uint scan_offset_y = simd_lane_id; - uint scan_offset_x = simd_group_id * n_scans; - - uint stride_limit = stride - global_index_x; - in += offset + global_index_x + read_offset_x; - out += offset + global_index_x + read_offset_x; - threadgroup U* read_into = - read_buffer + read_offset_y * BN_pad + read_offset_x; - threadgroup U* read_from = - read_buffer + scan_offset_y * BN_pad + scan_offset_x; - - for (uint j = 0; j < axis_size; j += BM) { - // Calculate the indices for the current thread - uint index_y = j + read_offset_y; - uint check_index_y = index_y; - if (reverse) { - index_y = axis_size - 1 - index_y; - } - - // Read in SM - threadgroup_barrier(mem_flags::mem_threadgroup); - if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { - for (int i = 0; i < N_READS; i++) { - read_into[i] = in[index_y * stride + i]; - } - } else { - for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { - read_into[i] = in[index_y * stride + i]; - } else { - read_into[i] = Op::init; - } - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Read strided into registers - for (int i = 0; i < n_scans; i++) { - values[i] = read_from[i]; - } - simdgroup_barrier(mem_flags::mem_threadgroup); - - // Perform the scan - for (int i = 0; i < n_scans; i++) { - values[i] = op.simd_scan(values[i]); - values[i] = op(values[i], prefix[i]); - prefix[i] = simd_shuffle(values[i], simd_size - 1); - } - - // Write to SM - for (int i = 0; i < n_scans; i++) { - read_from[i] = values[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write to device memory - if (!inclusive) { - if (check_index_y == 0) { - if ((read_offset_x + N_READS) < stride_limit) { - for (int i = 0; i < N_READS; i++) { - out[index_y * stride + i] = Op::init; - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((read_offset_x + i) < stride_limit) { - out[index_y * stride + i] = Op::init; - } - } - } - } - if (reverse) { - index_y -= 1; - check_index_y += 1; - } else { - index_y += 1; - check_index_y += 1; - } - } - if (check_index_y < axis_size && (read_offset_x + N_READS) < stride_limit) { - for (int i = 0; i < N_READS; i++) { - out[index_y * stride + i] = read_into[i]; - } - } else { - for (int i = 0; i < N_READS; i++) { - if (check_index_y < axis_size && (read_offset_x + i) < stride_limit) { - out[index_y * stride + i] = read_into[i]; - } - } - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/sdpa_vector.h b/Source/Cmlx/mlx-generated/metal/sdpa_vector.h deleted file mode 100644 index 1eec72be..00000000 --- a/Source/Cmlx/mlx-generated/metal/sdpa_vector.h +++ /dev/null @@ -1,394 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include - -using namespace metal; - -constant bool has_mask [[function_constant(20)]]; -constant bool query_transposed [[function_constant(21)]]; -constant bool do_causal [[function_constant(22)]]; -constant bool bool_mask [[function_constant(23)]]; -constant bool float_mask [[function_constant(24)]]; -constant bool has_sinks [[function_constant(25)]]; -constant int blocks [[function_constant(26)]]; - -template -[[kernel]] void sdpa_vector( - const device T* queries [[buffer(0)]], - const device T* keys [[buffer(1)]], - const device T* values [[buffer(2)]], - device T* out [[buffer(3)]], - const constant int& gqa_factor [[buffer(4)]], - const constant int& N [[buffer(5)]], - const constant size_t& k_head_stride [[buffer(6)]], - const constant size_t& k_seq_stride [[buffer(7)]], - const constant size_t& v_head_stride [[buffer(8)]], - const constant size_t& v_seq_stride [[buffer(9)]], - const constant float& scale [[buffer(10)]], - const device bool* bmask [[buffer(11), function_constant(bool_mask)]], - const device T* fmask [[buffer(12), function_constant(float_mask)]], - const constant int& mask_kv_seq_stride - [[buffer(13), function_constant(has_mask)]], - const constant int& mask_q_seq_stride - [[buffer(14), function_constant(has_mask)]], - const constant int& mask_head_stride - [[buffer(15), function_constant(has_mask)]], - const device T* sinks [[buffer(16), function_constant(has_sinks)]], - const constant int& num_q_heads - [[buffer(17), function_constant(has_sinks)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 tpg [[threadgroups_per_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int BN = 32; - constexpr int BD = 32; - constexpr int qk_per_thread = D / BD; - constexpr int v_per_thread = V / BD; - int inner_k_stride = BN * int(k_seq_stride); - int inner_v_stride = BN * int(v_seq_stride); - - typedef float U; - - thread U q[qk_per_thread]; - thread U k[qk_per_thread]; - thread U o[v_per_thread]; - - threadgroup U outputs[BN * BD]; - threadgroup U max_scores[BN]; - threadgroup U sum_exp_scores[BN]; - - // Adjust positions - const int q_batch_head_idx = tid.x; - const int q_seq_idx = tid.y; - const int kv_head_idx = q_batch_head_idx / gqa_factor; - const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; - const int q_offset = - query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; - queries += q_offset * D + simd_lid * qk_per_thread; - keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + - simd_lid * qk_per_thread; - values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + - simd_lid * v_per_thread; - if (bool_mask) { - bmask += q_batch_head_idx * mask_head_stride + - simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; - } - if (float_mask) { - fmask += q_batch_head_idx * mask_head_stride + - simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; - } - - out += o_offset * V + simd_gid * v_per_thread; - - // Read the query and 0 the output accumulator - for (int i = 0; i < qk_per_thread; i++) { - q[i] = static_cast(scale) * queries[i]; - } - for (int i = 0; i < v_per_thread; i++) { - o[i] = 0; - } - - U max_score = Limits::finite_min; - U sum_exp_score = 0; - if (has_sinks && simd_gid == 0) { - max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); - sum_exp_score = 1; - } - - // For each key - for (int i = simd_gid; i < N; i += BN) { - bool use_key = true; - if (do_causal) { - use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); - } else if (bool_mask) { - use_key = bmask[0]; - } else if (float_mask) { - use_key = (fmask[0] >= Limits::finite_min); - } - if (use_key) { - // Read the key - for (int j = 0; j < qk_per_thread; j++) { - k[j] = keys[j]; - } - - // Compute the i-th score - U score = 0; - for (int j = 0; j < qk_per_thread; j++) { - score += q[j] * k[j]; - } - score = simd_sum(score); - if (float_mask) { - score += static_cast(fmask[0]); - } - - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); - - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; - - // Update the output accumulator - for (int j = 0; j < v_per_thread; j++) { - o[j] = o[j] * factor + exp_score * values[j]; - } - } - - // Move the pointers to the next kv - keys += inner_k_stride; - values += inner_v_stride; - if (bool_mask) { - bmask += BN * mask_kv_seq_stride; - } - if (float_mask) { - fmask += BN * mask_kv_seq_stride; - } - } - - // Each thread has a partial part of the output so we need to combine them. - - // First let's communicate the max and sum_exp - if (simd_lid == 0) { - max_scores[simd_gid] = max_score; - sum_exp_scores[simd_gid] = sum_exp_score; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - max_score = max_scores[simd_lid]; - U new_max = simd_max(max_score); - U factor = fast::exp(max_score - new_max); - sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); - - // Now we need to aggregate all the outputs - for (int i = 0; i < v_per_thread; i++) { - outputs[simd_lid * BD + simd_gid] = o[i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); - o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // And write the output - if (simd_lid == 0) { - for (int i = 0; i < v_per_thread; i++) { - out[i] = static_cast(o[i]); - } - } -} - -template -[[kernel]] void sdpa_vector_2pass_1( - const device T* queries [[buffer(0)]], - const device T* keys [[buffer(1)]], - const device T* values [[buffer(2)]], - device T* out [[buffer(3)]], - device float* sums [[buffer(4)]], - device float* maxs [[buffer(5)]], - const constant int& N [[buffer(7)]], - const constant size_t& k_head_stride [[buffer(8)]], - const constant size_t& k_seq_stride [[buffer(9)]], - const constant size_t& v_head_stride [[buffer(10)]], - const constant size_t& v_seq_stride [[buffer(11)]], - const constant float& scale [[buffer(12)]], - const device bool* bmask [[buffer(13), function_constant(bool_mask)]], - const device T* fmask [[buffer(14), function_constant(float_mask)]], - const constant int& mask_kv_seq_stride - [[buffer(15), function_constant(has_mask)]], - const constant int& mask_q_seq_stride - [[buffer(16), function_constant(has_mask)]], - const constant int& mask_head_stride - [[buffer(17), function_constant(has_mask)]], - const device T* sinks [[buffer(18), function_constant(has_sinks)]], - uint3 tptg [[threads_per_threadgroup]], - uint3 tidtg [[thread_position_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 tpg [[threadgroups_per_grid]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int BD = 32; - constexpr int qk_per_thread = D / BD; - constexpr int v_per_thread = V / BD; - - typedef float U; - - thread U q[qk_per_thread]; - thread U o[v_per_thread] = {0}; - - // Adjust positions - const int kv_head_idx = tid.x; - const int batch_idx = tid.y; - const int block_idx = tid.z; - const int gqa_factor = tptg.y; - const int q_seq_len = tptg.z; - const int q_seq_idx = tidtg.z; - const int q_head_idx = gqa_factor * kv_head_idx + tidtg.y; - const int num_kv_heads = tpg.x; - const int num_q_heads = num_kv_heads * gqa_factor; - const int q_batch_head_idx = (batch_idx * num_q_heads + q_head_idx); - const int o_offset = q_batch_head_idx * q_seq_len + q_seq_idx; - const int q_offset = - query_transposed ? num_q_heads * q_seq_idx + q_batch_head_idx : o_offset; - - queries += q_offset * D + simd_lid * qk_per_thread; - - const int kv_batch_head_idx = batch_idx * num_kv_heads + kv_head_idx; - keys += kv_batch_head_idx * k_head_stride + block_idx * k_seq_stride + - simd_lid * qk_per_thread; - values += kv_batch_head_idx * v_head_stride + block_idx * v_seq_stride + - simd_lid * v_per_thread; - out += o_offset * blocks * V + block_idx * V + simd_lid * v_per_thread; - if (bool_mask) { - bmask += q_batch_head_idx * mask_head_stride + - block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; - } - if (float_mask) { - fmask += q_batch_head_idx * mask_head_stride + - block_idx * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; - } - sums += o_offset * blocks + block_idx; - maxs += o_offset * blocks + block_idx; - - // Read the query - for (int i = 0; i < qk_per_thread; i++) { - q[i] = static_cast(scale) * queries[i]; - } - - U max_score = Limits::finite_min; - U sum_exp_score = 0; - if (has_sinks && block_idx == 0) { - max_score = static_cast(sinks[q_head_idx]); - sum_exp_score = 1; - } - - // For each key - for (int i = block_idx; i < N; i += blocks) { - bool use_key = true; - if (do_causal) { - use_key = i <= (N - q_seq_len + int(q_seq_idx)); - } else if (bool_mask) { - use_key = bmask[0]; - } else if (float_mask) { - use_key = (fmask[0] >= Limits::finite_min); - } - if (use_key) { - // Compute the i-th score - U score = 0; - for (int i = 0; i < qk_per_thread; i++) { - score += q[i] * keys[i]; - } - score = simd_sum(score); - - if (float_mask) { - score += fmask[0]; - } - - // Update the accumulators - U new_max = max(max_score, score); - U factor = fast::exp(max_score - new_max); - U exp_score = fast::exp(score - new_max); - - max_score = new_max; - sum_exp_score = sum_exp_score * factor + exp_score; - - // Update the output accumulator - for (int i = 0; i < v_per_thread; i++) { - o[i] = o[i] * factor + exp_score * values[i]; - } - } - - // Move the pointers to the next kv - keys += blocks * int(k_seq_stride); - values += blocks * int(v_seq_stride); - if (bool_mask) { - bmask += blocks * mask_kv_seq_stride; - } - if (float_mask) { - fmask += blocks * mask_kv_seq_stride; - } - } - - // Write the sum and max and outputs - if (simd_lid == 0) { - sums[0] = sum_exp_score; - maxs[0] = max_score; - } - - for (int i = 0; i < v_per_thread; i++) { - out[i] = static_cast(o[i]); - } -} - -template -[[kernel]] void sdpa_vector_2pass_2( - const device T* partials [[buffer(0)]], - const device float* sums [[buffer(1)]], - const device float* maxs [[buffer(2)]], - device T* out [[buffer(3)]], - const constant int& blocks [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 tpg [[threadgroups_per_grid]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int BN = 32; - constexpr int BD = 32; - constexpr int elem_per_thread = D / BD; - - typedef float U; - - thread U o[elem_per_thread] = {0}; - threadgroup U outputs[BN * BD]; - - // Adjust positions - const int head_idx = tid.x; - const int q_seq_idx = tid.y; - const int q_offset = head_idx * tpg.y + q_seq_idx; - partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; - sums += q_offset * blocks; - maxs += q_offset * blocks; - out += q_offset * D + simd_gid * elem_per_thread; - - // Set defaults - U sum_exp_score = 0.0; - U max_score = Limits::finite_min; - - // Reduce the max - for (int b = 0; b < blocks / BN; ++b) { - max_score = max(max_score, maxs[simd_lid + BN * b]); - } - max_score = simd_max(max_score); - - // Reduce the d - for (int b = 0; b < blocks / BN; ++b) { - U factor = fast::exp(maxs[simd_lid + BN * b] - max_score); - sum_exp_score += factor * sums[simd_lid + BN * b]; - } - sum_exp_score = simd_sum(sum_exp_score); - - // Reduce the sum exp and partials - for (int b = 0; b < blocks / BN; ++b) { - U factor = fast::exp(maxs[simd_gid] - max_score); - - // Update the output accumulator - for (int i = 0; i < elem_per_thread; i++) { - o[i] += factor * static_cast(partials[i]); - } - maxs += BN; - sums += BN; - partials += BN * D; - } - - // Use shared memory to transpose and reduce the final block - for (int i = 0; i < elem_per_thread; i++) { - outputs[simd_lid * BD + simd_gid] = o[i]; - threadgroup_barrier(mem_flags::mem_threadgroup); - o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]); - o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); - threadgroup_barrier(mem_flags::mem_threadgroup); - } - - // And write the output - if (simd_lid == 0) { - for (int i = 0; i < elem_per_thread; i++) { - out[i] = static_cast(o[i]); - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/softmax.h b/Source/Cmlx/mlx-generated/metal/softmax.h deleted file mode 100644 index 6ea4ac73..00000000 --- a/Source/Cmlx/mlx-generated/metal/softmax.h +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -template -inline T softmax_exp(T x) { - // Softmax doesn't need high precision exponential cause x is gonna be in - // (-oo, 0] anyway and subsequently it will be divided by sum(exp(x_i)). - return fast::exp(x); -} - -template -[[kernel]] void softmax_single_row( - const device T* in, - device T* out, - constant int& axis_size, - uint gid [[threadgroup_position_in_grid]], - uint _lid [[thread_position_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - int lid = _lid; - - constexpr int SIMD_SIZE = 32; - - threadgroup AccT local_max[SIMD_SIZE]; - threadgroup AccT local_normalizer[SIMD_SIZE]; - - AccT ld[N_READS]; - - in += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - ld[i] = AccT(in[i]); - } - } else { - for (int i = 0; i < N_READS; i++) { - ld[i] = - ((lid * N_READS + i) < axis_size) ? AccT(in[i]) : Limits::min; - } - } - if (simd_group_id == 0) { - local_max[simd_lane_id] = Limits::min; - local_normalizer[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Get the max - AccT maxval = Limits::finite_min; - for (int i = 0; i < N_READS; i++) { - maxval = (maxval < ld[i]) ? ld[i] : maxval; - } - maxval = simd_max(maxval); - if (simd_lane_id == 0) { - local_max[simd_group_id] = maxval; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_group_id == 0) { - maxval = simd_max(local_max[simd_lane_id]); - if (simd_lane_id == 0) { - local_max[0] = maxval; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - maxval = local_max[0]; - - // Compute exp(x_i - maxval) and store the partial sums in local_normalizer - AccT normalizer = 0; - for (int i = 0; i < N_READS; i++) { - AccT exp_x = softmax_exp(ld[i] - maxval); - ld[i] = exp_x; - normalizer += exp_x; - } - normalizer = simd_sum(normalizer); - if (simd_lane_id == 0) { - local_normalizer[simd_group_id] = normalizer; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - local_normalizer[0] = normalizer; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - normalizer = 1 / local_normalizer[0]; - - // Normalize and write to the output - out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - out[i] = T(ld[i] * normalizer); - } - } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - out[i] = T(ld[i] * normalizer); - } - } - } -} - -template -[[kernel]] void softmax_looped( - const device T* in, - device T* out, - constant int& axis_size, - uint gid [[threadgroup_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - in += gid * size_t(axis_size); - - constexpr int SIMD_SIZE = 32; - - threadgroup AccT local_max[SIMD_SIZE]; - threadgroup AccT local_normalizer[SIMD_SIZE]; - - // Get the max and the normalizer in one go - AccT prevmax; - AccT maxval = Limits::finite_min; - AccT normalizer = 0; - for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); - r++) { - int offset = r * lsize * N_READS + lid * N_READS; - AccT vals[N_READS]; - if (offset + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - vals[i] = AccT(in[offset + i]); - } - } else { - for (int i = 0; i < N_READS; i++) { - vals[i] = - (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; - } - } - prevmax = maxval; - for (int i = 0; i < N_READS; i++) { - maxval = (maxval < vals[i]) ? vals[i] : maxval; - } - normalizer *= softmax_exp(prevmax - maxval); - for (int i = 0; i < N_READS; i++) { - normalizer += softmax_exp(vals[i] - maxval); - } - } - // Now we got partial normalizer of N_READS * ceildiv(axis_size, N_READS * - // lsize) parts. We need to combine them. - // 1. We start by finding the max across simd groups - // 2. We then change the partial normalizers to account for a possible - // change in max - // 3. We sum all normalizers - prevmax = maxval; - maxval = simd_max(maxval); - normalizer *= softmax_exp(prevmax - maxval); - normalizer = simd_sum(normalizer); - - // Now the normalizer and max value is correct for each simdgroup. We write - // them shared memory and combine them. - prevmax = maxval; - if (simd_lane_id == 0) { - local_max[simd_group_id] = maxval; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - maxval = simd_max(local_max[simd_lane_id]); - normalizer *= softmax_exp(prevmax - maxval); - if (simd_lane_id == 0) { - local_normalizer[simd_group_id] = normalizer; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - normalizer = simd_sum(local_normalizer[simd_lane_id]); - normalizer = 1 / normalizer; - - // Finally given the normalizer and max value we can directly write the - // softmax output - out += gid * size_t(axis_size); - for (int r = 0; r < static_cast(ceildiv(axis_size, N_READS * lsize)); - r++) { - int offset = r * lsize * N_READS + lid * N_READS; - if (offset + N_READS <= axis_size) { - for (int i = 0; i < N_READS; i++) { - out[offset + i] = T(softmax_exp(in[offset + i] - maxval) * normalizer); - } - } else { - for (int i = 0; i < N_READS; i++) { - if (offset + i < axis_size) { - out[offset + i] = - T(softmax_exp(in[offset + i] - maxval) * normalizer); - } - } - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/sort.h b/Source/Cmlx/mlx-generated/metal/sort.h deleted file mode 100644 index 0d357333..00000000 --- a/Source/Cmlx/mlx-generated/metal/sort.h +++ /dev/null @@ -1,719 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#define MLX_MTL_CONST static constant constexpr const -#define MLX_MTL_LOOP_UNROLL _Pragma("clang loop unroll(full)") - -using namespace metal; - -// Based on GPU merge sort algorithm at -// https://github.com/NVIDIA/cccl/tree/main/cub/cub - -/////////////////////////////////////////////////////////////////////////////// -// Thread-level sort -/////////////////////////////////////////////////////////////////////////////// - -template -METAL_FUNC void thread_swap(thread T& a, thread T& b) { - T w = a; - a = b; - b = w; -} - -template -struct Init { - static constexpr constant T v = Limits::max; -}; - -template -struct Init>> { - static constexpr constant T v = metal::numeric_limits::quiet_NaN(); -}; - -template -struct LessThan { - static constexpr constant T init = Init::v; - METAL_FUNC bool operator()(T a, T b) const { - if constexpr ( - metal::is_floating_point_v || metal::is_same_v) { - bool an = isnan(a); - bool bn = isnan(b); - if (an | bn) { - return (!an) & bn; - } - } - return a < b; - } -}; - -template < - typename ValT, - typename IdxT, - bool ARG_SORT, - short N_PER_THREAD, - typename CompareOp> -struct ThreadSort { - static METAL_FUNC void sort( - thread ValT (&vals)[N_PER_THREAD], - thread IdxT (&idxs)[N_PER_THREAD]) { - CompareOp op; - MLX_MTL_LOOP_UNROLL - for (short i = 0; i < N_PER_THREAD; ++i) { - MLX_MTL_LOOP_UNROLL - for (short j = i & 1; j < N_PER_THREAD - 1; j += 2) { - if (op(vals[j + 1], vals[j])) { - thread_swap(vals[j + 1], vals[j]); - if (ARG_SORT) { - thread_swap(idxs[j + 1], idxs[j]); - } - } - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -// Threadgroup-level sort -/////////////////////////////////////////////////////////////////////////////// - -template < - typename ValT, - typename IdxT, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD, - typename CompareOp> -struct BlockMergeSort { - using thread_sort_t = - ThreadSort; - static METAL_FUNC int merge_partition( - const threadgroup ValT* As, - const threadgroup ValT* Bs, - short A_sz, - short B_sz, - short sort_md) { - CompareOp op; - - short A_st = max(0, sort_md - B_sz); - short A_ed = min(sort_md, A_sz); - - while (A_st < A_ed) { - short md = A_st + (A_ed - A_st) / 2; - auto a = As[md]; - auto b = Bs[sort_md - 1 - md]; - - if (op(b, a)) { - A_ed = md; - } else { - A_st = md + 1; - } - } - - return A_ed; - } - - static METAL_FUNC void merge_step( - const threadgroup ValT* As, - const threadgroup ValT* Bs, - const threadgroup IdxT* As_idx, - const threadgroup IdxT* Bs_idx, - short A_sz, - short B_sz, - thread ValT (&vals)[N_PER_THREAD], - thread IdxT (&idxs)[N_PER_THREAD]) { - CompareOp op; - short a_idx = 0; - short b_idx = 0; - - for (int i = 0; i < N_PER_THREAD; ++i) { - auto a = (a_idx < A_sz) ? As[a_idx] : ValT(CompareOp::init); - auto b = (b_idx < B_sz) ? Bs[b_idx] : ValT(CompareOp::init); - bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a)); - - vals[i] = pred ? b : a; - if (ARG_SORT) { - if (pred) { - idxs[i] = Bs_idx[b_idx]; - } else { - idxs[i] = (a_idx < A_sz) ? As_idx[a_idx] : IdxT(0); - } - } - - b_idx += short(pred); - a_idx += short(!pred); - } - } - - static METAL_FUNC void sort( - threadgroup ValT* tgp_vals [[threadgroup(0)]], - threadgroup IdxT* tgp_idxs [[threadgroup(1)]], - int size_sorted_axis, - uint3 lid [[thread_position_in_threadgroup]]) { - // Get thread location - int idx = lid.x * N_PER_THREAD; - - // Load from shared memory - thread ValT thread_vals[N_PER_THREAD]; - thread IdxT thread_idxs[N_PER_THREAD]; - for (int i = 0; i < N_PER_THREAD; ++i) { - thread_vals[i] = tgp_vals[idx + i]; - if (ARG_SORT) { - thread_idxs[i] = tgp_idxs[idx + i]; - } - } - - // Per thread sort - if (idx < size_sorted_axis) { - thread_sort_t::sort(thread_vals, thread_idxs); - } - - // Do merges using threadgroup memory - for (int merge_threads = 2; merge_threads <= BLOCK_THREADS; - merge_threads *= 2) { - // Update threadgroup memory - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_PER_THREAD; ++i) { - tgp_vals[idx + i] = thread_vals[i]; - if (ARG_SORT) { - tgp_idxs[idx + i] = thread_idxs[i]; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Find location in merge step - int merge_group = lid.x / merge_threads; - int merge_lane = lid.x % merge_threads; - - int sort_sz = N_PER_THREAD * merge_threads; - int sort_st = N_PER_THREAD * merge_threads * merge_group; - - // As = tgp_vals[A_st:A_ed] is sorted - // Bs = tgp_vals[B_st:B_ed] is sorted - int A_st = sort_st; - int A_ed = sort_st + sort_sz / 2; - int B_st = sort_st + sort_sz / 2; - int B_ed = sort_st + sort_sz; - - const threadgroup ValT* As = tgp_vals + A_st; - const threadgroup ValT* Bs = tgp_vals + B_st; - int A_sz = A_ed - A_st; - int B_sz = B_ed - B_st; - - // Find a partition of merge elements - // Ci = merge(As[partition:], Bs[sort_md - partition:]) - // of size N_PER_THREAD for each merge lane i - // C = [Ci] is sorted - int sort_md = N_PER_THREAD * merge_lane; - int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md); - - As += partition; - Bs += sort_md - partition; - - A_sz -= partition; - B_sz -= sort_md - partition; - - const threadgroup IdxT* As_idx = - ARG_SORT ? tgp_idxs + A_st + partition : nullptr; - const threadgroup IdxT* Bs_idx = - ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr; - - // Merge starting at the partition and store results in thread registers - merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs); - } - - // Write out to shared memory - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_PER_THREAD; ++i) { - tgp_vals[idx + i] = thread_vals[i]; - if (ARG_SORT) { - tgp_idxs[idx + i] = thread_idxs[i]; - } - } - } -}; - -/////////////////////////////////////////////////////////////////////////////// -// Kernel sort -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename U, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD, - typename CompareOp = LessThan> -struct KernelMergeSort { - using ValT = T; - using IdxT = uint; - using block_merge_sort_t = BlockMergeSort< - ValT, - IdxT, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD, - CompareOp>; - - MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; - - static METAL_FUNC void block_sort( - const device T* inp, - device U* out, - const constant int& size_sorted_axis, - const constant int& in_stride_sorted_axis, - const constant int& out_stride_sorted_axis, - const constant int& in_stride_segment_axis, - const constant int& out_stride_segment_axis, - threadgroup ValT* tgp_vals, - threadgroup IdxT* tgp_idxs, - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // tid.y tells us the segment index - inp += tid.y * in_stride_segment_axis; - out += tid.y * out_stride_segment_axis; - - // Copy into threadgroup memory - for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { - tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis] - : ValT(CompareOp::init); - if (ARG_SORT) { - tgp_idxs[i] = i; - } - } - - // Sort elements within the block - threadgroup_barrier(mem_flags::mem_threadgroup); - - block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write output - for (int i = lid.x; i < size_sorted_axis; i += BLOCK_THREADS) { - if (ARG_SORT) { - out[i * out_stride_sorted_axis] = tgp_idxs[i]; - } else { - out[i * out_stride_sorted_axis] = tgp_vals[i]; - } - } - } -}; - -template < - typename T, - typename U, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort( - const device T* inp [[buffer(0)]], - device U* out [[buffer(1)]], - const constant int& size_sorted_axis [[buffer(2)]], - const constant int& in_stride_sorted_axis [[buffer(3)]], - const constant int& out_stride_sorted_axis [[buffer(4)]], - const constant int& in_stride_segment_axis [[buffer(5)]], - const constant int& out_stride_segment_axis [[buffer(6)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using sort_kernel = - KernelMergeSort; - using ValT = typename sort_kernel::ValT; - using IdxT = typename sort_kernel::IdxT; - - if (ARG_SORT) { - threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; - sort_kernel::block_sort( - inp, - out, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis, - tgp_vals, - tgp_idxs, - tid, - lid); - } else { - threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; - sort_kernel::block_sort( - inp, - out, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - in_stride_segment_axis, - out_stride_segment_axis, - tgp_vals, - nullptr, - tid, - lid); - } -} - -constant constexpr const int zero_helper = 0; - -template < - typename T, - typename U, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void block_sort_nc( - const device T* inp [[buffer(0)]], - device U* out [[buffer(1)]], - const constant int& size_sorted_axis [[buffer(2)]], - const constant int& in_stride_sorted_axis [[buffer(3)]], - const constant int& out_stride_sorted_axis [[buffer(4)]], - const constant int& nc_dim [[buffer(5)]], - const constant int* nc_shape [[buffer(6)]], - const constant int64_t* in_nc_strides [[buffer(7)]], - const constant int64_t* out_nc_strides [[buffer(8)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using sort_kernel = - KernelMergeSort; - using ValT = typename sort_kernel::ValT; - using IdxT = typename sort_kernel::IdxT; - - auto in_block_idx = elem_to_loc(tid.y, nc_shape, in_nc_strides, nc_dim); - auto out_block_idx = elem_to_loc(tid.y, nc_shape, out_nc_strides, nc_dim); - inp += in_block_idx; - out += out_block_idx; - - if (ARG_SORT) { - threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; - sort_kernel::block_sort( - inp, - out, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - zero_helper, - zero_helper, - tgp_vals, - tgp_idxs, - tid, - lid); - } else { - threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; - sort_kernel::block_sort( - inp, - out, - size_sorted_axis, - in_stride_sorted_axis, - out_stride_sorted_axis, - zero_helper, - zero_helper, - tgp_vals, - nullptr, - tid, - lid); - } -} - -template < - typename ValT, - typename IdxT, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD, - typename CompareOp = LessThan> -struct KernelMultiBlockMergeSort { - using block_merge_sort_t = BlockMergeSort< - ValT, - IdxT, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD, - CompareOp>; - - MLX_MTL_CONST short N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD; - - static METAL_FUNC void block_sort( - const device ValT* inp, - device ValT* out_vals, - device IdxT* out_idxs, - const constant int& size_sorted_axis, - const constant int& stride_sorted_axis, - threadgroup ValT* tgp_vals, - threadgroup IdxT* tgp_idxs, - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // tid.y tells us the segment index - int base_idx = tid.x * N_PER_BLOCK; - - // Copy into threadgroup memory - for (short i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { - int idx = base_idx + i; - tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis] - : ValT(CompareOp::init); - tgp_idxs[i] = idx; - } - - // Sort elements within the block - threadgroup_barrier(mem_flags::mem_threadgroup); - - block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis, lid); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write output - for (int i = lid.x; i < N_PER_BLOCK; i += BLOCK_THREADS) { - int idx = base_idx + i; - if (idx < size_sorted_axis) { - out_vals[idx] = tgp_vals[i]; - out_idxs[idx] = tgp_idxs[i]; - } - } - } - - static METAL_FUNC int merge_partition( - const device ValT* As, - const device ValT* Bs, - int A_sz, - int B_sz, - int sort_md) { - CompareOp op; - - int A_st = max(0, sort_md - B_sz); - int A_ed = min(sort_md, A_sz); - - while (A_st < A_ed) { - int md = A_st + (A_ed - A_st) / 2; - auto a = As[md]; - auto b = Bs[sort_md - 1 - md]; - - if (op(b, a)) { - A_ed = md; - } else { - A_st = md + 1; - } - } - - return A_ed; - } -}; - -template < - typename ValT, - typename IdxT, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void mb_block_sort( - const device ValT* inp [[buffer(0)]], - device ValT* out_vals [[buffer(1)]], - device IdxT* out_idxs [[buffer(2)]], - const constant int& size_sorted_axis [[buffer(3)]], - const constant int& stride_sorted_axis [[buffer(4)]], - const constant int& nc_dim [[buffer(5)]], - const constant int* nc_shape [[buffer(6)]], - const constant int64_t* nc_strides [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using sort_kernel = KernelMultiBlockMergeSort< - ValT, - IdxT, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD>; - - auto block_idx = elem_to_loc(tid.y, nc_shape, nc_strides, nc_dim); - inp += block_idx; - out_vals += tid.y * size_sorted_axis; - out_idxs += tid.y * size_sorted_axis; - - threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; - - sort_kernel::block_sort( - inp, - out_vals, - out_idxs, - size_sorted_axis, - stride_sorted_axis, - tgp_vals, - tgp_idxs, - tid, - lid); -} - -template < - typename ValT, - typename IdxT, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD> -[[kernel]] void mb_block_partition( - device IdxT* block_partitions [[buffer(0)]], - const device ValT* dev_vals [[buffer(1)]], - const device IdxT* dev_idxs [[buffer(2)]], - const constant int& size_sorted_axis [[buffer(3)]], - const constant int& merge_tiles [[buffer(4)]], - const constant int& n_blocks [[buffer(5)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint3 tgp_dims [[threads_per_threadgroup]]) { - using sort_kernel = KernelMultiBlockMergeSort< - ValT, - IdxT, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD>; - - block_partitions += tid.y * tgp_dims.x; - dev_vals += tid.y * size_sorted_axis; - dev_idxs += tid.y * size_sorted_axis; - - for (int i = lid.x; i <= n_blocks; i += tgp_dims.x) { - // Find location in merge step - int merge_group = i / merge_tiles; - int merge_lane = i % merge_tiles; - - int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; - int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; - - int A_st = min(size_sorted_axis, sort_st); - int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); - int B_st = A_ed; - int B_ed = min(size_sorted_axis, B_st + sort_sz / 2); - - int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane); - int partition = sort_kernel::merge_partition( - dev_vals + A_st, - dev_vals + B_st, - A_ed - A_st, - B_ed - B_st, - partition_at); - - block_partitions[i] = A_st + partition; - } -} - -template < - typename ValT, - typename IdxT, - bool ARG_SORT, - short BLOCK_THREADS, - short N_PER_THREAD, - typename CompareOp = LessThan> -[[kernel, max_total_threads_per_threadgroup(BLOCK_THREADS)]] void -mb_block_merge( - const device IdxT* block_partitions [[buffer(0)]], - const device ValT* dev_vals_in [[buffer(1)]], - const device IdxT* dev_idxs_in [[buffer(2)]], - device ValT* dev_vals_out [[buffer(3)]], - device IdxT* dev_idxs_out [[buffer(4)]], - const constant int& size_sorted_axis [[buffer(5)]], - const constant int& merge_tiles [[buffer(6)]], - const constant int& num_tiles [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - using sort_kernel = KernelMultiBlockMergeSort< - ValT, - IdxT, - ARG_SORT, - BLOCK_THREADS, - N_PER_THREAD, - CompareOp>; - - using block_sort_t = typename sort_kernel::block_merge_sort_t; - - block_partitions += tid.y * (num_tiles + 1); - dev_vals_in += tid.y * size_sorted_axis; - dev_idxs_in += tid.y * size_sorted_axis; - dev_vals_out += tid.y * size_sorted_axis; - dev_idxs_out += tid.y * size_sorted_axis; - - int block_idx = tid.x; - int merge_group = block_idx / merge_tiles; - int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group; - int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles; - int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st; - - int A_st = block_partitions[block_idx + 0]; - int A_ed = block_partitions[block_idx + 1]; - int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st); - int B_ed = min( - size_sorted_axis, - 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed); - - if ((block_idx % merge_tiles) == merge_tiles - 1) { - A_ed = min(size_sorted_axis, sort_st + sort_sz / 2); - B_ed = min(size_sorted_axis, sort_st + sort_sz); - } - - int A_sz = A_ed - A_st; - int B_sz = B_ed - B_st; - - // Load from global memory - thread ValT thread_vals[N_PER_THREAD]; - thread IdxT thread_idxs[N_PER_THREAD]; - for (int i = 0; i < N_PER_THREAD; i++) { - int idx = BLOCK_THREADS * i + lid.x; - if (idx < (A_sz + B_sz)) { - thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx] - : dev_vals_in[B_st + idx - A_sz]; - thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx] - : dev_idxs_in[B_st + idx - A_sz]; - } else { - thread_vals[i] = CompareOp::init; - thread_idxs[i] = 0; - } - } - - // Write to shared memory - threadgroup ValT tgp_vals[sort_kernel::N_PER_BLOCK]; - threadgroup IdxT tgp_idxs[sort_kernel::N_PER_BLOCK]; - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_PER_THREAD; i++) { - int idx = BLOCK_THREADS * i + lid.x; - tgp_vals[idx] = thread_vals[i]; - tgp_idxs[idx] = thread_idxs[i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Merge - int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(lid.x)); - - int A_st_local = block_sort_t::merge_partition( - tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local); - int A_ed_local = A_sz; - - int B_st_local = sort_md_local - A_st_local; - int B_ed_local = B_sz; - - int A_sz_local = A_ed_local - A_st_local; - int B_sz_local = B_ed_local - B_st_local; - - // Do merge - block_sort_t::merge_step( - tgp_vals + A_st_local, - tgp_vals + A_ed_local + B_st_local, - tgp_idxs + A_st_local, - tgp_idxs + A_ed_local + B_st_local, - A_sz_local, - B_sz_local, - thread_vals, - thread_idxs); - - threadgroup_barrier(mem_flags::mem_threadgroup); - for (int i = 0; i < N_PER_THREAD; ++i) { - int idx = lid.x * N_PER_THREAD; - tgp_vals[idx + i] = thread_vals[i]; - tgp_idxs[idx + i] = thread_idxs[i]; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - // Write output - int base_idx = tid.x * sort_kernel::N_PER_BLOCK; - for (int i = lid.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) { - int idx = base_idx + i; - if (idx < size_sorted_axis) { - dev_vals_out[idx] = tgp_vals[i]; - dev_idxs_out[idx] = tgp_idxs[i]; - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/attn.h b/Source/Cmlx/mlx-generated/metal/steel/attn/attn.h deleted file mode 100644 index 8851df68..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/attn.h +++ /dev/null @@ -1,296 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../steel/attn/loader.h" -#include "../../steel/attn/mma.h" -#include "../../steel/attn/params.h" -#include "../../steel/attn/transforms.h" -#include "../../steel/gemm/params.h" -#include "../../steel/utils.h" - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernel class -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template -struct LoopAlignment {}; - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct GEMMKernel { - STEEL_CONST short tgp_padding_a = 16 / sizeof(T); - STEEL_CONST short tgp_padding_b = 16 / sizeof(T); - STEEL_CONST short tgp_mem_size_a = - transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); - STEEL_CONST short tgp_mem_size_b = - transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); - STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; - - STEEL_CONST short tgp_size = WM * WN * 32; - - using loader_a_t = BlockLoader< - T, - transpose_a ? BK : BM, - transpose_a ? BM : BK, - transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, - !transpose_a, - tgp_size>; - using loader_b_t = BlockLoader< - T, - transpose_b ? BN : BK, - transpose_b ? BK : BN, - transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, - transpose_b, - tgp_size>; - using mma_t = BlockMMA< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, - transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, - AccumType, - Epilogue>; - - /* Main kernel function */ - template - static METAL_FUNC void gemm_loop( - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - const int gemm_k_iterations, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - thread mma_t& mma_op, - thread const short& tgp_bm, - thread const short& tgp_bn, - thread const short& lbk, - LoopAlignment l = {}) { - // Appease the compiler - (void)l; - - short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - - short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - if (M_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(tile_dims_A); - } - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - if (!K_aligned_) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - short2 tile_dims_A_last = - transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); - short2 tile_dims_B_last = - transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); - - loader_a.load_safe(tile_dims_A_last); - loader_b.load_safe(tile_dims_B_last); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - } - - /* Main kernel function */ - static METAL_FUNC void run( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device U* D [[buffer(2)]], - const constant GEMMParams* params [[buffer(3)]], - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - threadgroup_barrier(mem_flags::mem_none); - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - D += c_row_long * params->ldd + c_col_long; - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (MN_aligned) { - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Loop tail - if (!K_aligned) { - int lbk = params->K - params->gemm_k_iterations_aligned * BK; - short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); - short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - - // Store results to device memory - mma_op.store_result(D, params->ldd); - return; - - } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { // Loop over K - unaligned case - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; - - if (tgp_bm == BM && tgp_bn == BN) { - gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk); - - mma_op.store_result(D, params->ldd); - return; - - } else if (tgp_bn == BN) { - gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk); - - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - return; - - } else if (tgp_bm == BM) { - gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk); - - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - return; - - } else { - gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk); - - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - return; - } - } - } -}; - -} // namespace steel -} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h deleted file mode 100644 index df891fa3..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h +++ /dev/null @@ -1,471 +0,0 @@ -// Copyright © 2024-25 Apple Inc. - -#include "../../../steel/attn/attn.h" - -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -constant bool align_Q [[function_constant(200)]]; -constant bool align_K [[function_constant(201)]]; - -constant bool has_mask [[function_constant(300)]]; -constant bool do_causal [[function_constant(301)]]; -constant bool has_sinks [[function_constant(302)]]; - -struct MaxOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return metal::max(x, y); - } -}; - -struct SumOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return x + y; - } -}; - -struct MulOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return x * y; - } -}; - -struct SubOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return x - y; - } -}; - -struct ExpSubOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return fast::exp2(x - y); - } -}; - -struct DivOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return x / y; - } -}; - -// clang-format off -template < - typename T, - int BQ, - int BK, - int BD, - int WM, - int WN, - typename MaskType = float, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device T* O [[buffer(3)]], - const constant AttnParams* params [[buffer(4)]], - const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], - const device MaskType* mask [[buffer(6), function_constant(has_mask)]], - const device T* sinks [[buffer(7), function_constant(has_sinks)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on - - // Pacifying compiler - (void)lid; - - // Move to correct block - ulong3 tidl{tid.x, tid.y, tid.z}; - - Q += tidl.z * params->Q_strides[0] + // Batch - tidl.y * params->Q_strides[1] + // Head - tidl.x * BQ * params->Q_strides[2]; // Sequence - - ulong kv_head_idx = int(tid.y) / params->gqa_factor; - K += tidl.z * params->K_strides[0] + // Batch - kv_head_idx * params->K_strides[1]; // Head - - V += tidl.z * params->V_strides[0] + // Batch - kv_head_idx * params->V_strides[1]; // Head - - O += tidl.z * params->O_strides[0] + // Batch - tidl.y * params->O_strides[1] + // Head - tidl.x * BQ * params->O_strides[2]; // Sequence - - if (has_mask) { - mask += tidl.z * mask_params->M_strides[0] + // Batch - tidl.y * mask_params->M_strides[1]; // Head - } - - // Prepare threadgroup memory - constexpr short padQ = 16 / sizeof(T); - constexpr short padK = 16 / sizeof(T); - constexpr short padV = 16 / sizeof(T); - - constexpr short LDQ_tgp = BD + padQ; - constexpr short LDK_tgp = BK + padK; - constexpr short LDV_tgp = BD + padV; - - constexpr short tgp_mem_0 = (BK + padK) * (BD); - constexpr short tgp_mem_1 = BK * (BD + padV); - constexpr short tgp_mem_s = tgp_mem_0 > tgp_mem_1 ? tgp_mem_0 : tgp_mem_1; - - threadgroup T Q_smem[BQ * (BD + padQ)]; - threadgroup T KV_smem[tgp_mem_s]; - - threadgroup T* Qs = Q_smem; - threadgroup T* Ks = KV_smem; - threadgroup T* Vs = KV_smem; - - // Prepare block loaders - using QBlockLoader = BlockLoaderT< - /* typename T = */ T, - /* short BROWS = */ BQ, - /* short BCOLS = */ BD, - /* short kDstStrRow = */ LDQ_tgp, - /* short kDstStrCol = */ 1, - /* short reduction_dim = */ 1, - /* short tgp_size = */ WM * WN * 32>; - - // K is loaded in transposed - using KBlockLoader = BlockLoaderT< - /* typename T = */ T, - /* short BROWS = */ BK, - /* short BCOLS = */ BD, - /* short kDstStrRow = */ 1, - /* short kDstStrCol = */ LDK_tgp, - /* short reduction_dim = */ 0, - /* short tgp_size = */ WM * WN * 32>; - - using VBlockLoader = BlockLoaderT< - /* typename T = */ T, - /* short BROWS = */ BK, - /* short BCOLS = */ BD, - /* short kDstStrRow = */ LDV_tgp, - /* short kDstStrCol = */ 1, - /* short reduction_dim = */ 0, - /* short tgp_size = */ WM * WN * 32>; - - QBlockLoader loader_q( - Q, params->Q_strides[2], Qs, simd_group_id, simd_lane_id); - KBlockLoader loader_k( - K, params->K_strides[2], Ks, simd_group_id, simd_lane_id); - VBlockLoader loader_v( - V, params->V_strides[2], Vs, simd_group_id, simd_lane_id); - - const AccumType scale = params->scale * M_LOG2E_F; - - // Prepare MMA tiles - constexpr short kFragSize = 8; // MMAFrag size - using MMAFrag_acc_t = BaseMMAFrag; - - constexpr int kNWarps = WM * WN; - static_assert( - BQ >= (kNWarps * kFragSize) && BQ % (kNWarps * kFragSize) == 0, - "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); - - // Q seq frags per warp - constexpr int TQ = BQ / (kNWarps * kFragSize); - // KV sequence frags (all warps load the same frags) - constexpr int TK = BK / kFragSize; - // HeadDim frags (all warps load the same frags) - constexpr int TD = BD / kFragSize; - - static_assert(TQ == 1, "Check TQ"); - - MMATile Qtile; - MMATile Ktile; - MMATile Stile; - MMATile Vtile; - MMATile Otile; - - Otile.clear(); - - // Prepare mma tile offsets - const short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); - const short sm = simd_coord.y; - const short sn = simd_coord.x; - const short tm = kFragSize * TQ * simd_group_id; - - const short Qs_offset = (tm + sm) * LDQ_tgp + sn; - const short Ks_offset = sm * LDK_tgp + sn; - const short Vs_offset = sm * LDV_tgp + sn; - - constexpr short Qs_tile_stride = kFragSize; - constexpr short Ks_tile_stride = kFragSize * LDK_tgp; - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load Q blocks - if (!align_Q && int(tid.x) == (params->NQ_aligned)) { - loader_q.load_safe(short2(BD, params->qL_rem)); - } else { - loader_q.load_unsafe(); - } - - // Init row reduction variables - constexpr short kRowsPT = decltype(Stile)::kRowsPerThread; - - AccumType max_score[kRowsPT]; - AccumType sum_score[kRowsPT] = {0}; - - // Init to -Inf - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - max_score[i] = Limits::finite_min; - } - - if (has_sinks) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); - sum_score[i] = 1; - } - } - - int kb_lim = params->NK; - - if (do_causal) { - int q_max = (tid.x + 1) * BQ + params->qL_off; - kb_lim = (q_max + BK - 1) / BK; - kb_lim = min(params->NK, kb_lim); - } - - // Loop over KV seq length - for (int kb = 0; kb < kb_lim; kb++) { - // Load K block and apply scale - threadgroup_barrier(mem_flags::mem_threadgroup); - if (!align_K && kb == (params->NK_aligned)) { - loader_k.load_safe(short2(BD, params->kL_rem)); - } else { - loader_k.load_unsafe(); - } - - // Do S = Q @ K.T - Stile.clear(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_UNROLL - for (short dd = 0; dd < TD; dd++) { - simdgroup_barrier(mem_flags::mem_none); - - Qtile.template load( - &Qs[Qs_offset + dd * Qs_tile_stride]); - Ktile.template load( - &Ks[Ks_offset + dd * Ks_tile_stride]); - - simdgroup_barrier(mem_flags::mem_none); - - tile_matmad(Stile, Qtile, Ktile, Stile); - } - - // Apply scale in float32 - STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { - Stile.elems()[ii] *= scale; - } - - // Mask out length sequence - if (!align_K && kb == (params->NK_aligned)) { - using stile_t = decltype(Stile); - using selem_t = typename stile_t::elem_type; - constexpr auto neg_inf = Limits::finite_min; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < stile_t::kTileRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < stile_t::kTileCols; j++) { - short col_pos = sn + (j * stile_t::kFragCols); - STEEL_PRAGMA_UNROLL - for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { - if ((col_pos + jj) >= params->kL_rem) { - Stile.frag_at(i, j)[jj] = neg_inf; - } - } - } - } - } - - // Mask out if causal - if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { - using stile_t = decltype(Stile); - using selem_t = typename stile_t::elem_type; - constexpr auto neg_inf = Limits::finite_min; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < stile_t::kTileRows; i++) { - const int row_pos = - tid.x * BQ + params->qL_off + tm + sm + (i * stile_t::kFragRows); - STEEL_PRAGMA_UNROLL - for (short j = 0; j < stile_t::kTileCols; j++) { - const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); - STEEL_PRAGMA_UNROLL - for (short jj = 0; jj < stile_t::MMAFrag_t::kElemCols; jj++) { - if (row_pos < (col_pos + jj)) { - Stile.frag_at(i, j)[jj] = neg_inf; - } - } - } - } - } - - // Other masking as needed - if (has_mask) { - using stile_t = decltype(Stile); - using selem_t = typename stile_t::elem_type; - constexpr auto neg_inf = Limits::finite_min; - - constexpr bool is_bool = is_same_v; - using melem_t = typename metal::conditional_t; - - using MMAFrag_mask_t = BaseMMAFrag; - using frag_t = typename MMAFrag_mask_t::frag_type; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < stile_t::kTileRows; i++) { - const int row_pos = tid.x * BQ + tm + sm + (i * stile_t::kFragRows); - STEEL_PRAGMA_UNROLL - for (short j = 0; j < stile_t::kTileCols; j++) { - const int col_pos = kb * BK + sn + (j * stile_t::kFragCols); - - frag_t mfrag; - - MMAFrag_mask_t::load_safe( - mfrag, - mask, - int64_t(mask_params->M_strides[2]), - Int<1>{}, - params->qL, - params->kL, - row_pos, - col_pos); - - STEEL_PRAGMA_UNROLL - for (short jj = 0; jj < stile_t::MMAFrag_t::kElemsPerFrag; jj++) { - if constexpr (is_bool) { - Stile.frag_at(i, j)[jj] = - mfrag[jj] ? Stile.frag_at(i, j)[jj] : neg_inf; - } else { - Stile.frag_at(i, j)[jj] += M_LOG2E_F * selem_t(mfrag[jj]); - } - } - } - } - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load V blocks - if (!align_K && kb == (params->NK_aligned)) { - loader_v.load_safe(short2(BD, params->kL_rem)); - } else { - loader_v.load_unsafe(); - } - - // Do softmax - - // Temp variables - AccumType new_max[kRowsPT]; - AccumType factor[kRowsPT]; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - new_max[i] = max_score[i]; - } - - // Row max - Stile.template row_reduce(new_max); - - // exp(Si - rowmax(Si)) - Stile.template row_bin_op(new_max); - - // Factor exp(rowmax(Si) - rowmax(Si-1)) - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - factor[i] = fast::exp2(max_score[i] - new_max[i]); - } - - // Save max for next iteration - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - max_score[i] = new_max[i]; - } - - // Row Sum - AccumType sum_score_tmp[kRowsPT] = {0}; - Stile.template row_reduce(sum_score_tmp); - - // Update norm - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - sum_score[i] = sum_score[i] * factor[i] + sum_score_tmp[i]; - } - - // Update O - Otile.template row_bin_op(factor); - - // Load V into registers - threadgroup_barrier(mem_flags::mem_threadgroup); - - STEEL_PRAGMA_UNROLL - for (short iq = 0; iq < TQ; iq++) { - STEEL_PRAGMA_UNROLL - for (short id = 0; id < TD; id++) { - STEEL_PRAGMA_UNROLL - for (short ik = 0; ik < TK; ik++) { - if constexpr (BD == 128) { - simdgroup_barrier(mem_flags::mem_none); - } - - const short kk = ik * kFragSize; - const short dd = id * kFragSize; - - Vtile.template load( - &Vs[Vs_offset + kk * LDV_tgp + dd]); - - if constexpr (BD == 128) { - simdgroup_barrier(mem_flags::mem_none); - } - - MMAFrag_acc_t::mma( - Otile.frag_at(iq, id), - Stile.frag_at(iq, ik), - Vtile.frag_at(0, 0), - Otile.frag_at(iq, id)); - } - } - } - - // Prepare for next iteration - loader_k.next(); - loader_v.next(); - } - - // Normalize output - Otile.template row_bin_op(sum_score); - threadgroup_barrier(mem_flags::mem_none); - - // Store results - O += (tm + sm) * params->O_strides[2] + sn; - - if (!align_Q && int(tid.x) == (params->NQ_aligned)) { - auto dst_tile_dims = short2(BD - sn, params->qL_rem - (tm + sm)); - - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - - Otile.template store_safe(O, params->O_strides[2], dst_tile_dims); - } else { - Otile.template store(O, params->O_strides[2]); - } -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal deleted file mode 100644 index a68dcfc5..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.metal +++ /dev/null @@ -1,27 +0,0 @@ -// Copyright © 2024-25 Apple Inc. - -// clang-format off -#include "../../../utils.h" - -#include "../../../steel/attn/kernels/steel_attention.h" - -#define instantiate_attn(tname, dtype, bq, bk, bd, wm, wn, mname, mtype) \ - instantiate_kernel( \ - "steel_attention_" #tname "_bq" #bq "_bk" #bk "_bd" #bd \ - "_wm" #wm "_wn" #wn "_mask" #mname, \ - attention, dtype, bq, bk, bd, wm, wn, mtype, float) - -#define instantiate_attn_shapes_helper(iname, itype, mname, mtype) \ - instantiate_attn(iname, itype, 32, 16, 128, 4, 1, mname, mtype) \ - instantiate_attn(iname, itype, 32, 32, 80, 4, 1, mname, mtype) \ - instantiate_attn(iname, itype, 32, 32, 64, 4, 1, mname, mtype) - -#define instantiate_attn_mask_helper(iname, itype) \ - instantiate_attn_shapes_helper(iname, itype, iname, itype) \ - instantiate_attn_shapes_helper(iname, itype, bool_, bool) - -instantiate_attn_mask_helper(float16, half); -instantiate_attn_mask_helper(bfloat16, bfloat16_t); - -instantiate_attn_mask_helper(float32, float); -// clang-format on diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention_nax.h b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention_nax.h deleted file mode 100644 index 4edc1729..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention_nax.h +++ /dev/null @@ -1,481 +0,0 @@ -// Copyright © 2024-25 Apple Inc. - -#include "../../../steel/attn/nax.h" -#include "../../../steel/attn/params.h" -#include "../../../steel/attn/transforms.h" -#include "../../../steel/utils.h" - -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -constant bool align_Q [[function_constant(200)]]; -constant bool align_K [[function_constant(201)]]; - -constant bool has_mask [[function_constant(300)]]; -constant bool do_causal [[function_constant(301)]]; -constant bool has_sinks [[function_constant(302)]]; - -template -struct TransformScale { - T scale; - METAL_FUNC TransformScale(T scale_) : scale(scale_) {} - - METAL_FUNC T apply(T x) const { - return scale * x; - } -}; - -struct MaxOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return metal::max(x, y); - } -}; - -struct SumOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return x + y; - } -}; - -struct MulOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return x * y; - } -}; - -struct SubOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return x - y; - } -}; - -struct ExpSubOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return fast::exp2(x - y); - } -}; - -struct DivOp { - template - METAL_FUNC static constexpr T apply(T x, T y) { - return x / y; - } -}; - -// clang-format off -template < - typename T, - int BQ, - int BK, - int BD, - int WM, - int WN, - typename MaskType = float, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void attention_nax( - const device T* Q [[buffer(0)]], - const device T* K [[buffer(1)]], - const device T* V [[buffer(2)]], - device T* O [[buffer(3)]], - const constant AttnParams* params [[buffer(4)]], - const constant AttnMaskParams* mask_params [[buffer(5), function_constant(has_mask)]], - const device MaskType* mask [[buffer(6), function_constant(has_mask)]], - const device T* sinks [[buffer(7), function_constant(has_sinks)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on - - // Pacifying compiler - (void)lid; - (void)simd_lane_id; - - // Move to correct block - ulong3 tidl{tid.x, tid.y, tid.z}; - - Q += tidl.z * params->Q_strides[0] + // Batch - tidl.y * params->Q_strides[1] + // Head - tidl.x * BQ * params->Q_strides[2]; // Sequence - - ulong kv_head_idx = int(tid.y) / params->gqa_factor; - K += tidl.z * params->K_strides[0] + // Batch - kv_head_idx * params->K_strides[1]; // Head - - V += tidl.z * params->V_strides[0] + // Batch - kv_head_idx * params->V_strides[1]; // Head - - O += tidl.z * params->O_strides[0] + // Batch - tidl.y * params->O_strides[1] + // Head - tidl.x * BQ * params->O_strides[2]; // Sequence - - if (has_mask) { - mask += tidl.z * mask_params->M_strides[0] + // Batch - tidl.y * mask_params->M_strides[1]; // Head - } - - const metal::uniform scale2 = - make_uniform(params->scale) * make_uniform(1.44269504089f); - - // Prepare MMA tiles - constexpr short UQ = 16; - constexpr short UD = 32; - - constexpr int kNWarps = WM * WN; - static_assert( - BQ >= (kNWarps * UQ) && BQ % (kNWarps * UQ) == 0, - "Each simdgroup must host atleast 1 simdgroup matrix along Q sequence."); - - // Q seq frags per warp - constexpr int TQ = BQ / (kNWarps * UQ); - // HeadDim frags (all warps load the same frags) - constexpr int TD = BD / UD; - - static_assert(TQ == 1, "Check TQ"); - - using OSubTile = NAXSubTile; - NAXTile Otile; - - Otile.clear(); - - // Prepare mma tile offsets - const short2 simd_coord = OSubTile::NAXFrag_t::get_coord(); - const short sm = simd_coord.y; - const short sn = simd_coord.x; - const short tm = UQ * TQ * simd_group_id; - - Q += (tm + sm) * int(params->Q_strides[2]) + sn; - K += sm * int(params->K_strides[2]) + sn; - V += sm * int(params->V_strides[2]) + sn; - - // Init row reduction variables - constexpr short kRowsPT = decltype(Otile)::kRowsPerThread; - - metal::vec max_score; - metal::vec sum_score{0}; - - // Init to -Inf - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - max_score[i] = Limits::finite_min; - } - - if (has_sinks) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - max_score[i] = M_LOG2E_F * static_cast(sinks[tidl.y]); - sum_score[i] = 1; - } - } - - int kb_lim = params->NK; - - if (do_causal) { - int q_max = (tid.x + 1) * BQ + params->qL_off; - kb_lim = (q_max + BK - 1) / BK; - kb_lim = min(params->NK, kb_lim); - } - - const bool is_last_bq = int(tid.x) == (params->NQ_aligned); - // const bool is_last_tq = int(simd_group_id) >= (params->qL_rem / UQ); - const bool is_last_q = is_last_bq; - - const short lim_rows_q = params->qL_rem - (tm + sm); - const short lim_rows_k = params->kL_rem - sm; - - // Loop over KV seq length - for (int kb = 0; kb < kb_lim; kb++) { - const int is_last_k = (kb == (params->NK_aligned)); - - // Do S = Q @ K.T - constexpr short UDs = 16; - constexpr short UKs = 32; - - constexpr short TDs = BD / UDs; - constexpr short TKs = BK / UKs; - - using SSubTile = NAXSubTile; - using QSubTile = NAXSubTile; - using KSubTile = NAXSubTile; - - NAXTile Stile; - - Stile.clear(); - - STEEL_PRAGMA_UNROLL - for (short iq = 0; iq < TQ; iq++) { - STEEL_PRAGMA_UNROLL - for (short ik = 0; ik < TKs; ik++) { - STEEL_PRAGMA_UNROLL - for (short id = 0; id < TDs; id++) { - NAXTile Qtile; - NAXTile Ktile; - - const int Q_load_off = iq * UQ * int(params->Q_strides[2]) + id * UDs; - const int K_load_off = - ik * UKs * int(params->K_strides[2]) + id * UDs; - - if (!align_Q && is_last_q) { - // Qtile.load_rows( - // Q + Q_load_off, - // int(params->Q_strides[2]), - // lim_rows_q - iq * UQ); - Qtile.load_safe( - Q + Q_load_off, - int(params->Q_strides[2]), - short2(BD, lim_rows_q - iq * UQ)); - } else { - Qtile.load(Q + Q_load_off, int(params->Q_strides[2])); - } - - if (!align_K && is_last_k) { - // Ktile.load_rows( - // K + K_load_off, - // int(params->K_strides[2]), - // lim_rows_k - ik * UKs); - Ktile.load_safe( - K + K_load_off, - int(params->K_strides[2]), - short2(BD, lim_rows_k - ik * UKs)); - } else { - Ktile.load(K + K_load_off, int(params->K_strides[2])); - } - - subtile_matmad_nax( - Stile.subtile_at(iq, ik), - Qtile.subtile_at(0, 0), - metal::false_type{}, - Ktile.subtile_at(0, 0), - metal::true_type{}); - } - } - } - - // Scale S - STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { - Stile.elems()[ii] *= float(scale2); - } - - // Scale and Retile S - constexpr short UK = 16; - constexpr short TK = BK / UK; - using PSubTile = NAXSubTile; - - NAXTile Ptile; - - STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < decltype(Stile)::kElemsPerTile; ii++) { - Ptile.elems()[ii] = Stile.elems()[ii]; - } - - // Mask out length sequence - if (!align_K && is_last_k) { - constexpr auto neg_inf = Limits::finite_min; - - STEEL_PRAGMA_UNROLL - for (short iq = 0; iq < TQ; iq++) { - STEEL_PRAGMA_UNROLL - for (short ik = 0; ik < TK; ik++) { - const short col_pos = sn + ik * UK; - - thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); - - STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { - STEEL_PRAGMA_UNROLL - for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { - const auto loc = ii * PSubTile::kFragThrCols + jj; - fg[loc] = ((col_pos + jj) >= params->kL_rem) ? neg_inf : fg[loc]; - } - } - } - } - } - - // Mask out if causal - if (do_causal && kb >= (kb_lim - ((BQ + BK - 1) / BK) - int(!align_K))) { - constexpr auto neg_inf = Limits::finite_min; - - const int base_row = tid.x * BQ + params->qL_off + tm; - const int base_col = kb * BK; - - STEEL_PRAGMA_UNROLL - for (short iq = 0; iq < TQ; iq++) { - STEEL_PRAGMA_UNROLL - for (short ik = 0; ik < TK; ik++) { - const short row_pos = base_row + iq * UQ; - const short col_pos = base_col + ik * UK; - - thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); - - STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < PSubTile::kFragThrRows; ii++) { - STEEL_PRAGMA_UNROLL - for (short jj = 0; jj < PSubTile::kFragThrCols; jj++) { - const auto r = row_pos + ii * PSubTile::kFragRowsJump + sm; - const auto c = col_pos + jj + sn; - const auto loc = ii * PSubTile::kFragThrCols + jj; - fg[loc] = (r < c) ? neg_inf : fg[loc]; - } - } - } - } - } - - // Other masking as needed - if (has_mask) { - constexpr auto neg_inf = Limits::finite_min; - - const int base_row = tid.x * BQ + tm; - const int base_col = kb * BK; - - constexpr bool is_bool = is_same_v; - using melem_t = typename metal::conditional_t; - using MSubTile = NAXSubTile; - - STEEL_PRAGMA_UNROLL - for (short iq = 0; iq < TQ; iq++) { - STEEL_PRAGMA_UNROLL - for (short ik = 0; ik < TK; ik++) { - const short row_pos = base_row + iq * UQ + sm; - const short col_pos = base_col + ik * UK + sn; - - MSubTile mfrag; - mfrag.load_safe( - mask, - int64_t(mask_params->M_strides[2]), - Int<1>{}, - params->qL, - params->kL, - row_pos, - col_pos); - - thread auto& fg = Ptile.subtile_at(iq, ik).frag_at(0, 0); - - STEEL_PRAGMA_UNROLL - for (short jj = 0; jj < MSubTile::kElemsPerFrag; jj++) { - if constexpr (is_bool) { - fg[jj] = mfrag.elems()[jj] ? fg[jj] : neg_inf; - } else { - fg[jj] += M_LOG2E_F * AccumType(mfrag.elems()[jj]); - } - } - } - } - } - - // Do softmax - - // Temp variables - metal::vec new_max; - metal::vec factor; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - new_max[i] = max_score[i]; - } - - // Row max - Ptile.template row_reduce(new_max); - - // exp(Si - rowmax(Si)) - Ptile.template row_bin_op(new_max); - - // Factor exp(rowmax(Si) - rowmax(Si-1)) - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - factor[i] = fast::exp2(max_score[i] - new_max[i]); - max_score[i] = new_max[i]; - } - - // Row Sum - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - sum_score[i] = sum_score[i] * factor[i]; - } - - Ptile.template row_reduce(sum_score); - - // Update O - Otile.template row_bin_op(factor); - - simdgroup_barrier(mem_flags::mem_none); - - // Do O = P @ V - STEEL_PRAGMA_UNROLL - for (short iq = 0; iq < TQ; iq++) { - STEEL_PRAGMA_UNROLL - for (short id = 0; id < TD; id++) { - if constexpr (BD == 128) { - if (id == 2) { - threadgroup_barrier(mem_flags::mem_none); - } - } - - STEEL_PRAGMA_UNROLL - for (short ik = 0; ik < TK; ik++) { - using VSubTile = NAXSubTile; - NAXTile Vtile; - - const int V_load_off = ik * UK * int(params->V_strides[2]) + id * UD; - - if (!align_K && is_last_k) { - // Vtile.load_rows( - // V + V_load_off, - // int(params->V_strides[2]), - // lim_rows_k - ik * UK); - Vtile.load_safe( - V + V_load_off, - int(params->V_strides[2]), - short2(BD, lim_rows_k - ik * UK)); - } else { - Vtile.load(V + V_load_off, int(params->V_strides[2])); - } - - subtile_matmad_nax( - Otile.subtile_at(iq, id), - Ptile.subtile_at(iq, ik), - metal::bool_constant{}, - Vtile.subtile_at(0, 0), - metal::bool_constant{}); - } - } - } - - // Prepare for next iteration - K += BK * int(params->K_strides[2]); - V += BK * int(params->V_strides[2]); - } - - // Normalize output - - threadgroup_barrier(mem_flags::mem_none); - - metal::vec rcp; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kRowsPT; ++i) { - rcp[i] = 1.f / sum_score[i]; - } - - Otile.template row_bin_op(rcp); - - // Store results - O += (tm + sm) * int(params->O_strides[2]) + sn; - - if (!align_Q && is_last_q) { - if (lim_rows_q <= 0) - return; - - // Otile.store_rows(O, params->O_strides[2], lim_rows_q); - Otile.store_safe(O, params->O_strides[2], short2(BD, lim_rows_q)); - } else { - Otile.store(O, int(params->O_strides[2])); - } -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h b/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h deleted file mode 100644 index 3b7c5166..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h +++ /dev/null @@ -1,264 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../steel/defines.h" - -/////////////////////////////////////////////////////////////////////////////// -// Loading helper -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short alignment = 1, - short n_reads = (BCOLS * BROWS) / (tgp_size), - short TCOLS = BCOLS / n_reads, - short TROWS = tgp_size / TCOLS> -struct BlockLoader { - STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; - STEEL_CONST short vec_size = n_reads; - - // Leading dimension for src - const int src_ld; - const int tile_stride; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - struct alignas(alignment * sizeof(T)) ReadVector { - uint8_t v[sizeof(T) * vec_size]; - }; - - /* Constructor */ - METAL_FUNC BlockLoader( - const device T* src_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj) {} - - /* Apply operation to threadgroup without bound checking */ - template - METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); - } - } - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - *((threadgroup ReadVector*)(&dst[i * dst_ld])) = - *((const device ReadVector*)(&src[i * src_ld])); - } - } - - /* Load from device memory into threadgroup memory - with bound checking */ - METAL_FUNC void load_safe(short2 src_tile_dim) const { - src_tile_dim = src_tile_dim - short2(bj, bi); - - // Skip loading if thread has no valid reads - if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - return; - } - - // Use fast thread memory for bound checks - bool tmp_idx[vec_size]; - T tmp_val[vec_size]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - // Make sure tmp_idx only contains valid indices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); - } - - // Read valid indices into tmp_val - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; - } - - // Zero out unneeded values - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); - } - - // Copy values to threadgroup memory - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = tmp_val[j]; - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - src += tile_stride; - } -}; - -template -struct CShape { - STEEL_CONST int kRows = R; - STEEL_CONST int kCols = C; -}; - -template < - typename T, - short BROWS, - short BCOLS, - short kDstStrRow, - short kDstStrCol, - short reduction_dim, - short tgp_size, - short n_reads = (BCOLS * BROWS) / (tgp_size), - short TCOLS = BCOLS / n_reads, - short TROWS = tgp_size / TCOLS> -struct BlockLoaderT { - STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; - STEEL_CONST short vec_size = n_reads; - - // Leading dimension for src - const int src_ld; - const int tile_stride; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - /* Constructor */ - METAL_FUNC BlockLoaderT( - const device T* src_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * kDstStrRow + bj * kDstStrCol), - src(src_ + bi * src_ld + bj) {} - - /* Apply operation to threadgroup without bound checking */ - template - METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * kDstStrRow + j * kDstStrCol] = - op.apply(dst[i * kDstStrRow + j * kDstStrCol]); - } - } - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * kDstStrRow + j * kDstStrCol] = src[i * src_ld + j]; - } - } - } - - /* Load from device memory into threadgroup memory - with bound checking */ - METAL_FUNC void load_safe(short2 src_tile_dim) const { - src_tile_dim = src_tile_dim - short2(bj, bi); - - // Skip loading if thread has no valid reads - if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * kDstStrRow + j * kDstStrCol] = T(0); - } - } - return; - } - - // Use fast thread memory for bound checks - bool tmp_idx[vec_size]; - T tmp_val[vec_size]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - // Make sure tmp_idx only contains valid indices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); - } - - // Read valid indices into tmp_val - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; - } - - // Zero out unneeded values - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); - } - - // Copy values to threadgroup memory - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * kDstStrRow + j * kDstStrCol] = tmp_val[j]; - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - src += tile_stride; - } -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/mma.h b/Source/Cmlx/mlx-generated/metal/steel/attn/mma.h deleted file mode 100644 index a735848d..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/mma.h +++ /dev/null @@ -1,750 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include -#include -#include - -#include "../../steel/attn/transforms.h" -#include "../../steel/defines.h" -#include "../../steel/utils/integral_constant.h" - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// MMA helper -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template -struct Shape2D { - RInt r; - CInt c; - - Shape2D(RInt r_, CInt c_) : r(r_), c(c_) {} -}; - -template -struct Layout2D { - Shape shape; - Layout layout; -}; - -template -struct BaseMMAFrag { - static_assert( - kFragRows_ == 8, - "Only 8 x 8 fragment matrices are currently supported"); - static_assert( - kFragCols_ == 8, - "Only 8 x 8 fragment matrices are currently supported"); -}; - -template -struct BaseMMAFrag { - STEEL_CONST int kFragRows = 8; - STEEL_CONST int kFragCols = 8; - - STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; - - STEEL_CONST int kElemRows = 1; - STEEL_CONST int kElemCols = 2; - - static_assert( - kElemRows * kElemCols == kElemsPerFrag, - "MMAFrag shape is not consistent with MMAFrag size"); - - typedef metal::simdgroup_matrix mat_type; - typedef metal::vec frag_type; - typedef metal::vec row_frag_type; - typedef metal::vec col_frag_type; - - template - using dtype_mat_t = typename metal::simdgroup_matrix; - - template - using dtype_frag_t = typename metal::vec; - - METAL_FUNC static constexpr short2 get_coord( - ushort simd_lane_id [[thread_index_in_simdgroup]]) { - const short qid = simd_lane_id / 4; - const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); - const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - return short2{fn, fm}; - } - - template - METAL_FUNC static constexpr void - load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); - } - } - } - - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX, - typename OffY> - METAL_FUNC static constexpr void load_safe( - thread frag_type& dst, - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) { - src += off_x * str_x + off_y * str_y; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - if ((off_x + i) < lim_x && (off_y + j) < lim_y) { - dst[i * kElemCols + j] = static_cast(src[0]); - } else { - dst[i * kElemCols + j] = T(0); - } - src += str_y; - } - src -= kElemCols * str_y; - src += str_x; - } - } - - template - METAL_FUNC static constexpr void - store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { - using U = pointer_element_t; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX, - typename OffY> - METAL_FUNC static constexpr void store_safe( - const thread frag_type& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) { - using U = pointer_element_t; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - if ((off_x + i) < lim_x && (off_y + j) < lim_y) { - dst[(off_x + i) * str_x + (off_y + j) * str_y] = - static_cast(src[i * kElemCols + j]); - } - } - } - } - - template - METAL_FUNC static constexpr void mma( - thread frag_type& D, - thread dtype_frag_t& A, - thread dtype_frag_t& B, - thread dtype_frag_t& C) { - mat_type D_mat; - dtype_mat_t A_mat; - dtype_mat_t B_mat; - dtype_mat_t C_mat; - - reinterpret_cast&>(A_mat.thread_elements()) = A; - reinterpret_cast&>(B_mat.thread_elements()) = B; - reinterpret_cast&>(C_mat.thread_elements()) = C; - - mma(D_mat, A_mat, B_mat, C_mat); - - D = reinterpret_cast(D_mat.thread_elements()); - } - - template - METAL_FUNC static constexpr void mma( - thread mat_type& D, - thread dtype_mat_t& A, - thread dtype_mat_t& B, - thread dtype_mat_t& C) { - simdgroup_multiply_accumulate(D, A, B, C); - } - - template - METAL_FUNC static constexpr void row_reduce( - thread const frag_type& inp_vals, - thread T* reduced_vals) { - T thr_reduce = Op::apply(inp_vals.x, inp_vals.y); - - T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); - qgr_reduce = Op::apply(thr_reduce, qgr_reduce); - - T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); - sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); - - reduced_vals[0] = Op::apply(reduced_vals[0], sgr_reduce); - } - - template - METAL_FUNC static constexpr void row_bin_op( - thread frag_type& inp_vals, - thread T* row_vals) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - inp_vals[i * kElemCols + j] = - Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); - } - } - } -}; - -template < - typename T, - int kTileRows_, - int kTileCols_, - class MMAFrag_ = BaseMMAFrag> -struct MMATile { - using MMAFrag_t = MMAFrag_; - using elem_type = T; - STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; - STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; - STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; - - STEEL_CONST int kTileRows = kTileRows_; - STEEL_CONST int kTileCols = kTileCols_; - - STEEL_CONST int kRows = kTileRows * kFragRows; - STEEL_CONST int kCols = kTileCols * kFragCols; - - STEEL_CONST int kNumFrags = kTileRows * kTileCols; - STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; - - STEEL_CONST int kRowsPerThread = kTileRows * MMAFrag_t::kElemRows; - STEEL_CONST int kColsPerThread = kTileCols * MMAFrag_t::kElemCols; - - typedef typename MMAFrag_t::mat_type mat_type; - typedef typename MMAFrag_t::frag_type frag_type; - - frag_type val_frags[kNumFrags]; // = {frag_type(0)}; - - METAL_FUNC MMATile() thread {} - - METAL_FUNC constexpr void clear() { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kNumFrags; ++i) { - val_frags[i] = frag_type(0); - } - } - - METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { - return val_frags[i * kTileCols + j]; - } - - METAL_FUNC constexpr const thread frag_type& frag_at( - const short i, - const short j) const { - return val_frags[i * kTileCols + j]; - } - - METAL_FUNC mat_type mat_at(const short i, const short j) { - mat_type val_mat; - STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < kElemsPerFrag; ++ii) { - val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; - } - return val_mat; - } - - METAL_FUNC thread elem_type* elems() { - return reinterpret_cast(val_frags); - } - - METAL_FUNC const thread elem_type* elems() const { - return reinterpret_cast(val_frags); - } - - template - METAL_FUNC void row_reduce(thread T vals[kRowsPerThread]) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::template row_reduce( - frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); - } - } - } - - template - METAL_FUNC void row_bin_op(thread T vals[kRowsPerThread]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::template row_bin_op( - frag_at(i, j), &vals[i * MMAFrag_t::kElemRows]); - } - } - } - - template - METAL_FUNC void load(const threadgroup U* src) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::load( - frag_at(i, j), - &( - src[(i * kFragRows) * w_x * str_x + - (j * kFragCols) * w_y * str_y]), - Int{}, - Int{}); - } - } - } - - template - METAL_FUNC void store(threadgroup U* dst) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::store( - frag_at(i, j), - &( - dst[(i * kFragRows) * w_x * str_x + - (j * kFragCols) * w_y * str_y]), - Int{}, - Int{}); - } - } - } - - template - METAL_FUNC void load(const device U* src, const int ld) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::load( - frag_at(i, j), - &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), - ld, - Int<1>{}); - } - } - } - - template - METAL_FUNC void store(device U* dst, const int ld) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::store( - frag_at(i, j), - &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), - ld, - Int<1>{}); - } - } - } - - template - METAL_FUNC void - load_safe(const device U* src, const int ld, const short2 src_tile_dims) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - MMAFrag_t::load_safe( - frag_at(i, j), - src, - ld, - Int<1>{}, - src_tile_dims.y, - src_tile_dims.x, - (i * kFragRows) * w_x, - (j * kFragCols) * w_y); - } - } - } - - template - METAL_FUNC void - store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - MMAFrag_t::store_safe( - frag_at(i, j), - dst, - ld, - Int<1>{}, - dst_tile_dims.y, - dst_tile_dims.x, - (i * kFragRows) * w_x, - (j * kFragCols) * w_y); - } - } - } -}; - -template < - typename Dtype, - typename Atype, - typename Btype, - typename Ctype, - int M, - int N, - int K, - class MMAFragD, - class MMAFragA, - class MMAFragB, - class MMAFragC> -METAL_FUNC void tile_matmad( - thread MMATile& D, - thread MMATile& A, - thread MMATile& B, - thread MMATile& C) { - STEEL_PRAGMA_UNROLL - for (short m = 0; m < M; ++m) { - STEEL_PRAGMA_UNROLL - for (short n = 0; n < N; ++n) { - short m_serp = m; //(n % 2) ? (M - 1 - m) : m; - short n_serp = (m % 2) ? (N - 1 - n) : n; - - STEEL_PRAGMA_UNROLL - for (short k = 0; k < K; ++k) { - MMAFragD::mma( - D.frag_at(m_serp, n_serp), - A.frag_at(m_serp, k), - B.frag_at(k, n_serp), - C.frag_at(m_serp, n_serp)); - } - } - } -} - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - short lda_tgp, - short ldb_tgp, - typename AccumType = float, - typename Epilogue = TransformNone> -struct BlockMMA { - // MMAFrag size - STEEL_CONST short kFragSize = 8; - using MMAFrag_acc_t = BaseMMAFrag; - - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = kFragSize * WM; - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = kFragSize * WN; - - // Warp tile size along M - STEEL_CONST short TM = BM / TM_stride; - // Warp tile size along N - STEEL_CONST short TN = BN / TN_stride; - - // Threadgroup A strides - STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M - STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K - - // Threadgroup B strides - STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K - STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N - - // Threadgroup strides along K - STEEL_CONST short tile_stride_a = kFragSize * A_str_k; - STEEL_CONST short tile_stride_b = kFragSize * B_str_k; - - // Simdgroup matrices - MMATile Atile; - MMATile Btile; - MMATile Ctile; - - // Offsets within threadgroup - short sm; - short sn; - - short As_offset; - short Bs_offset; - - /* Constructor */ - METAL_FUNC BlockMMA( - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) { - // Determine thread position in simdgroup matrix - short tm = kFragSize * (simd_group_id / WN); - short tn = kFragSize * (simd_group_id % WN); - - short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); - sm = simd_coord.y; - sn = simd_coord.x; - - // Determine thread and simdgroup offset - As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K - Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N - - sm += tm; - sn += tn; - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { - // Adjust for simdgroup and thread location - As += As_offset; - Bs += Bs_offset; - - // Iterate over BK in blocks of kFragSize - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += kFragSize) { - simdgroup_barrier(mem_flags::mem_none); - - Atile.template load(As); - - simdgroup_barrier(mem_flags::mem_none); - - Btile.template load(Bs); - - simdgroup_barrier(mem_flags::mem_none); - - tile_matmad(Ctile, Atile, Btile, Ctile); - - // Progress to next simdgroup tile - As += tile_stride_a; - Bs += tile_stride_b; - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* D, const int ldd) { - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { - Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); - } - - // Adjust for simdgroup and thread location - D += sm * ldd + sn; - - Ctile.template store(D, ldd); - } - - METAL_FUNC void - store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { - Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); - } - - // Adjust for simdgroup and thread location - D += sm * ldd + sn; - dst_tile_dims -= short2(sn, sm); - - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - - Ctile.template store_safe(D, ldd, dst_tile_dims); - } - - /* Apply epilogue */ - template - METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { - Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); - } - } - - /* Apply epilogue */ - template - METAL_FUNC void apply_epilogue( - const device U* C, - const int ldc, - const int fdc, - thread const BinaryEpilogue& epilogue_op) { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = Ctile.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { - accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); - } - } - } - } - - /* Apply epilogue */ - template - METAL_FUNC void apply_epilogue_safe( - const device U* C, - const int ldc, - const int fdc, - short2 dst_tile_dims, - thread const BinaryEpilogue& epilogue_op) { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - dst_tile_dims -= short2(sn, sm); - - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = Ctile.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - - constexpr short kelems = decltype(Ctile)::kElemsPerFrag; - - // Read C - U c_elems[kelems] = {0}; - - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - if ((j * TN_stride + k) < dst_tile_dims.x) { - c_elems[k] = C[offset_c + k * fdc]; - } - } - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - accum[k] = epilogue_op.apply(accum[k], c_elems[k]); - } - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - D += (sm)*ldd + sn; - - constexpr short kelems = decltype(Ctile)::kElemsPerFrag; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = Ctile.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); - } - } - } - } - - METAL_FUNC void store_result_safe( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - short2 dst_tile_dims, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - D += (sm)*ldd + sn; - dst_tile_dims -= short2(sn, sm); - - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - - constexpr short kelems = decltype(Ctile)::kElemsPerFrag; - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = Ctile.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - if ((j * TN_stride + k) < dst_tile_dims.x) { - D[offset_d + k] = - epilogue_op.apply(accum[k], C[offset_c + k * fdc]); - } - } - } - } - } - } -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/nax.h b/Source/Cmlx/mlx-generated/metal/steel/attn/nax.h deleted file mode 100644 index 77f3ee41..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/nax.h +++ /dev/null @@ -1,1076 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include -#include -#include - -#include "../../steel/defines.h" -#include "../../steel/utils/integral_constant.h" - -#include - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// MMA helper -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -/////////////////////////////////////////////////////////////////////////////// -// NAX Steel with new tiles -/////////////////////////////////////////////////////////////////////////////// - -struct BaseNAXFrag { - STEEL_CONST short kFragRows = 16; - STEEL_CONST short kFragCols = 16; - - STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; - - STEEL_CONST short kElemRows = 2; - STEEL_CONST short kElemCols = 4; - - STEEL_CONST short kElemRowsJump = 8; - - static_assert( - kElemRows * kElemCols == kElemsPerFrag, - "MMAFrag shape is not consistent with MMAFrag size"); - - template - using dtype_frag_t = typename metal::vec; - - METAL_FUNC static short2 get_coord() { - const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); - const short qid = simd_lane_id >> 2; - const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); - const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; - return short2{fn, fm}; - } - - METAL_FUNC static short2 get_coord(short idx) { - const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); - const short qid = simd_lane_id >> 2; - const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; - const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; - return short2{fn, fm}; - } - - template < - typename T, - typename SrcPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void load( - thread dtype_frag_t& dst, - SrcPtrType src, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) { - const short2 sc = short2{0, 0}; // get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - if constexpr (metal::is_same_v>) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = - static_cast(src[r * str_x + (c + j) * str_y]); - } - } - } - } - - template < - typename T, - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void load_rows( - thread dtype_frag_t& dst, - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) { - const short2 sc = short2{0, 0}; // get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - if (r < lim_x) { - if constexpr (metal::is_same_v>) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = - static_cast(src[r * str_x + (c + j) * str_y]); - } - } - - } else { - dst = dtype_frag_t(0); - } - } - } - - template < - typename T, - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void load_safe( - thread dtype_frag_t& dst, - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) { - const short2 sc = short2{0, 0}; // get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - if (r < lim_x && (c + j) < lim_y) { - dst[i * kElemCols + j] = - static_cast(src[r * str_x + (c + j) * str_y]); - } else { - dst[i * kElemCols + j] = T(0); - } - } - } - } - - template < - typename T, - typename DstPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void store( - const thread dtype_frag_t& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) { - using U = pointer_element_t; - - const short2 sc = short2{0, 0}; // get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - if constexpr (metal::is_same_v>) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[r * str_x + (c + j) * str_y] = - static_cast(src[i * kElemCols + j]); - } - } - } - } - - template < - typename T, - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void store_rows( - const thread dtype_frag_t& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) { - using U = pointer_element_t; - - const short2 sc = short2{0, 0}; // get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - if (r < lim_x) { - if constexpr (metal::is_same_v>) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[r * str_x + (c + j) * str_y] = - static_cast(src[i * kElemCols + j]); - } - } - } - } - } - - template < - typename T, - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void store_safe( - const thread dtype_frag_t& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) { - using U = pointer_element_t; - - const short2 sc = short2{0, 0}; // get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - if (r < lim_x && (c + j) < lim_y) { - dst[r * str_x + (c + j) * str_y] = - static_cast(src[i * kElemCols + j]); - } - } - } - } - - template < - typename T, - typename DstPtrType, - typename StrX, - typename StrY, - typename StartX, - typename StopX, - typename StartY, - typename StopY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void store_slice( - const thread dtype_frag_t& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - StartX start_x, - StopX stop_x, - StartY start_y, - StopY stop_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) { - using U = pointer_element_t; - - const short2 sc = short2{0, 0}; // get_coord(); - - const_for_loop<0, kElemRows, 1>([&](auto idx_row) { - const auto r = off_x + idx_row * Int{}; - if (r >= stop_x - sc.y || r < start_x - sc.y) { - return; - } - - const_for_loop<0, kElemCols, 1>([&](auto idx_col) { - const auto c = off_y + idx_col; - if (c >= stop_y - sc.x || c < start_y - sc.x) { - return; - } - - const auto src_idx = idx_row * Int{} + idx_col; - dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = - static_cast(src[src_idx]); - }); - }); - } - - template - METAL_FUNC static constexpr void row_reduce( - thread const dtype_frag_t& inp_vals, - thread T* reduced_vals) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - T thr_reduce = Op::apply( - Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), - Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); - - T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); - qgr_reduce = Op::apply(thr_reduce, qgr_reduce); - - T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); - sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); - - reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); - } - } - - template - METAL_FUNC static constexpr void row_bin_op( - thread dtype_frag_t& inp_vals, - thread T* row_vals) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - inp_vals[i * kElemCols + j] = - Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); - } - } - } -}; - -template < - typename T, - short kRows_, - short kCols_, - typename NAXFrag_ = BaseNAXFrag> -struct NAXSubTile { - using NAXFrag_t = NAXFrag_; - STEEL_CONST short kRows = kRows_; - STEEL_CONST short kCols = kCols_; - - STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; - STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; - STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; - - STEEL_CONST short kSubTileRows = kRows / kFragRows; - STEEL_CONST short kSubTileCols = kCols / kFragCols; - - STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; - STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; - - STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; - STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; - - STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; - STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; - STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; - - using frag_type = typename NAXFrag_t::template dtype_frag_t; - - frag_type val_frags[kNumFrags]; - - METAL_FUNC constexpr void clear() { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kNumFrags; ++i) { - val_frags[i] = frag_type(0); - } - } - - METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { - return val_frags[i * kSubTileCols + j]; - } - - METAL_FUNC constexpr const thread frag_type& frag_at( - const short i, - const short j) const { - return val_frags[i * kSubTileCols + j]; - } - - template - METAL_FUNC constexpr thread frag_type& frag_at() { - return val_frags[i * kSubTileCols + j]; - } - - template - METAL_FUNC constexpr const thread frag_type& frag_at() const { - return val_frags[i * kSubTileCols + j]; - } - - METAL_FUNC thread T* elems() { - return reinterpret_cast(val_frags); - } - - METAL_FUNC const thread T* elems() const { - return reinterpret_cast(val_frags); - } - - template - METAL_FUNC void row_reduce(thread metal::vec& vals) const { - thread T* vptr = (thread T*)(&vals); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::template row_reduce( - frag_at(i, j), &vptr[i * kFragThrRows]); - } - } - } - - template - METAL_FUNC void row_bin_op(thread metal::vec& vals) { - thread T* vptr = (thread T*)(&vals); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::template row_bin_op( - frag_at(i, j), &vptr[i * kFragThrRows]); - } - } - } - - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load( - SrcPtrType src, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load( - frag_at(i, j), - src, - str_x, - str_y, - off_x + i * kFragRows, - off_y + j * kFragCols); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store( - DstPtrType dst, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store( - frag_at(i, j), - dst, - str_x, - str_y, - off_x + i * kFragRows, - off_y + j * kFragCols); - } - } - } - - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load_rows( - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load_rows( - frag_at(i, j), - src, - str_x, - str_y, - lim_x, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } - } - } - - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load_safe( - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load_safe( - frag_at(i, j), - src, - str_x, - str_y, - lim_x, - lim_y, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_rows( - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store_safe( - frag_at(i, j), - dst, - str_x, - str_y, - lim_x, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_safe( - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store_safe( - frag_at(i, j), - dst, - str_x, - str_y, - lim_x, - lim_y, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename StartX, - typename StopX, - typename StartY, - typename StopY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_slice( - DstPtrType dst, - StrX str_x, - StrY str_y, - StartX start_x, - StopX stop_x, - StartY start_y, - StopY stop_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) const { - const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { - const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { - NAXFrag_t::store_slice( - frag_at(), - dst, - str_x, - str_y, - start_x, - stop_x, - start_y, - stop_y, - off_x + idx_row * Int{}, - off_y + idx_col * Int{}); - }); - }); - } -}; - -template < - short RC, - short CC, - short RA, - short CA, - short RB, - short CB, - typename CType, - typename AType, - typename BType, - bool transpose_a, - bool transpose_b, - typename NAXFrag_t = BaseNAXFrag> -METAL_FUNC void subtile_matmad_nax( - thread NAXSubTile& C, - thread NAXSubTile& A, - metal::bool_constant, - thread NAXSubTile& B, - metal::bool_constant) { - // Static checks - constexpr short FMa = transpose_a ? CA : RA; - constexpr short FMc = RC; - static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); - - constexpr short FNb = transpose_b ? RB : CB; - constexpr short FNc = CC; - static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); - - constexpr short FKa = transpose_a ? RA : CA; - constexpr short FKb = transpose_b ? CB : RB; - static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); - - constexpr short FM = FMc; - constexpr short FN = FNc; - constexpr short FK = FKa; - - constexpr int TM = FM / 16; - constexpr int TN = FN / 16; - constexpr int TK = FK / 16; - - constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( - FM, - FN, - FK, - transpose_a, - transpose_b, - true, - mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); - - mpp::tensor_ops::matmul2d gemm_op; - - auto ct_a = - gemm_op.template get_left_input_cooperative_tensor(); - auto ct_b = - gemm_op - .template get_right_input_cooperative_tensor(); - auto ct_c = gemm_op.template get_destination_cooperative_tensor< - decltype(ct_a), - decltype(ct_b), - CType>(); - - STEEL_PRAGMA_UNROLL - for (short mm = 0; mm < TM; mm++) { - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < TK; kk++) { - const short fi = transpose_a ? kk : mm; - const short fj = transpose_a ? mm : kk; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < 8; i++) { - ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; - } - } - } - - STEEL_PRAGMA_UNROLL - for (short nn = 0; nn < TN; nn++) { - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < TK; kk++) { - const short fi = transpose_b ? nn : kk; - const short fj = transpose_b ? kk : nn; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < 8; i++) { - ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; - } - } - } - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < ct_c.get_capacity(); i++) { - ct_c[i] = C.elems()[i]; - } - - gemm_op.run(ct_a, ct_b, ct_c); - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < ct_c.get_capacity(); i++) { - C.elems()[i] = ct_c[i]; - } -} - -template -struct NAXTile { - using NAXSubTile_t = NAXSubTile_; - using elem_type = T; - STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; - STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; - STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; - - STEEL_CONST short kTileRows = kTileRows_; - STEEL_CONST short kTileCols = kTileCols_; - - STEEL_CONST short kRows = kTileRows * kSubTileRows; - STEEL_CONST short kCols = kTileCols * kSubTileCols; - - STEEL_CONST short kSubTiles = kTileRows * kTileCols; - STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; - - STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; - STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; - - STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; - STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; - - NAXSubTile_t val_subtiles[kSubTiles]; - - METAL_FUNC NAXTile() thread {} - - METAL_FUNC constexpr void clear() { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTiles; ++i) { - val_subtiles[i].clear(); - } - } - - METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( - const short i, - const short j) { - return val_subtiles[i * kTileCols + j]; - } - - METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( - const short i, - const short j) const { - return val_subtiles[i * kTileCols + j]; - } - - template - METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { - return val_subtiles[i * kTileCols + j]; - } - - METAL_FUNC thread elem_type* elems() { - return reinterpret_cast(val_subtiles[0].elems()); - } - - METAL_FUNC const thread elem_type* elems() const { - return reinterpret_cast(val_subtiles[0].elems()); - } - - template - METAL_FUNC void row_reduce(thread metal::vec& vals) const { - auto sub_rows = (thread metal::vec*)(&vals); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).template row_reduce(sub_rows[i]); - } - } - } - - template - METAL_FUNC void row_bin_op(thread metal::vec& vals) { - auto sub_rows = (thread metal::vec*)(&vals); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).template row_bin_op(sub_rows[i]); - } - } - } - - template - METAL_FUNC void load(const threadgroup U* src) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load( - src, - Int{}, - Int{}, - i * kSubTileRows, - j * kSubTileCols); - } - } - } - - template - METAL_FUNC void store(threadgroup U* dst) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store( - dst, - Int{}, - Int{}, - i * kSubTileRows, - j * kSubTileCols); - } - } - } - - template - METAL_FUNC void load(const device U* src, const int ld) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load( - &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); - } - } - } - - template - METAL_FUNC void store(device U* dst, const int ld) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store( - &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], ld, Int<1>{}); - } - } - } - - template - METAL_FUNC void - load_safe(const device U* src, const int ld, const short2 src_tile_dims) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load_safe( - src, - ld, - Int<1>{}, - src_tile_dims.y, - src_tile_dims.x, - i * kSubTileRows, - j * kSubTileCols); - } - } - } - - template - METAL_FUNC void - load_rows(const device U* src, const int ld, const short n_rows) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load_rows( - &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], - ld, - Int<1>{}, - n_rows - i * kSubTileRows); - } - } - } - - template - METAL_FUNC void - store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store_safe( - dst, - ld, - Int<1>{}, - dst_tile_dims.y, - dst_tile_dims.x, - i * kSubTileRows, - j * kSubTileCols); - } - } - } - - template - METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) - const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store_rows( - &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], - ld, - Int<1>{}, - n_rows - i * kSubTileRows); - } - } - } - - template - METAL_FUNC void store_slice( - device U* dst, - const int ld, - const short2 start, - const short2 stop) const { - const_for_loop<0, kTileRows, 1>([&](auto idx_row) { - const_for_loop<0, kTileCols, 1>([&](auto idx_col) { - subtile_at().store_slice( - dst, - ld, - Int<1>{}, - start.y, - stop.y, - start.x, - stop.x, - idx_row * Int{}, - idx_col * Int{}); - }); - }); - } -}; - -template < - class CTile, - class ATile, - class BTile, - bool transpose_a, - bool transpose_b> -METAL_FUNC void tile_matmad_nax( - thread CTile& C, - thread ATile& A, - metal::bool_constant, - thread BTile& B, - metal::bool_constant) { - // Static checks - constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; - constexpr short TMc = CTile::kTileRows; - static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); - - constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; - constexpr short FMc = CTile::kSubTileRows; - static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); - - constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; - constexpr short TNc = CTile::kTileCols; - static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); - - constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; - constexpr short FNc = CTile::kSubTileCols; - static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); - - constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; - constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; - static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); - - constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; - constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; - static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); - - constexpr short TM = TMc; - constexpr short TN = TNc; - constexpr short TK = TKa; - - // Do matmul here - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; ++j) { - STEEL_PRAGMA_UNROLL - for (short k = 0; k < TK; ++k) { - const short ra = transpose_a ? k : i; - const short ca = transpose_a ? i : k; - const short rb = transpose_b ? j : k; - const short cb = transpose_b ? k : j; - - subtile_matmad_nax( - C.subtile_at(i, j), - A.subtile_at(ra, ca), - metal::bool_constant{}, - B.subtile_at(rb, cb), - metal::bool_constant{}); - } - } - } -} - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/params.h b/Source/Cmlx/mlx-generated/metal/steel/attn/params.h deleted file mode 100644 index f1cf09fa..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/params.h +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -/////////////////////////////////////////////////////////////////////////////// -// Attn param classes -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -struct AttnParams { - int B; ///< Batch Size - int H; ///< Heads - int D; ///< Head Dim - - int qL; ///< Query Sequence Length - int kL; ///< Key Sequence Length - - int gqa_factor; ///< Group Query factor - float scale; ///< Attention scale - - int NQ; ///< Number of query blocks - int NK; ///< Number of key/value blocks - - int NQ_aligned; ///< Number of full query blocks - int NK_aligned; ///< Number of full key/value blocks - - int qL_rem; ///< Remainder in last query block - int kL_rem; ///< Remainder in last key/value block - int qL_off; ///< Offset in query sequence start - - int64_t Q_strides[3]; ///< Query strides (B, H, L, D = 1) - int64_t K_strides[3]; ///< Key strides (B, H, L, D = 1) - int64_t V_strides[3]; ///< Value strides (B, H, L, D = 1) - int64_t O_strides[3]; ///< Output strides (B, H, L, D = 1) -}; - -struct AttnMaskParams { - int64_t M_strides[3]; ///< Mask strides (B, H, qL, kL = 1) -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/transforms.h b/Source/Cmlx/mlx-generated/metal/steel/attn/transforms.h deleted file mode 100644 index 3d8ca054..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/transforms.h +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../steel/utils.h" - -/////////////////////////////////////////////////////////////////////////////// -// Transforms and Epilogues -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template -struct TransformNone { - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT) { - return static_cast(x); - } -}; - -template -struct TransformAdd { - TransformAdd(const float, const float) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT c) { - return static_cast(x) + c; - } -}; - -template -struct TransformAxpby { - const float alpha; - const float beta; - - TransformAxpby(const float alpha_, const float beta_) - : alpha(alpha_), beta(beta_) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - METAL_FUNC OutT apply(InT x, OutT c) const { - return static_cast(x * alpha + (beta * c)); - } -}; - -template -struct AccumHelper { - typedef float accum_type; -}; - -struct BlockSwizzle { - static METAL_FUNC int2 - swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { - const int tid_x = (tid.x) >> swizzle_log; - const int tid_y = - ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); - return int2(tid_x, tid_y); - } -}; - -} // namespace steel -} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/conv.h b/Source/Cmlx/mlx-generated/metal/steel/conv/conv.h deleted file mode 100644 index 0845f521..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/conv.h +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../steel/defines.h" -#include "../../steel/utils.h" - -#include "../../steel/conv/loader.h" -#include "../../steel/conv/params.h" -#include "../../steel/gemm/mma.h" - -using namespace metal; -using namespace mlx::steel; diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h deleted file mode 100644 index 850ec15b..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv.h +++ /dev/null @@ -1,176 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include - -using namespace metal; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - int N_CHANNELS = 0, - bool SMALL_FILTER = false> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void -implicit_gemm_conv_2d( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* C [[buffer(2)]], - const constant MLXConvParams<2>* params [[buffer(3)]], - const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using namespace mlx::steel; - - (void)lid; - - constexpr bool transpose_a = false; - constexpr bool transpose_b = true; - constexpr short tgp_padding_a = 16 / sizeof(T); - constexpr short tgp_padding_b = 16 / sizeof(T); - - constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; - constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; - constexpr short shape_a_rows = (transpose_a ? BK : BM); - constexpr short shape_b_rows = (transpose_b ? BN : BK); - constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; - constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; - - constexpr short tgp_size = WM * WN * 32; - - // Input loader - - using loader_a_t = typename metal::conditional_t< - // Check for small channel specialization - N_CHANNELS != 0 && N_CHANNELS <= 4, - - // Go to small channel specialization - Conv2DInputBlockLoaderSmallChannels< - T, - BM, - BN, - BK, - tgp_size, - N_CHANNELS, - tgp_padding_a>, - - // Else go to general loader - typename metal::conditional_t< - // Check if filter size is small enough - SMALL_FILTER, - - // Go to small filter specialization - Conv2DInputBlockLoaderSmallFilter< - T, - BM, - BN, - BK, - tgp_size, - tgp_padding_a>, - - // Else go to large filter generalization - Conv2DInputBlockLoaderLargeFilter< - T, - BM, - BN, - BK, - tgp_size, - tgp_padding_a>>>; - - // Weight loader - using loader_b_t = typename metal::conditional_t< - // Check for small channel specialization - N_CHANNELS != 0 && N_CHANNELS <= 4, - - // Go to small channel specialization - Conv2DWeightBlockLoaderSmallChannels< - T, - BM, - BN, - BK, - tgp_size, - N_CHANNELS, - tgp_padding_b>, - - // Else go to general loader - Conv2DWeightBlockLoader>; - - using mma_t = BlockMMA< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - shape_a_cols, - shape_b_cols>; - - threadgroup T As[tgp_mem_size_a]; - threadgroup T Bs[tgp_mem_size_b]; - - const int tid_y = ((tid.y) << gemm_params->swizzle_log) + - ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> gemm_params->swizzle_log; - - if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { - return; - } - - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const int K = gemm_params->K; - const int N = gemm_params->N; - const int C_per_group = params->C / params->groups; - - // Groups - A += tid.z * C_per_group; - B += tid.z * N * K; - C += tid.z * N; - - B += c_col * K; - C += c_row * (N * params->groups) + c_col; - - const int2 offsets_a(0, c_row); - const int2 offsets_b(0, c_col); - - // Prepare threadgroup loading operations - loader_a_t loader_a( - A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); - loader_b_t loader_b( - B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); - - // Prepare threadgroup mma operation - mma_t mma_op(simd_gid, simd_lid); - - int gemm_k_iterations = gemm_params->gemm_k_iterations; - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Store results to device memory - short tgp_bm = min(BM, gemm_params->M - c_row); - short tgp_bn = min(BN, gemm_params->N - c_col); - const int ldc = N * params->groups; - mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_3d.h b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_3d.h deleted file mode 100644 index d2fbac0f..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_3d.h +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include - -using namespace metal; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool SMALL_FILTER = false> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void -implicit_gemm_conv_3d( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* C [[buffer(2)]], - const constant MLXConvParams<3>* params [[buffer(3)]], - const constant ImplicitGemmConv3DParams* gemm_params [[buffer(4)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - using namespace mlx::steel; - - (void)lid; - - constexpr bool transpose_a = false; - constexpr bool transpose_b = true; - constexpr short tgp_padding_a = 16 / sizeof(T); - constexpr short tgp_padding_b = 16 / sizeof(T); - - constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; - constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; - constexpr short shape_a_rows = (transpose_a ? BK : BM); - constexpr short shape_b_rows = (transpose_b ? BN : BK); - constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; - constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; - - constexpr short tgp_size = WM * WN * 32; - - // Input loader - using loader_a_t = typename metal::conditional_t< - // If the filter is small we can precompute masks for bounds checking - SMALL_FILTER, - Conv3DInputBlockLoaderSmallFilter, - Conv3DInputBlockLoaderLargeFilter< - T, - BM, - BN, - BK, - tgp_size, - tgp_padding_a>>; - - // Weight loader - using loader_b_t = - Conv3DWeightBlockLoader; - - using mma_t = BlockMMA< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - shape_a_cols, - shape_b_cols>; - - threadgroup T As[tgp_mem_size_a]; - threadgroup T Bs[tgp_mem_size_b]; - - const int tid_y = ((tid.y) << gemm_params->swizzle_log) + - ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> gemm_params->swizzle_log; - - if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { - return; - } - - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const int K = gemm_params->K; - const int N = gemm_params->N; - const int C_per_group = params->C / params->groups; - - // Groups - A += tid.z * C_per_group; - B += tid.z * N * K; - C += tid.z * N; - - B += c_col * K; - C += c_row * (N * params->groups) + c_col; - - const int2 offsets_a(0, c_row); - const int2 offsets_b(0, c_col); - - // Prepare threadgroup loading operations - loader_a_t loader_a( - A, As, offsets_a, params, gemm_params, simd_gid, simd_lid); - loader_b_t loader_b( - B, Bs, offsets_b, params, gemm_params, simd_gid, simd_lid); - - // Prepare threadgroup mma operation - mma_t mma_op(simd_gid, simd_lid); - - int gemm_k_iterations = gemm_params->gemm_k_iterations; - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Store results to device memory - short tgp_bm = min(BM, gemm_params->M - c_row); - short tgp_bn = min(BN, gemm_params->N - c_col); - const int ldc = N * params->groups; - mma_op.store_result_safe(C, ldc, short2(tgp_bn, tgp_bm)); -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h deleted file mode 100644 index b775dd55..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include "../../../steel/conv/loaders/loader_general.h" - -constant bool align_C [[function_constant(200)]]; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - typename AccumType = float, - typename Epilogue = TransformNone> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void -implicit_gemm_conv_2d_general( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* C [[buffer(2)]], - const constant MLXConvParams<2>* params [[buffer(3)]], - const constant ImplicitGemmConv2DParams* gemm_params [[buffer(4)]], - const constant Conv2DGeneralJumpParams* jump_params [[buffer(5)]], - const constant Conv2DGeneralBaseInfo* base_h [[buffer(6)]], - const constant Conv2DGeneralBaseInfo* base_w [[buffer(7)]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]], - uint simd_gid [[simdgroup_index_in_threadgroup]], - uint simd_lid [[thread_index_in_simdgroup]]) { - (void)lid; - - constexpr bool transpose_a = false; - constexpr bool transpose_b = true; - constexpr short tgp_padding_a = 16 / sizeof(T); - constexpr short tgp_padding_b = 16 / sizeof(T); - - constexpr short shape_a_cols = (transpose_a ? BM : BK) + tgp_padding_a; - constexpr short shape_b_cols = (transpose_b ? BK : BN) + tgp_padding_b; - constexpr short shape_a_rows = (transpose_a ? BK : BM); - constexpr short shape_b_rows = (transpose_b ? BN : BK); - constexpr short tgp_mem_size_a = shape_a_cols * shape_a_rows; - constexpr short tgp_mem_size_b = shape_b_cols * shape_b_rows; - - constexpr short tgp_size = WM * WN * 32; - - // Input loader - using loader_a_t = - Conv2DInputBlockLoaderGeneral; - - // Weight loader - using loader_b_t = - Conv2DWeightBlockLoaderGeneral; - - using mma_t = BlockMMA< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - shape_a_cols, - shape_b_cols>; - - threadgroup T As[tgp_mem_size_a]; - threadgroup T Bs[tgp_mem_size_b]; - - const int tid_y = ((tid.y) << gemm_params->swizzle_log) + - ((tid.x) & ((1 << gemm_params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> gemm_params->swizzle_log; - - if (gemm_params->tiles_n <= tid_x || gemm_params->tiles_m <= tid_y) { - return; - } - - const int tid_z = tid.z; - - const int base_oh = tid_z / jump_params->f_out_jump_w; - const int base_ow = tid_z % jump_params->f_out_jump_w; - - const int base_wh = base_h[base_oh].weight_base; - const int base_ww = base_w[base_ow].weight_base; - - const int base_wh_size = base_h[base_oh].weight_size; - const int base_ww_size = base_w[base_ow].weight_size; - - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const int K = gemm_params->K; - - B += c_col * K; - - const int4 offsets_a(0, c_row, base_oh, base_ow); - const int2 offsets_b(0, c_col); - - // Prepare threadgroup loading operations - loader_a_t loader_a( - A, - As, - offsets_a, - params, - jump_params, - base_wh, - base_ww, - simd_gid, - simd_lid); - loader_b_t loader_b( - B, - Bs, - offsets_b, - params, - jump_params, - base_wh, - base_ww, - simd_gid, - simd_lid); - - // Prepare threadgroup mma operation - mma_t mma_op(simd_gid, simd_lid); - - if (align_C) { - int gemm_k_iterations = - base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; - - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - } - - else { - for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { - for (int j = 0; j < base_wh_size * base_ww_size; j++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - } - const short remaining_k = params->C % BK; - for (int j = 0; j < base_wh_size * base_ww_size; j++) { - // Load elements into threadgroup - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_a.load_safe(remaining_k); - loader_b.load_safe(remaining_k); - threadgroup_barrier(mem_flags::mem_threadgroup); - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - } - - threadgroup_barrier(mem_flags::mem_none); - - // Store results to device memory - { - // Adjust for simdgroup and thread location - int offset_m = c_row + mma_op.sm; - int offset_n = c_col + mma_op.sn; - C += offset_n; - - if (offset_n >= gemm_params->N) - return; - - short diff = gemm_params->N - offset_n; - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < mma_t::TM; i++) { - int cm = offset_m + i * mma_t::TM_stride; - - int n = cm / jump_params->adj_out_hw; - int hw = cm % jump_params->adj_out_hw; - int oh = - (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + base_oh; - int ow = - (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + base_ow; - - if (n < params->N && oh < params->oS[0] && ow < params->oS[1]) { - int offset_cm = n * params->out_strides[0] + - oh * params->out_strides[1] + ow * params->out_strides[2]; - - STEEL_PRAGMA_UNROLL - for (int j = 0; j < mma_t::TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = mma_op.Ctile.frag_at(i, j); - int offset = offset_cm + (j * mma_t::TN_stride); - - constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag; - - // Apply epilogue and output C - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - if ((j * mma_t::TN_stride + k) < diff) { - C[offset + k] = Epilogue::apply(accum[k]); - } - } - } - } - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loader.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loader.h deleted file mode 100644 index bb9b3926..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loader.h +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../steel/conv/loaders/loader_channel_l.h" -#include "../../steel/conv/loaders/loader_channel_n.h" \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h deleted file mode 100644 index a516c1ad..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h +++ /dev/null @@ -1,955 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../../steel/utils.h" - -#include "../../../steel/conv/params.h" - -/////////////////////////////////////////////////////////////////////////////// -// Loading helper -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short tgp_padding = 0> -struct Conv2DInputBlockLoaderLargeFilter { - // Destination dimensions - STEEL_CONST short BROWS = BM; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - - const constant MLXConvParams<2>* params; - const constant ImplicitGemmConv2DParams* gemm_params; - - short weight_h; - short weight_w; - - const device T* src[n_rows]; - - int read_n[n_rows]; - int read_ih[n_rows]; - int read_iw[n_rows]; - - /* Constructor */ - METAL_FUNC Conv2DInputBlockLoaderLargeFilter( - const device T* src_, - threadgroup T* dst_, - const int2 offsets, - const constant MLXConvParams<2>* params_, - const constant ImplicitGemmConv2DParams* gemm_params_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - params(params_), - gemm_params(gemm_params_), - weight_h(0), - weight_w(0) { - int out_n_pixels = params->oS[0] * params->oS[1]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int offset_nhw = offsets.y + bi + i * TROWS; - int n = offset_nhw / out_n_pixels; - int hw = offset_nhw % out_n_pixels; - int oh = hw / params->oS[1]; - int ow = hw % params->oS[1]; - - int ih = oh * params->str[0] - params->pad[0]; - int iw = ow * params->str[1] - params->pad[1]; - - read_n[i] = n; - read_ih[i] = ih; - read_iw[i] = iw; - - // Adjust for flip - if (params->flip) { - ih += (params->wS[0] - 1) * params->kdil[0]; - iw += (params->wS[1] - 1) * params->kdil[1]; - } - - // Read from input if in bounds - src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + - iw * params->in_strides[2] + bj; - } - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { - // Find bounds - int n = read_n[i]; - int ih = read_ih[i] + weight_h * params->kdil[0]; - int iw = read_iw[i] + weight_w * params->kdil[1]; - - // Read from input if in bounds - if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && - (iw >= 0 && iw < params->iS[1])) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = src[i][j]; - } - } - - // Zero pad otherwise - else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = T(0); - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - if (++weight_w < params->wS[1]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_w; - } - - return; - } - - weight_w = 0; - - if (++weight_h < params->wS[0]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_h; - } - - return; - } - - weight_h = 0; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_c; - } - } -}; - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short tgp_padding = 0> -struct Conv2DInputBlockLoaderSmallFilter { - // Destination dimensions - STEEL_CONST short BROWS = BM; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - using mask_t = short; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - - const constant MLXConvParams<2>* params; - const constant ImplicitGemmConv2DParams* gemm_params; - - short weight_h; - short weight_w; - - const device T* src[n_rows]; - - mask_t mask_h[n_rows]; - mask_t mask_w[n_rows]; - - /* Constructor */ - METAL_FUNC Conv2DInputBlockLoaderSmallFilter( - const device T* src_, - threadgroup T* dst_, - const int2 offsets, - const constant MLXConvParams<2>* params_, - const constant ImplicitGemmConv2DParams* gemm_params_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - params(params_), - gemm_params(gemm_params_), - weight_h(0), - weight_w(0) { - int out_n_pixels = params->oS[0] * params->oS[1]; - - int read_n[n_rows]; - int read_ih[n_rows]; - int read_iw[n_rows]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int offset_nhw = offsets.y + bi + i * TROWS; - int n = offset_nhw / out_n_pixels; - int hw = offset_nhw % out_n_pixels; - int oh = hw / params->oS[1]; - int ow = hw % params->oS[1]; - - int ih = oh * params->str[0] - params->pad[0]; - int iw = ow * params->str[1] - params->pad[1]; - - read_n[i] = n; - read_ih[i] = ih; - read_iw[i] = iw; - - // Adjust for flip - if (params->flip) { - ih += (params->wS[0] - 1) * params->kdil[0]; - iw += (params->wS[1] - 1) * params->kdil[1]; - } - - // Read from input if in bounds - src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + - iw * params->in_strides[2] + bj; - } - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - mask_h[i] = 0; - mask_w[i] = 0; - } - - for (short kh = 0; kh < params->wS[0]; kh++) { - short flip_h = params->flip ? params->wS[0] - kh - 1 : kh; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int n = read_n[i]; - int ih = read_ih[i] + flip_h * params->kdil[0]; - - bool in_bounds = n < params->N && ih >= 0 && ih < params->iS[0]; - - mask_h[i] |= (in_bounds << kh); - } - } - - for (short kw = 0; kw < params->wS[1]; kw++) { - short flip_w = params->flip ? params->wS[1] - kw - 1 : kw; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int iw = read_iw[i] + flip_w * params->kdil[1]; - - bool in_bounds = iw >= 0 && iw < params->iS[1]; - - mask_w[i] |= (in_bounds << kw); - } - } - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - mask_t h_mask = mask_t(1) << weight_h; - mask_t w_mask = mask_t(1) << weight_w; - - STEEL_PRAGMA_UNROLL - for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { - // Read from input if in bounds - if ((mask_h[i] & h_mask) && (mask_w[i] & w_mask)) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = src[i][j]; - } - } - - // Zero pad otherwise - else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = T(0); - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - if (++weight_w < params->wS[1]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_w; - } - - return; - } - - weight_w = 0; - - if (++weight_h < params->wS[0]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_h; - } - - return; - } - - weight_h = 0; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_c; - } - } -}; - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short tgp_padding = 0> -struct Conv2DWeightBlockLoader { - // Destination dimensions - STEEL_CONST short BROWS = BN; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = - (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - // Leading dimension for src - const int src_ld; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - const constant MLXConvParams<2>* params; - - int weight_hw; - int weight_step; - - const int read_n; - const bool do_read; - - /* Constructor */ - METAL_FUNC Conv2DWeightBlockLoader( - const device T* src_, - threadgroup T* dst_, - const int2 offsets, - const constant MLXConvParams<2>* params_, - const constant ImplicitGemmConv2DParams* gemm_params_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_->wt_strides[0]), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj), - params(params_), - weight_hw(0), - weight_step(params->C / params->groups), - read_n(offsets.y + bi), - do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - if (BN != 8 || do_read) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BN; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = src[i * src_ld + j]; - } - } - } else { - for (short i = 0; i < BN; i += TROWS) { - if ((read_n + i) < params->O) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = src[i * src_ld + j]; - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - if (++weight_hw < (params->wS[1] * params->wS[0])) { - src += weight_step; - return; - } - - weight_hw = 0; - - src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; - } -}; - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short tgp_padding = 0> -struct Conv3DInputBlockLoaderLargeFilter { - // Destination dimensions - STEEL_CONST short BROWS = BM; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - - const constant MLXConvParams<3>* params; - const constant ImplicitGemmConv3DParams* gemm_params; - - short weight_d; - short weight_h; - short weight_w; - - short kdil_d; - short kdil_h; - short kdil_w; - - const device T* src[n_rows]; - - int read_n[n_rows]; - int read_id[n_rows]; - int read_ih[n_rows]; - int read_iw[n_rows]; - - /* Constructor */ - METAL_FUNC Conv3DInputBlockLoaderLargeFilter( - const device T* src_, - threadgroup T* dst_, - const int2 offsets, - const constant MLXConvParams<3>* params_, - const constant ImplicitGemmConv3DParams* gemm_params_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - params(params_), - gemm_params(gemm_params_), - weight_d(0), - weight_h(0), - weight_w(0), - kdil_d(params_->flip ? -params_->kdil[0] : params_->kdil[0]), - kdil_h(params_->flip ? -params_->kdil[1] : params_->kdil[1]), - kdil_w(params_->flip ? -params_->kdil[2] : params_->kdil[2]) { - int out_n_pixels = params->oS[0] * params->oS[1] * params->oS[2]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int offset_ndhw = offsets.y + bi + i * TROWS; - int n = offset_ndhw / out_n_pixels; - int dhw = offset_ndhw % out_n_pixels; - int od = dhw / (params->oS[1] * params->oS[2]); - int hw = dhw % (params->oS[1] * params->oS[2]); - int oh = hw / params->oS[2]; - int ow = hw % params->oS[2]; - - int id = od * params->str[0] - params->pad[0]; - int ih = oh * params->str[1] - params->pad[1]; - int iw = ow * params->str[2] - params->pad[2]; - - read_n[i] = n; - - if (params->flip) { - read_id[i] = id + (params->wS[0] - 1) * params->kdil[0]; - read_ih[i] = ih + (params->wS[1] - 1) * params->kdil[1]; - read_iw[i] = iw + (params->wS[2] - 1) * params->kdil[2]; - } else { - read_id[i] = id; - read_ih[i] = ih; - read_iw[i] = iw; - } - - // Adjust for flip - if (params->flip) { - id += (params->wS[0] - 1) * params->kdil[0]; - ih += (params->wS[1] - 1) * params->kdil[1]; - iw += (params->wS[2] - 1) * params->kdil[2]; - } - - // Read from input if in bounds - src[i] = src_ + n * params->in_strides[0] + id * params->in_strides[1] + - ih * params->in_strides[2] + iw * params->in_strides[3] + bj; - } - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { - // Find bounds - int n = read_n[i]; - int id = read_id[i] + weight_d * kdil_d; - int ih = read_ih[i] + weight_h * kdil_h; - int iw = read_iw[i] + weight_w * kdil_w; - - // Read from input if in bounds - if ((n < params->N) && (id >= 0 && id < params->iS[0]) && - (ih >= 0 && ih < params->iS[1]) && (iw >= 0 && iw < params->iS[2])) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = src[i][j]; - } - } - - // Zero pad otherwise - else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = T(0); - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - if (++weight_w < params->wS[2]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_w; - } - - return; - } - - weight_w = 0; - - if (++weight_h < params->wS[1]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_h; - } - - return; - } - - weight_h = 0; - - if (++weight_d < params->wS[0]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_d; - } - - return; - } - - weight_d = 0; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_c; - } - } -}; - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short tgp_padding = 0> -struct Conv3DInputBlockLoaderSmallFilter { - // Destination dimensions - STEEL_CONST short BROWS = BM; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - using mask_t = short; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - - const constant MLXConvParams<3>* params; - const constant ImplicitGemmConv3DParams* gemm_params; - - short weight_d; - short weight_h; - short weight_w; - - const device T* src[n_rows]; - - mask_t mask_d[n_rows]; - mask_t mask_h[n_rows]; - mask_t mask_w[n_rows]; - - /* Constructor */ - METAL_FUNC Conv3DInputBlockLoaderSmallFilter( - const device T* src_, - threadgroup T* dst_, - const int2 offsets, - const constant MLXConvParams<3>* params_, - const constant ImplicitGemmConv3DParams* gemm_params_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - params(params_), - gemm_params(gemm_params_), - weight_d(0), - weight_h(0), - weight_w(0) { - int out_n_pixels = params->oS[0] * params->oS[1] * params->oS[2]; - - int read_n[n_rows]; - int read_id[n_rows]; - int read_ih[n_rows]; - int read_iw[n_rows]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int offset_ndhw = offsets.y + bi + i * TROWS; - int n = offset_ndhw / out_n_pixels; - int dhw = offset_ndhw % out_n_pixels; - int od = dhw / (params->oS[1] * params->oS[2]); - int hw = dhw % (params->oS[1] * params->oS[2]); - int oh = hw / params->oS[2]; - int ow = hw % params->oS[2]; - - int id = od * params->str[0] - params->pad[0]; - int ih = oh * params->str[1] - params->pad[1]; - int iw = ow * params->str[2] - params->pad[2]; - - read_n[i] = n; - read_id[i] = id; - read_ih[i] = ih; - read_iw[i] = iw; - - // Adjust for flip - if (params->flip) { - id += (params->wS[0] - 1) * params->kdil[0]; - ih += (params->wS[1] - 1) * params->kdil[1]; - iw += (params->wS[2] - 1) * params->kdil[2]; - } - - // Read from input if in bounds - src[i] = src_ + n * params->in_strides[0] + id * params->in_strides[1] + - ih * params->in_strides[2] + iw * params->in_strides[3] + bj; - } - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - mask_d[i] = 0; - mask_h[i] = 0; - mask_w[i] = 0; - } - - for (short kd = 0; kd < params->wS[0]; kd++) { - short flip_d = params->flip ? params->wS[0] - kd - 1 : kd; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int n = read_n[i]; - int id = read_id[i] + flip_d * params->kdil[0]; - - bool in_bounds = n < params->N && id >= 0 && id < params->iS[0]; - - mask_d[i] |= (in_bounds << kd); - } - } - - for (short kh = 0; kh < params->wS[1]; kh++) { - short flip_h = params->flip ? params->wS[1] - kh - 1 : kh; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int ih = read_ih[i] + flip_h * params->kdil[1]; - - bool in_bounds = ih >= 0 && ih < params->iS[1]; - - mask_h[i] |= (in_bounds << kh); - } - } - - for (short kw = 0; kw < params->wS[2]; kw++) { - short flip_w = params->flip ? params->wS[2] - kw - 1 : kw; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int iw = read_iw[i] + flip_w * params->kdil[2]; - - bool in_bounds = iw >= 0 && iw < params->iS[2]; - - mask_w[i] |= (in_bounds << kw); - } - } - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - mask_t d_mask = mask_t(1) << weight_d; - mask_t h_mask = mask_t(1) << weight_h; - mask_t w_mask = mask_t(1) << weight_w; - - STEEL_PRAGMA_UNROLL - for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { - // Read from input if in bounds - if ((mask_d[i] & d_mask) && (mask_h[i] & h_mask) && - (mask_w[i] & w_mask)) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = src[i][j]; - } - } - - // Zero pad otherwise - else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = T(0); - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - if (++weight_w < params->wS[2]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_w; - } - - return; - } - - weight_w = 0; - - if (++weight_h < params->wS[1]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_h; - } - - return; - } - - weight_h = 0; - - if (++weight_d < params->wS[0]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_d; - } - - return; - } - - weight_d = 0; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += gemm_params->inp_jump_c; - } - } -}; - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short tgp_padding = 0> -struct Conv3DWeightBlockLoader { - // Destination dimensions - STEEL_CONST short BROWS = BN; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = - (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - // Leading dimension for src - const int src_ld; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - const constant MLXConvParams<3>* params; - - int weight_dhw; - int weight_step; - - const int read_n; - const bool do_read; - - /* Constructor */ - METAL_FUNC Conv3DWeightBlockLoader( - const device T* src_, - threadgroup T* dst_, - const int2 offsets, - const constant MLXConvParams<3>* params_, - const constant ImplicitGemmConv3DParams* gemm_params_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_->wt_strides[0]), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj), - params(params_), - weight_dhw(0), - weight_step(params->C / params->groups), - read_n(offsets.y + bi), - do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - if (BN != 8 || do_read) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BN; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = src[i * src_ld + j]; - } - } - } else { - for (short i = 0; i < BN; i += TROWS) { - if ((read_n + i) < params->O) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = src[i * src_ld + j]; - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - if (++weight_dhw < (params->wS[0] * params->wS[1] * params->wS[2])) { - src += weight_step; - return; - } - - weight_dhw = 0; - - src += - BK - (params->wS[0] * params->wS[1] * params->wS[2] - 1) * weight_step; - } -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h deleted file mode 100644 index 1f37fb21..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h +++ /dev/null @@ -1,319 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../../steel/utils.h" - -#include "../../../steel/conv/params.h" - -/////////////////////////////////////////////////////////////////////////////// -// Loading helper -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template -struct ChannelHelper { - STEEL_CONST short n_channels = n_channels_; - STEEL_CONST short vec_size = n_channels_ <= 4 ? 4 : 8; - STEEL_CONST short excess = vec_size - n_channels_; -}; - -template <> -struct ChannelHelper<1> { - STEEL_CONST short n_channels = 1; - STEEL_CONST short vec_size = 1; - STEEL_CONST short excess = 0; -}; - -template <> -struct ChannelHelper<2> { - STEEL_CONST short n_channels = 2; - STEEL_CONST short vec_size = 2; - STEEL_CONST short excess = 0; -}; - -template <> -struct ChannelHelper<3> { - STEEL_CONST short n_channels = 3; - STEEL_CONST short vec_size = 4; - STEEL_CONST short excess = 1; -}; - -template <> -struct ChannelHelper<4> { - STEEL_CONST short n_channels = 4; - STEEL_CONST short vec_size = 4; - STEEL_CONST short excess = 0; -}; - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short n_channels, - short tgp_padding = 0> -struct Conv2DInputBlockLoaderSmallChannels { - // Destination dimensions - STEEL_CONST short BROWS = BM; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = ChannelHelper::vec_size; - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - - const constant MLXConvParams<2>* params; - const constant ImplicitGemmConv2DParams* gemm_params; - - int weight_hw; - - const device T* src[n_rows]; - - int read_n[n_rows]; - int read_ih[n_rows]; - int read_iw[n_rows]; - - /* Constructor */ - METAL_FUNC Conv2DInputBlockLoaderSmallChannels( - const device T* src_, - threadgroup T* dst_, - const int2 offsets, - const constant MLXConvParams<2>* params_, - const constant ImplicitGemmConv2DParams* gemm_params_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - params(params_), - gemm_params(gemm_params_), - weight_hw(thread_idx % TCOLS) { - int out_n_pixels = params->oS[0] * params->oS[1]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int offset_nhw = offsets.y + bi + i * TROWS; - int n = offset_nhw / out_n_pixels; - int hw = offset_nhw % out_n_pixels; - int oh = hw / params->oS[1]; - int ow = hw % params->oS[1]; - - int ih = oh * params->str[0] - params->pad[0]; - int iw = ow * params->str[1] - params->pad[1]; - - // Read from input if in bounds - src[i] = src_ + n * params->in_strides[0] + ih * params->in_strides[1] + - iw * params->in_strides[2]; - - read_n[i] = n; - read_ih[i] = ih; - read_iw[i] = iw; - } - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - if (weight_hw >= params->wS[1] * params->wS[0]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - return; - } - - int wh = (weight_hw / params->wS[1]); - int ww = (weight_hw % params->wS[1]); - - int flip_h = params->flip ? params->wS[0] - wh - 1 : wh; - int flip_w = params->flip ? params->wS[1] - ww - 1 : ww; - - int weight_h = flip_h * params->kdil[0]; - int weight_w = flip_w * params->kdil[1]; - - STEEL_PRAGMA_UNROLL - for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { - // Find bounds - int n = read_n[i]; - int ih = read_ih[i] + weight_h; - int iw = read_iw[i] + weight_w; - - // Read from input if in bounds - if ((n < params->N) && (ih >= 0 && ih < params->iS[0]) && - (iw >= 0 && iw < params->iS[1])) { - const device T* curr_src = src[i] + weight_h * params->in_strides[1] + - weight_w * params->in_strides[2]; - - STEEL_PRAGMA_UNROLL - for (short j = 0; j < n_channels; ++j) { - dst[is * dst_ld + j] = curr_src[j]; - } - - STEEL_PRAGMA_UNROLL - for (short j = n_channels; j < vec_size; ++j) { - dst[is * dst_ld + j] = T(0); - } - } - - // Zero pad otherwise - else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = T(0); - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - weight_hw += TCOLS; - } -}; - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short n_channels, - short tgp_padding = 0> -struct Conv2DWeightBlockLoaderSmallChannels { - // Destination dimensions - STEEL_CONST short BROWS = BN; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = ChannelHelper::vec_size; - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - // Leading dimension for src - const int src_ld; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - const constant MLXConvParams<2>* params; - - int weight_hw; - - const int read_n; - const bool do_read; - - /* Constructor */ - METAL_FUNC Conv2DWeightBlockLoaderSmallChannels( - const device T* src_, - threadgroup T* dst_, - const int2 offsets, - const constant MLXConvParams<2>* params_, - const constant ImplicitGemmConv2DParams* gemm_params_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_->wt_strides[0]), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld), - params(params_), - weight_hw(thread_idx % TCOLS), - read_n(offsets.y + bi), - do_read(read_n + BN <= gemm_params_->N) {} - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - if (bi >= BROWS || bj >= BCOLS) - return; - - if (read_n >= params->O || weight_hw >= params->wS[1] * params->wS[0]) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - - return; - } - - const device T* curr_src = src + weight_hw * (params->C / params->groups); - - if (BN != 8 || do_read) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < n_channels; j++) { - dst[i * dst_ld + j] = curr_src[i * src_ld + j]; - } - - STEEL_PRAGMA_UNROLL - for (short j = n_channels; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - } else { - for (short i = 0; i < BROWS; i += TROWS) { - if (((read_n + i) < params->O)) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < n_channels; j++) { - dst[i * dst_ld + j] = curr_src[i * src_ld + j]; - } - - STEEL_PRAGMA_UNROLL - for (short j = n_channels; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - weight_hw += TCOLS; - } -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h deleted file mode 100644 index 9043a3c4..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h +++ /dev/null @@ -1,381 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../../steel/defines.h" - -/////////////////////////////////////////////////////////////////////////////// -// Loading helper -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short tgp_padding = 0> -struct Conv2DInputBlockLoaderGeneral { - // Destination dimensions - STEEL_CONST short BROWS = BM; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4; - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - - const constant MLXConvParams<2>* params; - const constant Conv2DGeneralJumpParams* jump_params; - - const short base_wh; - const short base_ww; - - short weight_h; - short weight_w; - - const device T* src[n_rows]; - - int read_n[n_rows]; - int read_ih[n_rows]; - int read_iw[n_rows]; - - /* Constructor */ - METAL_FUNC Conv2DInputBlockLoaderGeneral( - const device T* src_, - threadgroup T* dst_, - const int4 offsets, - const constant MLXConvParams<2>* params_, - const constant Conv2DGeneralJumpParams* jump_params_, - const short base_wh_, - const short base_ww_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - params(params_), - jump_params(jump_params_), - base_wh(base_wh_), - base_ww(base_ww_), - weight_h(base_wh_), - weight_w(base_ww_) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; ++i) { - int offset_nhw = offsets.y + bi + i * TROWS; - int n = offset_nhw / jump_params->adj_out_hw; - int hw = offset_nhw % jump_params->adj_out_hw; - int oh = - (hw / jump_params->adj_out_w) * jump_params->f_out_jump_h + offsets.z; - int ow = - (hw % jump_params->adj_out_w) * jump_params->f_out_jump_w + offsets.w; - - int ih = oh * params->str[0] - params->pad[0]; - int iw = ow * params->str[1] - params->pad[1]; - - read_n[i] = n; - read_ih[i] = ih; - read_iw[i] = iw; - - // Read from input if in bounds - src[i] = src_ + n * params->in_strides[0] + bj; - } - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { - // Find bounds - int n = read_n[i]; - - int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; - int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; - - int ih_dil = read_ih[i] + h_flip * params->kdil[0]; - int iw_dil = read_iw[i] + w_flip * params->kdil[1]; - - int ih = ih_dil / params->idil[0]; - int iw = iw_dil / params->idil[1]; - - size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; - - // Read from input if in bounds - if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && - (iw_dil >= 0 && iw < params->iS[1])) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = (src[i])[offset + j]; - } - } - - // Zero pad otherwise - else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = T(0); - } - } - } - } - - METAL_FUNC void load_safe(const short remaining_k) const { - STEEL_PRAGMA_UNROLL - for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { - // Find bounds - int n = read_n[i]; - - int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; - int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; - - int ih_dil = read_ih[i] + h_flip * params->kdil[0]; - int iw_dil = read_iw[i] + w_flip * params->kdil[1]; - - int ih = ih_dil / params->idil[0]; - int iw = iw_dil / params->idil[1]; - - size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; - - // Read from input if in bounds - if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && - (iw_dil >= 0 && iw < params->iS[1])) { - if (bj + vec_size <= remaining_k) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = (src[i])[offset + j]; - } - } else { - for (short j = 0; j < vec_size; ++j) { - if (bj + j < remaining_k) { - dst[is * dst_ld + j] = (src[i])[offset + j]; - } else { - dst[is * dst_ld + j] = T(0); - } - } - } - } - - // Zero pad otherwise - else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; ++j) { - dst[is * dst_ld + j] = T(0); - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - weight_w += jump_params->f_wgt_jump_w; - if (weight_w < params->wS[1]) { - return; - } - - weight_w = base_ww; - - weight_h += jump_params->f_wgt_jump_h; - if (weight_h < params->wS[0]) { - return; - } - - weight_h = base_wh; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < n_rows; i++) { - src[i] += BK; - } - } -}; - -template < - typename T, - short BM, - short BN, - short BK, - short tgp_size, - short tgp_padding = 0> -struct Conv2DWeightBlockLoaderGeneral { - // Destination dimensions - STEEL_CONST short BROWS = BN; - STEEL_CONST short BCOLS = BK; - - // Read dimensions - STEEL_CONST short dst_ld = BCOLS + tgp_padding; - STEEL_CONST short vec_size = - (BN == 8) ? 1 : (tgp_size / (BROWS * BCOLS) >= 8 ? 8 : 4); - - // Thread read shape - STEEL_CONST short TCOLS = BCOLS / vec_size; - STEEL_CONST short TROWS = tgp_size / TCOLS; - - // Rows / strided reads within the block - STEEL_CONST short n_rows = BROWS / TROWS; - - // Leading dimension for src - const int src_ld; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - const constant MLXConvParams<2>* params; - const constant Conv2DGeneralJumpParams* jump_params; - - const short base_wh; - const short base_ww; - - short weight_h; - short weight_w; - - const int start_row; - - /* Constructor */ - METAL_FUNC Conv2DWeightBlockLoaderGeneral( - const device T* src_, - threadgroup T* dst_, - const int2 offsets, - const constant MLXConvParams<2>* params_, - const constant Conv2DGeneralJumpParams* jump_params_, - const short base_wh_, - const short base_ww_, - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(params_->wt_strides[0]), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj), - params(params_), - jump_params(jump_params_), - base_wh(base_wh_), - base_ww(base_ww_), - weight_h(base_wh_), - weight_w(base_ww_), - start_row(offsets.y + bi) {} - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - const device T* curr_src = src + weight_h * params->wt_strides[1] + - weight_w * params->wt_strides[2]; - - if ((start_row + BN <= params->O)) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BN; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = curr_src[i * src_ld + j]; - } - } - } else { - for (short i = 0; i < BN; i += TROWS) { - if ((start_row + i) < params->O) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = curr_src[i * src_ld + j]; - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - } - } - } - - METAL_FUNC void load_safe(const short remaining_k) const { - const device T* curr_src = src + weight_h * params->wt_strides[1] + - weight_w * params->wt_strides[2]; - - if ((start_row + BN <= params->O)) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BN; i += TROWS) { - if (bj + vec_size <= remaining_k) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = curr_src[i * src_ld + j]; - } - } else { - for (short j = 0; j < vec_size; j++) { - if (bj + j < remaining_k) { - dst[i * dst_ld + j] = curr_src[i * src_ld + j]; - } else { - dst[i * dst_ld + j] = T(0); - } - } - } - } - } else { - for (short i = 0; i < BN; i += TROWS) { - if ((start_row + i) < params->O) { - if (bj + vec_size <= remaining_k) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = curr_src[i * src_ld + j]; - } - } else { - for (short j = 0; j < vec_size; j++) { - if (bj + j < remaining_k) { - dst[i * dst_ld + j] = curr_src[i * src_ld + j]; - } else { - dst[i * dst_ld + j] = T(0); - } - } - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - weight_w += jump_params->f_wgt_jump_w; - if (weight_w < params->wS[1]) { - return; - } - - weight_w = base_ww; - - weight_h += jump_params->f_wgt_jump_h; - if (weight_h < params->wS[0]) { - return; - } - - weight_h = base_wh; - - src += BK; - } -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/params.h b/Source/Cmlx/mlx-generated/metal/steel/conv/params.h deleted file mode 100644 index 67d38274..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/params.h +++ /dev/null @@ -1,103 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -template -struct MLXConvParams { - int N; // Batch size - int C; // In channels - int O; // Out channels - int iS[NDIM]; // Input spatial dim - int wS[NDIM]; // Weight spatial dim - int oS[NDIM]; // Output spatial dim - int str[NDIM]; // Kernel strides - int pad[NDIM]; // Input padding - int kdil[NDIM]; // Kernel dilation - int idil[NDIM]; // Input dilation - int64_t in_strides[NDIM + 2]; // In strides - int64_t wt_strides[NDIM + 2]; // Wt strides - int64_t out_strides[NDIM + 2]; // Out strides - int groups; // Input channel groups - bool flip; - - static MLXConvParams - with_padded_channels(MLXConvParams other, int pad_out, int pad_in) { - MLXConvParams params = other; - - // Update strides - for (int i = 0; i < NDIM + 1; i++) { - params.in_strides[i] = - (params.in_strides[i] / params.C) * (params.C + pad_in); - params.wt_strides[i] = - (params.wt_strides[i] / params.C) * (params.C + pad_in); - params.out_strides[i] = - (params.out_strides[i] / params.O) * (params.O + pad_out); - } - params.in_strides[NDIM + 1] = 1; - params.wt_strides[NDIM + 1] = 1; - params.out_strides[NDIM + 1] = 1; - - // Update channels - params.C += pad_in; - params.O += pad_out; - - return params; - }; -}; - -namespace mlx { -namespace steel { - -struct ImplicitGemmConv2DParams { - const int M; - const int N; - const int K; - - const int gemm_k_iterations; - - const int inp_jump_w; - const int inp_jump_h; - const int inp_jump_c; - - const int tiles_n; - const int tiles_m; - const int swizzle_log; -}; - -struct ImplicitGemmConv3DParams { - const int M; - const int N; - const int K; - - const int gemm_k_iterations; - - const int inp_jump_w; - const int inp_jump_h; - const int inp_jump_d; - const int inp_jump_c; - - const int tiles_n; - const int tiles_m; - const int swizzle_log; -}; - -struct Conv2DGeneralJumpParams { - const int f_wgt_jump_h; - const int f_wgt_jump_w; - - const int f_out_jump_h; - const int f_out_jump_w; - - const int adj_out_h; - const int adj_out_w; - const int adj_out_hw; - const int adj_implicit_m; -}; - -struct Conv2DGeneralBaseInfo { - int weight_base; - int weight_size; -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/defines.h b/Source/Cmlx/mlx-generated/metal/steel/defines.h deleted file mode 100644 index f5657ee3..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/defines.h +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#define STEEL_CONST static constant constexpr const -#define STEEL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") -#define STEEL_PRAGMA_NO_UNROLL _Pragma("clang loop unroll(disable)") diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm.h deleted file mode 100644 index 697a8b56..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm.h +++ /dev/null @@ -1,295 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../steel/gemm/loader.h" -#include "../../steel/gemm/mma.h" -#include "../../steel/gemm/params.h" -#include "../../steel/gemm/transforms.h" -#include "../../steel/utils.h" - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernel class -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template -struct LoopAlignment {}; - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned, - typename AccumType = typename AccumHelper::accum_type, - typename Epilogue = TransformNone> -struct GEMMKernel { - STEEL_CONST short tgp_padding_a = 16 / sizeof(T); - STEEL_CONST short tgp_padding_b = 16 / sizeof(T); - STEEL_CONST short tgp_mem_size_a = - transpose_a ? BK * (BM + tgp_padding_a) : BM * (BK + tgp_padding_a); - STEEL_CONST short tgp_mem_size_b = - transpose_b ? BN * (BK + tgp_padding_b) : BK * (BN + tgp_padding_b); - STEEL_CONST short tgp_mem_size = tgp_mem_size_a + tgp_mem_size_b; - - STEEL_CONST short tgp_size = WM * WN * 32; - - using loader_a_t = BlockLoader< - T, - transpose_a ? BK : BM, - transpose_a ? BM : BK, - transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, - !transpose_a, - tgp_size>; - using loader_b_t = BlockLoader< - T, - transpose_b ? BN : BK, - transpose_b ? BK : BN, - transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, - transpose_b, - tgp_size>; - using mma_t = BlockMMA< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - transpose_a ? BM + tgp_padding_a : BK + tgp_padding_a, - transpose_b ? BK + tgp_padding_b : BN + tgp_padding_b, - AccumType, - Epilogue>; - - /* Main kernel function */ - template - static METAL_FUNC void gemm_loop( - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - const int gemm_k_iterations, - thread loader_a_t& loader_a, - thread loader_b_t& loader_b, - thread mma_t& mma_op, - thread const short& tgp_bm, - thread const short& tgp_bn, - thread const short& lbk, - LoopAlignment l = {}) { - // Appease the compiler - (void)l; - - short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - - short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - if (M_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(tile_dims_A); - } - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - if (!K_aligned_) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - short2 tile_dims_A_last = - transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); - short2 tile_dims_B_last = - transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); - - loader_a.load_safe(tile_dims_A_last); - loader_b.load_safe(tile_dims_B_last); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - } - - /* Main kernel function */ - static METAL_FUNC void run( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device U* D [[buffer(2)]], - const constant GEMMParams* params [[buffer(3)]], - threadgroup T* As [[threadgroup(0)]], - threadgroup T* Bs [[threadgroup(1)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Pacifying compiler - (void)lid; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - threadgroup_barrier(mem_flags::mem_none); - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - D += c_row_long * params->ldd + c_col_long; - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (MN_aligned) { - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Loop tail - if (!K_aligned) { - int lbk = params->K - params->gemm_k_iterations_aligned * BK; - short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); - short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - - // Store results to device memory - mma_op.store_result(D, params->ldd); - return; - - } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { // Loop over K - unaligned case - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - short leftover_bk = params->K - params->gemm_k_iterations_aligned * BK; - - if (tgp_bm == BM && tgp_bn == BN) { - gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk); - - mma_op.store_result(D, params->ldd); - return; - - } else if (tgp_bn == BN) { - gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk); - - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - return; - - } else if (tgp_bm == BM) { - gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk); - - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - return; - - } else { - gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk); - - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - return; - } - } - } -}; - -} // namespace steel -} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm_nax.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm_nax.h deleted file mode 100644 index 9ccd2a96..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/gemm_nax.h +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include "../../steel/gemm/nax.h" -#include "../../steel/gemm/params.h" -#include "../../steel/gemm/transforms.h" -#include "../../steel/utils.h" - -using namespace metal; - -namespace mlx::steel { - -template < - typename T, - short SM, - short SN, - short SK, - short BK, - bool transpose_a, - bool transpose_b, - bool kAlignedM, - bool kAlignedN, - bool kAlignedK, - short UM, - short UN, - short UK, - typename AccumType = float> -auto gemm_loop( - const device T* A, - const device T* B, - int lda, - int ldb, - int K, - int gemm_k_iterations_aligned, - const short sgp_sm, - const short sgp_sn) { - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - constexpr short TK = SK / UK; - - constexpr int RA = transpose_a ? TK : TM; - constexpr int CA = transpose_a ? TM : TK; - - constexpr int RB = transpose_b ? TN : TK; - constexpr int CB = transpose_b ? TK : TN; - - using DSubTile = NAXSubTile; - using ASubTile = - NAXSubTile; - using BSubTile = - NAXSubTile; - - NAXTile Dtile; - Dtile.clear(); - - int gemm_k_iterations_ = gemm_k_iterations_aligned; - - STEEL_PRAGMA_NO_UNROLL - for (int kk0 = 0; kk0 < gemm_k_iterations_; kk0++) { - threadgroup_barrier(mem_flags::mem_none); - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < BK; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - const int k = kk1; - - volatile int compiler_barrier; - - const int A_offset = transpose_a ? k * lda : k; - const int B_offset = transpose_b ? k : k * ldb; - - if constexpr (kAlignedM) { - Atile.load(A + A_offset, lda); - } else { - const short rmax = transpose_a ? SK : sgp_sm; - const short cmax = transpose_a ? sgp_sm : SK; - Atile.load_safe(A + A_offset, lda, short2(cmax, rmax)); - } - - if constexpr (kAlignedN) { - Btile.load(B + B_offset, ldb); - } else { - const short rmax = transpose_b ? sgp_sn : SK; - const short cmax = transpose_b ? SK : sgp_sn; - Btile.load_safe(B + B_offset, ldb, short2(cmax, rmax)); - } - - tile_matmad_nax( - Dtile, - Atile, - metal::bool_constant{}, - Btile, - metal::bool_constant{}); - - (void)compiler_barrier; - } - - A += transpose_a ? (BK * lda) : BK; - B += transpose_b ? BK : (BK * ldb); - } - - if constexpr (!kAlignedK) { - simdgroup_barrier(mem_flags::mem_none); - - const short rem_bk = K - gemm_k_iterations_ * BK; - - STEEL_PRAGMA_NO_UNROLL - for (int kk1 = 0; kk1 < rem_bk; kk1 += SK) { - NAXTile Atile; - NAXTile Btile; - - STEEL_PRAGMA_UNROLL - for (int mm = 0; mm < TM; mm++) { - STEEL_PRAGMA_UNROLL - for (int nn = 0; nn < TN; nn++) { - STEEL_PRAGMA_UNROLL - for (int kk = 0; kk < TK; kk++) { - const int m = mm * UM; - const int n = nn * UN; - const int k = kk1 + kk * UK; - const short psk = max(0, rem_bk - k); - - const int A_offset = transpose_a ? (m + k * lda) : (m * lda + k); - const int B_offset = transpose_b ? (k + n * ldb) : (k * ldb + n); - - { - const short psm = kAlignedM ? SM : max(0, sgp_sm - m); - const short rmax = transpose_a ? psk : psm; - const short cmax = transpose_a ? psm : psk; - Atile.load_safe(A + A_offset, lda, short2(cmax, rmax)); - } - - { - const short psn = kAlignedN ? SN : max(0, sgp_sn - n); - const short rmax = transpose_b ? psn : psk; - const short cmax = transpose_b ? psk : psn; - Btile.load_safe(B + B_offset, ldb, short2(cmax, rmax)); - } - - subtile_matmad_nax( - Dtile.subtile_at(mm, nn), - Atile.subtile_at(0, 0), - metal::bool_constant{}, - Btile.subtile_at(0, 0), - metal::bool_constant{}); - } - } - } - } - } - - return Dtile; -} - -} // namespace mlx::steel diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h deleted file mode 100644 index 85830872..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h +++ /dev/null @@ -1,346 +0,0 @@ -// Copyright © 2024 Apple Inc. - -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -constant bool has_batch [[function_constant(10)]]; - -constant bool use_out_source [[function_constant(100)]]; -constant bool do_axpby [[function_constant(110)]]; - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -// clang-format off -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - const device T* C [[buffer(2), function_constant(use_out_source)]], - device T* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], - const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { // clang-format on - // Pacifying compiler - (void)lid; - - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - true, - true, - AccumType>; - - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - // Find block - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - // Exit early if out of bounds - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - // Adjust for batch - if (has_batch) { - const constant auto* A_bstrides = batch_strides; - const constant auto* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - if (use_out_source) { - const constant auto* C_bstrides = B_bstrides + params->batch_ndim; - C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); - } - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - - if (use_out_source) { - C += addmm_params->batch_stride_c * tid.z; - } - } - - D += params->batch_stride_d * tid.z; - - // Prepare threadgroup memory - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - threadgroup_barrier(mem_flags::mem_none); - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - D += c_row_long * params->ldd + c_col_long; - - if (use_out_source) { - C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; - } - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup bounds - const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); - const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); - - // Prepare iterations - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - // Do unaligned K iterations first - if (!align_K) { - const int k_last = params->gemm_k_iterations_aligned * BK; - const int k_remain = params->K - k_last; - const size_t k_jump_a = - transpose_a ? params->lda * size_t(k_last) : size_t(k_last); - const size_t k_jump_b = - transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); - - // Move loader source ahead to end - loader_a.src += k_jump_a; - loader_b.src += k_jump_b; - - // Load tile - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Do matmul - mma_op.mma(As, Bs); - - // Reset source back to start - loader_a.src -= k_jump_a; - loader_b.src -= k_jump_b; - } - - const TransformAdd epilogue_op_add( - addmm_params->alpha, addmm_params->beta); - const TransformAxpby epilogue_op_axpby( - addmm_params->alpha, addmm_params->beta); - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (align_M && align_N) { - // Do gemm - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue( - C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); - } else { - mma_op.apply_epilogue( - C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result(D, params->ldd); - - } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { // Loop over K - unaligned case - const int leftover_bk = 0; - - if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { - // Do gemm - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue( - C, addmm_params->ldc, addmm_params->fdc, epilogue_op_axpby); - } else { - mma_op.apply_epilogue( - C, addmm_params->ldc, addmm_params->fdc, epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result(D, params->ldd); - - } else if (align_N || tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_axpby); - } else { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - - } else if (align_M || tgp_bm == BM) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_axpby); - } else { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - - } else { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - - // Do epilogue - if (use_out_source) { - if (do_axpby) { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_axpby); - } else { - mma_op.apply_epilogue_safe( - C, - addmm_params->ldc, - addmm_params->fdc, - short2(tgp_bn, tgp_bm), - epilogue_op_add); - } - } - - // Store results to device memory - return mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused_nax.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused_nax.h deleted file mode 100644 index 4ff92606..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused_nax.h +++ /dev/null @@ -1,219 +0,0 @@ -// Copyright © 2025 Apple Inc. - -using namespace mlx::steel; - -constant bool has_batch [[function_constant(10)]]; - -constant bool use_out_source [[function_constant(100)]]; -constant bool do_axpby [[function_constant(110)]]; - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -// clang-format off -template < - bool kAlignedM, - bool kAlignedN, - typename NAXTile_t, - typename T> -void gemm_epilogue( - thread NAXTile_t& Dtile, - const device T* C, - const constant GEMMParams* params, - const constant GEMMAddMMParams* addmm_params, - const short sgp_sm, - const short sgp_sn) { // clang-format on - - (void)params; - - constexpr short UM = NAXTile_t::kSubTileRows; - constexpr short UN = NAXTile_t::kSubTileCols; - using CSubTile = NAXSubTile; - - using V = typename NAXTile_t::elem_type; - - constexpr short TM = NAXTile_t::kTileRows; - constexpr short TN = NAXTile_t::kTileCols; - constexpr short kElemsPerSubTile = NAXTile_t::kElemsPerSubTile; - - STEEL_PRAGMA_UNROLL - for (short mm = 0; mm < TM; mm++) { - STEEL_PRAGMA_UNROLL - for (short nn = 0; nn < TN; nn++) { - const short m = mm * UM; - const short n = nn * UN; - - CSubTile CTile; - - if constexpr (kAlignedM && kAlignedN) { - CTile.load(C, addmm_params->ldc, addmm_params->fdc, m, n); - } else { - CTile.load_safe( - C, addmm_params->ldc, addmm_params->fdc, sgp_sm, sgp_sn, m, n); - } - - auto delems = Dtile.subtile_at(mm, nn).elems(); - auto celems = CTile.elems(); - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemsPerSubTile; i++) { - if (do_axpby) { - delems[i] = addmm_params->alpha * delems[i] + - addmm_params->beta * static_cast(celems[i]); - } else { - delems[i] += static_cast(celems[i]); - } - } - } - } -} - -// clang-format off -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - const device T* C [[buffer(2), function_constant(use_out_source)]], - device T* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], - const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on - // Find block - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - // Exit early if out of bounds - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - // Adjust for batch - if (has_batch) { - const constant auto* A_bstrides = batch_strides; - const constant auto* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - if (use_out_source) { - const constant auto* C_bstrides = B_bstrides + params->batch_ndim; - C += elem_to_loc(tid.z, batch_shape, C_bstrides, params->batch_ndim); - } - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - - if (use_out_source) { - C += addmm_params->batch_stride_c * tid.z; - } - } - - D += params->batch_stride_d * tid.z; - - // Prepare threadgroup memory - threadgroup_barrier(mem_flags::mem_none); - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - D += c_row_long * params->ldd + c_col_long; - - if (use_out_source) { - C += c_row_long * addmm_params->ldc + c_col_long * addmm_params->fdc; - } - - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; - constexpr short SM = BM / WM; - constexpr short SN = BN / WN; - constexpr short SK = 32; - - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - - const short tm = SM * (simd_group_id / WN); - const short tn = SN * (simd_group_id % WN); - - const int sgp_sm_int = - align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); - const short sgp_sm = short(sgp_sm_int); - const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); - - const int sgp_sn_int = - align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); - const short sgp_sn = short(sgp_sn_int); - const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); - - A += transpose_a ? tm : (tm * params->lda); - B += transpose_b ? (tn * params->ldb) : tn; - D += tm * params->ldd + tn; - - if (use_out_source) { - C += tm * addmm_params->ldc + tn * addmm_params->fdc; - } - - using DSubTile = NAXSubTile; - NAXTile Dtile; - - dispatch_bool(align_K, [&](auto kAlignedK) { - dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { - dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { - Dtile = gemm_loop< - T, - SM, - SN, - SK, - BK, - transpose_a, - transpose_b, - kAlignedM.value, - kAlignedN.value, - kAlignedK.value, - UM, - UN, - UK, - AccumType>( - A, - B, - params->lda, - params->ldb, - params->K, - params->gemm_k_iterations_aligned, - sgp_sm, - sgp_sn); - if (use_out_source) { - gemm_epilogue( - Dtile, C, params, addmm_params, sgp_sm, sgp_sn); - } - if constexpr (kAlignedM && kAlignedN) { - Dtile.store(D, int(params->ldd)); - } else { - Dtile.store_safe(D, int(params->ldd), short2(sgp_sn, sgp_sm)); - } - }); - }); - }); -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather.h deleted file mode 100644 index 4c055e69..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather.h +++ /dev/null @@ -1,459 +0,0 @@ -// Copyright © 2024 Apple Inc. - -using namespace mlx::steel; - -constant bool has_batch [[function_constant(10)]]; -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gather_mm_rhs( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - const device uint32_t* rhs_indices [[buffer(2)]], - device T* C [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]]) { - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - true, - true, - AccumType>; - - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - if (params->tiles_n <= static_cast(tid.x) || - params->tiles_m <= static_cast(tid.y)) { - return; - } - - // Prepare threadgroup memory - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - // Find the block in A, B, C - const int c_row = tid.y * BM; - const int c_col = tid.x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - // Prepare threadgroup bounds - const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); - const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - C += c_row_long * params->ldd + c_col_long; - - // Do as many matmuls as necessary - uint32_t index; - short offset; - uint32_t index_next = rhs_indices[c_row]; - short offset_next = 0; - int n = 0; - while (n < tgp_bm) { - n++; - offset = offset_next; - index = index_next; - offset_next = tgp_bm; - for (; n < tgp_bm; n++) { - if (rhs_indices[c_row + n] != index) { - offset_next = n; - index_next = rhs_indices[c_row + n]; - break; - } - } - threadgroup_barrier(mem_flags::mem_none); - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b( - B + index * params->batch_stride_b, - params->ldb, - Bs, - simd_group_id, - simd_lane_id); - - // Prepare iterations - const int gemm_k_iterations = params->gemm_k_iterations_aligned; - - // Do unaligned K iterations first - if (!align_K) { - const int k_last = params->gemm_k_iterations_aligned * BK; - const int k_remain = params->K - k_last; - const size_t k_jump_a = - transpose_a ? params->lda * size_t(k_last) : size_t(k_last); - const size_t k_jump_b = - transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); - - // Move loader source ahead to end - loader_a.src += k_jump_a; - loader_b.src += k_jump_b; - - // Load tile - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Do matmul - mma_op.mma(As, Bs); - - // Reset source back to start - loader_a.src -= k_jump_a; - loader_b.src -= k_jump_b; - } - - // Matrix level aligned never check - if (align_M && align_N) { - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - // Store results to device memory - if (offset_next - offset == BM) { - mma_op.store_result(C, params->ldd); - } else { - mma_op.store_result_slice( - C, params->ldd, short2(0, offset), short2(BN, offset_next)); - } - } else { - const short lbk = 0; - - // Tile aligned don't check - if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - lbk, - LoopAlignment{}); - if (offset_next - offset == BM) { - mma_op.store_result(C, params->ldd); - } else { - mma_op.store_result_slice( - C, params->ldd, short2(0, offset), short2(BN, offset_next)); - } - } - - // Tile partially aligned check rows - else if (align_N || tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - lbk, - LoopAlignment{}); - mma_op.store_result_slice( - C, params->ldd, short2(0, offset), short2(BN, offset_next)); - } - - // Tile partially aligned check cols - else if (align_M || tgp_bm == BM) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - lbk, - LoopAlignment{}); - mma_op.store_result_slice( - C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); - } - - // Nothing aligned so check both rows and cols - else { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - lbk, - LoopAlignment{}); - mma_op.store_result_slice( - C, params->ldd, short2(0, offset), short2(tgp_bn, offset_next)); - } - } - } -} - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gather_mm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - const device uint32_t* lhs_indices [[buffer(2)]], - const device uint32_t* rhs_indices [[buffer(3)]], - device T* C [[buffer(4)]], - const constant GEMMParams* params [[buffer(5)]], - const constant int* indices_shape [[buffer(6)]], - const constant int64_t* lhs_strides [[buffer(7)]], - const constant int64_t* rhs_strides [[buffer(8)]], - const constant int& batch_ndim_a [[buffer(9)]], - const constant int* batch_shape_a [[buffer(10)]], - const constant int64_t* batch_strides_a [[buffer(11)]], - const constant int& batch_ndim_b [[buffer(12)]], - const constant int* batch_shape_b [[buffer(13)]], - const constant int64_t* batch_strides_b [[buffer(14)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]]) { - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - true, - true, - AccumType>; - - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - if (params->tiles_n <= static_cast(tid.x) || - params->tiles_m <= static_cast(tid.y)) { - return; - } - - // Move A and B to the locations pointed by lhs_indices and rhs_indices. - uint32_t indx_A, indx_B; - if (has_batch) { - ulong2 indices_offsets = elem_to_loc_broadcast( - tid.z, indices_shape, lhs_strides, rhs_strides, params->batch_ndim); - indx_A = lhs_indices[indices_offsets.x]; - indx_B = rhs_indices[indices_offsets.y]; - } else { - indx_A = lhs_indices[params->batch_stride_a * tid.z]; - indx_B = rhs_indices[params->batch_stride_b * tid.z]; - } - A += elem_to_loc(indx_A, batch_shape_a, batch_strides_a, batch_ndim_a); - B += elem_to_loc(indx_B, batch_shape_b, batch_strides_b, batch_ndim_b); - C += params->batch_stride_d * tid.z; - - // Prepare threadgroup memory - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - // Just make sure everybody's finished with the indexing math above. - threadgroup_barrier(mem_flags::mem_none); - - // Find block in A, B, C - const int c_row = tid.y * BM; - const int c_col = tid.x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - C += c_row_long * params->ldd + c_col_long; - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup bounds - const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); - const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); - - // Prepare iterations - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - // Do unaligned K iterations first - if (!align_K) { - const int k_last = params->gemm_k_iterations_aligned * BK; - const int k_remain = params->K - k_last; - const size_t k_jump_a = - transpose_a ? params->lda * size_t(k_last) : size_t(k_last); - const size_t k_jump_b = - transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); - - // Move loader source ahead to end - loader_a.src += k_jump_a; - loader_b.src += k_jump_b; - - // Load tile - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Do matmul - mma_op.mma(As, Bs); - - // Reset source back to start - loader_a.src -= k_jump_a; - loader_b.src -= k_jump_b; - } - - // Matrix level aligned never check - if (align_M && align_N) { - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - // Store results to device memory - mma_op.store_result(C, params->ldd); - } else { - const short lbk = 0; - - // Tile aligned don't check - if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - lbk, - LoopAlignment{}); - mma_op.store_result(C, params->ldd); - } - - // Tile partially aligned check rows - else if (align_N || tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - lbk, - LoopAlignment{}); - mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); - } - - // Tile partially aligned check cols - else if (align_M || tgp_bm == BM) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - lbk, - LoopAlignment{}); - mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); - } - - // Nothing aligned so check both rows and cols - else { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - lbk, - LoopAlignment{}); - mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather_nax.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather_nax.h deleted file mode 100644 index 67cd7378..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_gather_nax.h +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright © 2024 Apple Inc. - -using namespace mlx::steel; - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; -constant bool align_K [[function_constant(202)]]; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void -gather_mm_rhs_nax( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - const device uint32_t* rhs_indices [[buffer(2)]], - device T* C [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]]) { - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; - constexpr short SM = BM / WM; - constexpr short SN = BN / WN; - constexpr short SK = 32; - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - - if (params->tiles_n <= static_cast(tid.x) || - params->tiles_m <= static_cast(tid.y)) { - return; - } - - // Find the block in A, B, C - const int c_row = tid.y * BM; - const int c_col = tid.x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - C += c_row_long * params->ldd + c_col_long; - rhs_indices += c_row; - - const short tm = SM * (simd_group_id / WN); - const short tn = SN * (simd_group_id % WN); - - const int sgp_sm_int = - align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); - const short sgp_sm = short(sgp_sm_int); - const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); - - const int sgp_sn_int = - align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); - const short sgp_sn = short(sgp_sn_int); - const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); - - A += transpose_a ? tm : (tm * params->lda); - B += transpose_b ? (tn * params->ldb) : tn; - C += tm * params->ldd + tn; - rhs_indices += tm; - - // Do as many matmuls as necessary - uint32_t index; - short offset; - uint32_t index_next = rhs_indices[0]; - short offset_next = 0; - int n = 0; - while (n < sgp_sm) { - n++; - offset = offset_next; - index = index_next; - offset_next = sgp_sm; - for (; n < sgp_sm; n++) { - if (rhs_indices[n] != index) { - offset_next = n; - index_next = rhs_indices[n]; - break; - } - } - threadgroup_barrier(mem_flags::mem_none); - - using DSubTile = NAXSubTile; - NAXTile Ctile; - - dispatch_bool(align_K, [&](auto kAlignedK) { - dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { - dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { - auto do_gemm = gemm_loop< - T, - SM, - SN, - SK, - BK, - transpose_a, - transpose_b, - kAlignedM.value, - kAlignedN.value, - kAlignedK.value, - UM, - UN, - UK, - AccumType>; - Ctile = do_gemm( - A, - B + index * params->batch_stride_b, - params->lda, - params->ldb, - params->K, - params->gemm_k_iterations_aligned, - sgp_sm, - sgp_sn); - - if constexpr (kAlignedN.value) { - if (offset_next - offset == SM) { - Ctile.store(C, int(params->ldd)); - } else { - Ctile.store_slice( - C, - int(params->ldd), - short2(0, offset), - short2(SN, offset_next)); - } - } else { - Ctile.store_slice( - C, - int(params->ldd), - short2(0, offset), - short2(sgp_sn, offset_next)); - } - }); - }); - }); - } -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h deleted file mode 100644 index 6546215e..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_masked.h +++ /dev/null @@ -1,719 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#include "../../../steel/defines.h" -using namespace metal; -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -struct _NoMask { - char x; - - constexpr METAL_FUNC operator bool() { - return true; - } - constexpr METAL_FUNC operator bool() const threadgroup { - return true; - } - constexpr METAL_FUNC operator bool() const device { - return true; - } - constexpr METAL_FUNC operator bool() const constant { - return true; - } -}; - -template -struct ScaleOp { - OutT scale; - - METAL_FUNC OutT apply(InT x) const { - return static_cast(x) * scale; - } -}; - -typedef struct _NoMask nomask_t; - -template < - typename T, - typename out_mask_t, - typename op_mask_t, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void -block_masked_gemm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], - const device out_mask_t* out_mask [[buffer(10)]], - const device op_mask_t* lhs_mask [[buffer(11)]], - const device op_mask_t* rhs_mask [[buffer(12)]], - const constant int* mask_strides [[buffer(13)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Appease the compiler - (void)lid; - - static_assert( - BM == BN, - "block_masked_gemm must have the same block M and block N size"); - static_assert(BM % BK == 0, "block_masked_gemm must have BM % BK == 0"); - - constexpr bool has_operand_mask = !metal::is_same_v; - constexpr bool has_output_mask = !metal::is_same_v; - - constexpr bool has_mul_operand_mask = - has_operand_mask && !metal::is_same_v; - constexpr bool has_mul_output_mask = - has_output_mask && !metal::is_same_v; - - constexpr short k_mask_factor = short(BM / BK); - - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - MN_aligned, - K_aligned>; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - const constant auto* mask_batch_strides = - batch_strides + 2 * params->batch_ndim; - - if (params->batch_ndim > 1) { - if (has_output_mask) { - out_mask += elem_to_loc( - tid.z, batch_shape, mask_batch_strides, params->batch_ndim); - - mask_batch_strides += params->batch_ndim; - } - - if (has_operand_mask) { - const constant auto* mask_strides_lhs = mask_batch_strides; - const constant auto* mask_strides_rhs = - mask_strides_lhs + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, - batch_shape, - mask_strides_lhs, - mask_strides_rhs, - params->batch_ndim); - - lhs_mask += batch_offsets.x; - rhs_mask += batch_offsets.y; - } - } else { - if (has_output_mask) { - out_mask += tid.z * mask_batch_strides[0]; - mask_batch_strides += params->batch_ndim; - } - - if (has_operand_mask) { - lhs_mask += tid.z * mask_batch_strides[0]; - rhs_mask += tid.z * mask_batch_strides[params->batch_ndim]; - } - } - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant auto* A_bstrides = batch_strides; - const constant auto* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - } - - D += params->batch_stride_d * tid.z; - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - D += c_row_long * params->ldd + c_col_long; - - const constant int* out_mask_strides = mask_strides; - const constant int* lhs_mask_strides = - mask_strides + (has_output_mask ? 2 : 0); - const constant int* rhs_mask_strides = - lhs_mask_strides + (has_operand_mask ? 2 : 0); - - const int out_mask_offset = !has_output_mask - ? 0 - : tid_y * out_mask_strides[1] + tid_x * out_mask_strides[0]; - int lhs_mask_offset = !has_operand_mask ? 0 : tid_y * lhs_mask_strides[1]; - int rhs_mask_offset = !has_operand_mask ? 0 : tid_x * rhs_mask_strides[0]; - const int lhs_mask_step = !has_operand_mask ? 0 : lhs_mask_strides[0]; - const int rhs_mask_step = !has_operand_mask ? 0 : rhs_mask_strides[1]; - short k_factor_cnt = k_mask_factor; - - ScaleOp out_mask_op; - ScaleOp lhs_mask_op; - ScaleOp rhs_mask_op; - - if (has_output_mask) { - auto mask_out = out_mask[out_mask_offset]; - - if (has_mul_output_mask) { - out_mask_op.scale = float(mask_out); - } - - // Write zeros and return - if (!mask_out) { - constexpr short tgp_size = WM * WN * 32; - constexpr short vec_size = 4; - - // Tile threads in threadgroup - constexpr short TN = BN / vec_size; - constexpr short TM = tgp_size / TN; - - const short thread_idx = simd_group_id * 32 + simd_lane_id; - const short bi = thread_idx / TN; - const short bj = vec_size * (thread_idx % TN); - - D += bi * params->ldd + bj; - - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - - if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - for (short ti = 0; ti < BM; ti += TM) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - D[ti * params->ldd + j] = T(0.); - } - } - } else { - short jmax = tgp_bn - bj; - jmax = jmax < vec_size ? jmax : vec_size; - for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { - for (short j = 0; j < jmax; j++) { - D[ti * params->ldd + j] = T(0.); - } - } - } - - return; - } - } - - threadgroup_barrier(mem_flags::mem_none); - - // Prepare threadgroup mma operation - thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - // Prepare threadgroup loading operations - thread typename gemm_kernel::loader_a_t loader_a( - A, params->lda, As, simd_group_id, simd_lane_id); - thread typename gemm_kernel::loader_b_t loader_b( - B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup bounds - const short tgp_bm = - MN_aligned ? short(BM) : short(min(BM, params->M - c_row)); - const short tgp_bn = - MN_aligned ? short(BN) : short(min(BN, params->N - c_col)); - - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - /////////////////////////////////////////////////////////////////////////////// - // Do unaligned K iterations first - if (!K_aligned) { - const int k_last = params->gemm_k_iterations_aligned * BK; - const int mask_idx_last = k_last / BM; - - if (!has_operand_mask || - (bool(lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]) && - bool(rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]))) { - if (has_mul_operand_mask) { - lhs_mask_op.scale = - lhs_mask[lhs_mask_offset + mask_idx_last * lhs_mask_step]; - rhs_mask_op.scale = - rhs_mask[rhs_mask_offset + mask_idx_last * rhs_mask_step]; - } - - // Move loader source ahead to end - const int k_remain = params->K - k_last; - const size_t k_jump_a = - transpose_a ? params->lda * size_t(k_last) : size_t(k_last); - const size_t k_jump_b = - transpose_b ? size_t(k_last) : params->ldb * size_t(k_last); - - loader_a.src += k_jump_a; - loader_b.src += k_jump_b; - - // Load tile - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - if (has_mul_operand_mask) { - loader_a.apply_inplace_op(lhs_mask_op); - loader_b.apply_inplace_op(rhs_mask_op); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Do matmul - mma_op.mma(As, Bs); - - // Reset source back to start - loader_a.src -= k_jump_a; - loader_b.src -= k_jump_b; - } - } - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (MN_aligned) { - for (; gemm_k_iterations > 0; gemm_k_iterations--) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (!has_operand_mask || - (bool(lhs_mask[lhs_mask_offset]) && - bool(rhs_mask[rhs_mask_offset]))) { - if (has_mul_operand_mask) { - lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; - rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; - } - - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - if (has_mul_operand_mask) { - loader_a.apply_inplace_op(lhs_mask_op); - loader_b.apply_inplace_op(rhs_mask_op); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - - k_factor_cnt--; - lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; - rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; - k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; - } - - if (has_mul_output_mask) { - mma_op.apply_epilogue(out_mask_op); - } - - // Store results to device memory - mma_op.store_result(D, params->ldd); - return; - - } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { - const bool M_aligned = (tgp_bm == BM); - const bool N_aligned = (tgp_bn == BN); - - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - for (; gemm_k_iterations > 0; gemm_k_iterations--) { - threadgroup_barrier(mem_flags::mem_threadgroup); - if (!has_operand_mask || - (bool(lhs_mask[lhs_mask_offset]) && - bool(rhs_mask[rhs_mask_offset]))) { - if (has_mul_operand_mask) { - lhs_mask_op.scale = lhs_mask[lhs_mask_offset]; - rhs_mask_op.scale = rhs_mask[rhs_mask_offset]; - } - - // Load elements into threadgroup - if (M_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(tile_dims_A); - } - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } - - if (has_mul_operand_mask) { - loader_a.apply_inplace_op(lhs_mask_op); - loader_b.apply_inplace_op(rhs_mask_op); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - - k_factor_cnt--; - lhs_mask_offset += k_factor_cnt == 0 ? lhs_mask_step : 0; - rhs_mask_offset += k_factor_cnt == 0 ? rhs_mask_step : 0; - k_factor_cnt = k_factor_cnt == 0 ? k_mask_factor : k_factor_cnt; - } - - if (has_mul_output_mask) { - mma_op.apply_epilogue(out_mask_op); - } - - if (M_aligned && N_aligned) { - mma_op.store_result(D, params->ldd); - } else { - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - } - } -} - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned, - bool has_operand_mask = false> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void -block_masked_gemm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device T* D [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], - const device bool* out_mask [[buffer(10)]], - const device bool* lhs_mask [[buffer(11)]], - const device bool* rhs_mask [[buffer(12)]], - const constant int* mask_strides [[buffer(13)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - // Appease the compiler - (void)lid; - - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - MN_aligned, - K_aligned>; - - const int tid_y = ((tid.y) << params->swizzle_log) + - ((tid.x) & ((1 << params->swizzle_log) - 1)); - const int tid_x = (tid.x) >> params->swizzle_log; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - if (params->batch_ndim > 1) { - const constant auto* mask_batch_strides = - batch_strides + 2 * params->batch_ndim; - out_mask += - elem_to_loc(tid.z, batch_shape, mask_batch_strides, params->batch_ndim); - - if (has_operand_mask) { - const constant auto* mask_strides_lhs = - mask_batch_strides + params->batch_ndim; - const constant auto* mask_strides_rhs = - mask_strides_lhs + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, - batch_shape, - mask_strides_lhs, - mask_strides_rhs, - params->batch_ndim); - - lhs_mask += batch_offsets.x; - rhs_mask += batch_offsets.y; - } - } else { - out_mask += tid.z * batch_strides[2 * params->batch_ndim]; - if (has_operand_mask) { - lhs_mask += tid.z * batch_strides[3 * params->batch_ndim]; - rhs_mask += tid.z * batch_strides[4 * params->batch_ndim]; - } - } - - // Adjust for batch - if (params->batch_ndim > 1) { - const constant auto* A_bstrides = batch_strides; - const constant auto* B_bstrides = batch_strides + params->batch_ndim; - - ulong2 batch_offsets = elem_to_loc_broadcast( - tid.z, batch_shape, A_bstrides, B_bstrides, params->batch_ndim); - - A += batch_offsets.x; - B += batch_offsets.y; - - } else { - A += params->batch_stride_a * tid.z; - B += params->batch_stride_b * tid.z; - } - - D += params->batch_stride_d * tid.z; - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - D += c_row_long * params->ldd + c_col_long; - - bool mask_out = out_mask[tid_y * mask_strides[1] + tid_x * mask_strides[0]]; - - // Write zeros and return - if (!mask_out) { - constexpr short tgp_size = WM * WN * 32; - constexpr short vec_size = 4; - - // Tile threads in threadgroup - constexpr short TN = BN / vec_size; - constexpr short TM = tgp_size / TN; - - const short thread_idx = simd_group_id * 32 + simd_lane_id; - const short bi = thread_idx / TN; - const short bj = vec_size * (thread_idx % TN); - - D += bi * params->ldd + bj; - - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - - if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - for (short ti = 0; ti < BM; ti += TM) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - D[ti * params->ldd + j] = T(0.); - } - } - } else { - short jmax = tgp_bn - bj; - jmax = jmax < vec_size ? jmax : vec_size; - for (short ti = 0; (bi + ti) < tgp_bm; ti += TM) { - for (short j = 0; j < jmax; j++) { - D[ti * params->ldd + j] = T(0.); - } - } - } - - return; - } - - threadgroup_barrier(mem_flags::mem_none); - - // Prepare threadgroup mma operation - thread typename gemm_kernel::mma_t mma_op(simd_group_id, simd_lane_id); - - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - // Prepare threadgroup loading operations - thread typename gemm_kernel::loader_a_t loader_a( - A, params->lda, As, simd_group_id, simd_lane_id); - thread typename gemm_kernel::loader_b_t loader_b( - B, params->ldb, Bs, simd_group_id, simd_lane_id); - - /////////////////////////////////////////////////////////////////////////////// - // MNK aligned loop - if (MN_aligned) { - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (!has_operand_mask || - (lhs_mask - [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && - rhs_mask - [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - threadgroup_barrier(mem_flags::mem_none); - - // Loop tail - if (!K_aligned) { - if (!has_operand_mask || - (lhs_mask - [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && - rhs_mask - [(params->K / BM) * mask_strides[5] + - tid_x * mask_strides[4]])) { - int lbk = params->K - params->gemm_k_iterations_aligned * BK; - short2 tile_dims_A = transpose_a ? short2(BM, lbk) : short2(lbk, BM); - short2 tile_dims_B = transpose_b ? short2(lbk, BN) : short2(BN, lbk); - - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - } - - // Store results to device memory - mma_op.store_result(D, params->ldd); - return; - - } - /////////////////////////////////////////////////////////////////////////////// - // MN unaligned loop - else { // Loop over K - unaligned case - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - short lbk = params->K - params->gemm_k_iterations_aligned * BK; - - bool M_aligned = (tgp_bm == BM); - bool N_aligned = (tgp_bn == BN); - - short2 tile_dims_A = transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm); - short2 tile_dims_B = transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK); - - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - if (!has_operand_mask || - (lhs_mask - [tid_y * mask_strides[3] + ((k * BK) / BM) * mask_strides[2]] && - rhs_mask - [((k * BK) / BM) * mask_strides[5] + tid_x * mask_strides[4]])) { - // Load elements into threadgroup - if (M_aligned) { - loader_a.load_unsafe(); - } else { - loader_a.load_safe(tile_dims_A); - } - - if (N_aligned) { - loader_b.load_unsafe(); - } else { - loader_b.load_safe(tile_dims_B); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - } - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - - if (!K_aligned) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - if (!has_operand_mask || - (lhs_mask - [tid_y * mask_strides[3] + (params->K / BM) * mask_strides[2]] && - rhs_mask - [(params->K / BM) * mask_strides[5] + - tid_x * mask_strides[4]])) { - short2 tile_dims_A_last = - transpose_a ? short2(tgp_bm, lbk) : short2(lbk, tgp_bm); - short2 tile_dims_B_last = - transpose_b ? short2(lbk, tgp_bn) : short2(tgp_bn, lbk); - - loader_a.load_safe(tile_dims_A_last); - loader_b.load_safe(tile_dims_B_last); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - mma_op.mma(As, Bs); - } - } - - if (M_aligned && N_aligned) { - mma_op.store_result(D, params->ldd); - } else { - mma_op.store_result_safe(D, params->ldd, short2(tgp_bn, tgp_bm)); - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h deleted file mode 100644 index 5a43e223..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_segmented.h +++ /dev/null @@ -1,266 +0,0 @@ -// Copyright © 2025 Apple Inc. - -using namespace mlx::steel; - -constant bool segments_contiguous [[function_constant(199)]]; -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; - -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void segmented_mm( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - const device uint32_t* segments [[buffer(2)]], - device T* C [[buffer(3)]], - const constant GEMMParams* params [[buffer(4)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]]) { - using gemm_kernel = GEMMKernel< - T, - T, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - true, - true, - AccumType>; - - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - if (params->tiles_n <= static_cast(tid.x) || - params->tiles_m <= static_cast(tid.y)) { - return; - } - - // Prepare threadgroup memory - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - // Find the block in A, B, C - const int c_row = tid.y * BM; - const int c_col = tid.x * BN; - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - - // Prepare threadgroup bounds - const short tgp_bm = align_M ? BM : short(min(BM, params->M - c_row)); - const short tgp_bn = align_N ? BN : short(min(BN, params->N - c_col)); - - // Move the pointers to the output tile - A += transpose_a ? c_row_long : c_row_long * params->lda; - B += transpose_b ? c_col_long * params->ldb : c_col_long; - C += c_row_long * params->ldd + c_col_long; - - // Move the pointers to the start of the segment - uint32_t k_start, k_end; - if (segments_contiguous) { - k_start = segments[2 * tid.z]; - k_end = segments[2 * tid.z + 1]; - } else { - // We accept either contiguous (above) or weird strides where the beginning - // of the next one is the previous one. Basically the last two strides are - // both 1! - k_start = segments[tid.z]; - k_end = segments[tid.z + 1]; - } - A += transpose_a ? k_start * params->lda : k_start; - B += transpose_b ? k_start : k_start * params->ldb; - C += tid.z * params->batch_stride_d; - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Matrix level alignment so only check K - if (align_M && align_N) { - uint32_t k = k_start + BK; - for (; k <= k_end; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - short k_remain = BK - short(k - k_end); - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - if (k_remain > 0) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); - } - mma_op.store_result(C, params->ldd); - } else { - // Tile aligned do the same as above - if ((align_M || tgp_bm == BM) && (align_N || tgp_bn == BN)) { - uint32_t k = k_start + BK; - for (; k <= k_end; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - short k_remain = BK - short(k - k_end); - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - if (k_remain > 0) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); - } - mma_op.store_result(C, params->ldd); - } - - // Tile partially aligned check rows - else if (align_N || tgp_bn == BN) { - uint32_t k = k_start + BK; - for (; k <= k_end; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup - loader_a.load_safe( - transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); - loader_b.load_unsafe(); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - short k_remain = BK - short(k - k_end); - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - if (k_remain > 0) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); - } - mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); - } - - // Tile partially aligned check cols - else if (align_M || tgp_bm == BM) { - uint32_t k = k_start + BK; - for (; k <= k_end; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_safe( - transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - short k_remain = BK - short(k - k_end); - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - if (k_remain > 0) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); - } - mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); - } - - // Nothing aligned so check both rows and cols - else { - uint32_t k = k_start + BK; - for (; k <= k_end; k += BK) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Load elements into threadgroup - loader_a.load_safe( - transpose_a ? short2(tgp_bm, BK) : short2(BK, tgp_bm)); - loader_b.load_safe( - transpose_b ? short2(BK, tgp_bn) : short2(tgp_bn, BK)); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); - - // Prepare for next iteration - loader_a.next(); - loader_b.next(); - } - short k_remain = BK - short(k - k_end); - const short2 tile_dims_A = - transpose_a ? short2(tgp_bm, k_remain) : short2(k_remain, tgp_bm); - const short2 tile_dims_B = - transpose_b ? short2(k_remain, tgp_bn) : short2(tgp_bn, k_remain); - if (k_remain > 0) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_a.load_safe(tile_dims_A); - loader_b.load_safe(tile_dims_B); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); - } - mma_op.store_result_safe(C, params->ldd, short2(tgp_bn, tgp_bm)); - } - } -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h deleted file mode 100644 index a372e939..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk.h +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright © 2024 Apple Inc. - -using namespace mlx::steel; - -/////////////////////////////////////////////////////////////////////////////// -// GEMM kernels -/////////////////////////////////////////////////////////////////////////////// - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - bool MN_aligned, - bool K_aligned> -[[kernel, max_total_threads_per_threadgroup(WM * WN * 32)]] void gemm_splitk( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device U* C [[buffer(2)]], - const constant GEMMSpiltKParams* params [[buffer(3)]], - uint simd_lane_id [[thread_index_in_simdgroup]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]], - uint3 lid [[thread_position_in_threadgroup]]) { - (void)lid; - - using gemm_kernel = GEMMKernel< - T, - U, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - MN_aligned, - K_aligned>; - using loader_a_t = typename gemm_kernel::loader_a_t; - using loader_b_t = typename gemm_kernel::loader_b_t; - using mma_t = typename gemm_kernel::mma_t; - - threadgroup T As[gemm_kernel::tgp_mem_size_a]; - threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; - - const int tid_x = tid.x; - const int tid_y = tid.y; - const int tid_z = tid.z; - - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - // Find block in A, B, C - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const int k_start = params->split_k_partition_size * tid_z; - - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - const size_t k_start_long = size_t(k_start); - - A += transpose_a ? (c_row_long + k_start_long * params->lda) - : (k_start_long + c_row_long * params->lda); - B += transpose_b ? (k_start_long + c_col_long * params->ldb) - : (c_col_long + k_start_long * params->ldb); - C += (size_t(params->split_k_partition_stride) * tid_z) + - (c_row_long * params->ldc + c_col_long); - - // Prepare threadgroup loading operations - thread loader_a_t loader_a(A, params->lda, As, simd_group_id, simd_lane_id); - thread loader_b_t loader_b(B, params->ldb, Bs, simd_group_id, simd_lane_id); - - // Prepare threadgroup mma operation - thread mma_t mma_op(simd_group_id, simd_lane_id); - - int gemm_k_iterations = params->gemm_k_iterations_aligned; - - short tgp_bm = min(BM, params->M - c_row); - short tgp_bn = min(BN, params->N - c_col); - short leftover_bk = params->K % BK; - - if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } else if (tgp_bn == BN) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } else if (tgp_bm == BM) { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } else { - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iterations, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - if ((tid_z + 1) == (params->split_k_partitions)) { - int gemm_k_iter_remaining = - (params->K - (k_start + params->split_k_partition_size)) / BK; - if (!K_aligned || gemm_k_iter_remaining > 0) - gemm_kernel::gemm_loop( - As, - Bs, - gemm_k_iter_remaining, - loader_a, - loader_b, - mma_op, - tgp_bm, - tgp_bn, - leftover_bk, - LoopAlignment{}); - } - - if (MN_aligned || (tgp_bm == BM && tgp_bn == BN)) { - mma_op.store_result(C, params->ldc); - } else { - mma_op.store_result_safe(C, params->ldc, short2(tgp_bn, tgp_bm)); - } -} - -/////////////////////////////////////////////////////////////////////////////// -// Split k accumulation kernel -/////////////////////////////////////////////////////////////////////////////// - -template < - typename AccT, - typename OutT, - typename Epilogue = TransformNone> -[[kernel]] void gemm_splitk_accum( - const device AccT* C_split [[buffer(0)]], - device OutT* D [[buffer(1)]], - const constant int& k_partitions [[buffer(2)]], - const constant int& partition_stride [[buffer(3)]], - const constant int& ldd [[buffer(4)]], - uint2 gid [[thread_position_in_grid]]) { - // Ajust D and C - D += gid.x + gid.y * size_t(ldd); - C_split += gid.x + gid.y * size_t(ldd); - - size_t offset = 0; - AccT out = 0; - - for (int i = 0; i < k_partitions; i++) { - out += C_split[offset]; - offset += partition_stride; - } - - // Write output - D[0] = Epilogue::apply(out); -} - -template < - typename AccT, - typename OutT, - typename Epilogue = TransformAxpby> -[[kernel]] void gemm_splitk_accum_axpby( - const device AccT* C_split [[buffer(0)]], - device OutT* D [[buffer(1)]], - const constant int& k_partitions [[buffer(2)]], - const constant int& partition_stride [[buffer(3)]], - const constant int& ldd [[buffer(4)]], - const device OutT* C [[buffer(5)]], - const constant int& ldc [[buffer(6)]], - const constant int& fdc [[buffer(7)]], - const constant float& alpha [[buffer(8)]], - const constant float& beta [[buffer(9)]], - uint2 gid [[thread_position_in_grid]]) { - // Ajust D and C - C += gid.x * size_t(fdc) + gid.y * size_t(ldc); - D += gid.x + gid.y * size_t(ldd); - C_split += gid.x + gid.y * size_t(ldd); - - size_t offset = 0; - AccT out = 0; - - for (int i = 0; i < k_partitions; i++) { - out += C_split[offset]; - offset += partition_stride; - } - - // Write output - Epilogue op(alpha, beta); - D[0] = op.apply(out, *C); -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk_nax.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk_nax.h deleted file mode 100644 index 1b6b8280..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_splitk_nax.h +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright © 2026 Apple Inc. - -using namespace mlx::steel; - -constant bool align_M [[function_constant(200)]]; -constant bool align_N [[function_constant(201)]]; - -/////////////////////////////////////////////////////////////////////////////// -// NAX Split-K GEMM kernel -/////////////////////////////////////////////////////////////////////////////// - -// clang-format off -template < - typename T, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - typename AccumType = float> -[[kernel, max_total_threads_per_threadgroup(WM* WN * 32)]] void gemm_splitk_nax( - const device T* A [[buffer(0)]], - const device T* B [[buffer(1)]], - device AccumType* C [[buffer(2)]], - const constant GEMMSpiltKParams* params [[buffer(3)]], - uint simd_group_id [[simdgroup_index_in_threadgroup]], - uint3 tid [[threadgroup_position_in_grid]]) { // clang-format on - - const int linear_tid = tid.x; - - // Compute swizzled tile dimensions - const int tn_swizzled = params->tiles_n << params->swizzle_log; - const int tm_swizzled = - (params->tiles_m + (1 << params->swizzle_log) - 1) >> params->swizzle_log; - const int tiles_per_partition = tn_swizzled * tm_swizzled; - - const int tid_z = linear_tid / tiles_per_partition; - const int xy_flat = linear_tid % tiles_per_partition; - - // Decode 2D grid coordinates in swizzled space - const int grid_x = xy_flat % tn_swizzled; - const int grid_y = xy_flat / tn_swizzled; - - // Apply X-Y swizzle - const int tid_y = (grid_y << params->swizzle_log) + - (grid_x & ((1 << params->swizzle_log) - 1)); - const int tid_x = grid_x >> params->swizzle_log; - - // Exit early - if (params->tiles_n <= tid_x || params->tiles_m <= tid_y) { - return; - } - - // Calculate partition bounds - const int c_row = tid_y * BM; - const int c_col = tid_x * BN; - const int k_start = params->split_k_partition_size * tid_z; - const int k_end = min(k_start + params->split_k_partition_size, params->K); - - const size_t c_row_long = size_t(c_row); - const size_t c_col_long = size_t(c_col); - const size_t k_start_long = size_t(k_start); - - // Adjust pointers for split-K partition - A += transpose_a ? (c_row_long + k_start_long * params->lda) - : (k_start_long + c_row_long * params->lda); - B += transpose_b ? (k_start_long + c_col_long * params->ldb) - : (c_col_long + k_start_long * params->ldb); - C += (size_t(params->split_k_partition_stride) * tid_z) + - (c_row_long * params->ldc + c_col_long); - - // NAX tile configuration - constexpr short UM = 16; - constexpr short UN = 32; - constexpr short UK = 16; - constexpr short SM = BM / WM; - constexpr short SN = BN / WN; - constexpr short SK = 32; - - constexpr short TM = SM / UM; - constexpr short TN = SN / UN; - - // Calculate simdgroup offsets and alignment - const short tm = SM * (simd_group_id / WN); - const short tn = SN * (simd_group_id % WN); - - const int sgp_sm_int = - align_M ? int(SM) : min(int(SM), params->M - (c_row + tm)); - const short sgp_sm = short(sgp_sm_int); - const bool is_unaligned_sm = align_M ? false : (sgp_sm != SM); - - const int sgp_sn_int = - align_N ? int(SN) : min(int(SN), params->N - (c_col + tn)); - const short sgp_sn = short(sgp_sn_int); - const bool is_unaligned_sn = align_N ? false : (sgp_sn != SN); - - A += transpose_a ? tm : (tm * params->lda); - B += transpose_b ? (tn * params->ldb) : tn; - C += tm * params->ldc + tn; - - using DSubTile = NAXSubTile; - NAXTile Dtile; - - // gemm_loop through the partition - // Check K-alignment at runtime (partition-specific) - const int partition_k_size = k_end - k_start; - const int partition_k_iters = partition_k_size / BK; - const bool partition_k_aligned = (partition_k_size % BK) == 0; - - dispatch_bool(partition_k_aligned, [&](auto kAlignedK) { - dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { - dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { - Dtile = gemm_loop< - T, - SM, - SN, - SK, - BK, - transpose_a, - transpose_b, - kAlignedM.value, - kAlignedN.value, - kAlignedK.value, - UM, - UN, - UK, - AccumType>( - A, - B, - params->lda, - params->ldb, - partition_k_size, - partition_k_iters, - sgp_sm, - sgp_sn); - }); - }); - }); - - // Store result - dispatch_bool(align_M || !is_unaligned_sm, [&](auto kAlignedM) { - dispatch_bool(align_N || !is_unaligned_sn, [&](auto kAlignedN) { - if constexpr (kAlignedM && kAlignedN) { - Dtile.store(C, int(params->ldc)); - } else { - Dtile.store_safe(C, int(params->ldc), short2(sgp_sn, sgp_sm)); - } - }); - }); -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h deleted file mode 100644 index cc79de86..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../steel/defines.h" - -/////////////////////////////////////////////////////////////////////////////// -// Loading helper -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template < - typename T, - short BROWS, - short BCOLS, - short dst_ld, - short reduction_dim, - short tgp_size, - short alignment = 1, - short n_reads = (BCOLS * BROWS) / (tgp_size), - short TCOLS = BCOLS / n_reads, - short TROWS = tgp_size / TCOLS> -struct BlockLoader { - STEEL_CONST short n_rows = (BROWS + TROWS - 1) / TROWS; - STEEL_CONST short vec_size = n_reads; - - // Leading dimension for src - const int src_ld; - const int tile_stride; - - // Thread location indices - const short thread_idx; - const short bi; - const short bj; - - // threadgroup and device memory - threadgroup T* dst; - const device T* src; - - struct alignas(alignment * sizeof(T)) ReadVector { - uint8_t v[sizeof(T) * vec_size]; - }; - - /* Constructor */ - METAL_FUNC BlockLoader( - const device T* src_, - const int src_ld_, - threadgroup T* dst_, - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) - : src_ld(src_ld_), - tile_stride(reduction_dim ? BCOLS : BROWS * src_ld), - thread_idx(simd_group_id * 32 + simd_lane_id), - bi(thread_idx / TCOLS), - bj(vec_size * (thread_idx % TCOLS)), - dst(dst_ + bi * dst_ld + bj), - src(src_ + bi * src_ld + bj) {} - - /* Apply operation to threadgroup without bound checking */ - template - METAL_FUNC void apply_inplace_op(thread const UnaryOp& op) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = op.apply(dst[i * dst_ld + j]); - } - } - } - - /* Load from device memory into threadgroup memory - without bound checking */ - METAL_FUNC void load_unsafe() const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - *((threadgroup ReadVector*)(&dst[i * dst_ld])) = - *((const device ReadVector*)(&src[i * src_ld])); - } - } - - /* Load from device memory into threadgroup memory - with bound checking */ - METAL_FUNC void load_safe(short2 src_tile_dim) const { - src_tile_dim = src_tile_dim - short2(bj, bi); - - // Skip loading if thread has no valid reads - if (src_tile_dim.x <= 0 || src_tile_dim.y <= 0) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = T(0); - } - } - return; - } - - // Use fast thread memory for bound checks - bool tmp_idx[vec_size]; - T tmp_val[vec_size]; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < BROWS; i += TROWS) { - // Make sure tmp_idx only contains valid indices - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_idx[j] = (i < src_tile_dim.y) && (j < src_tile_dim.x); - } - - // Read valid indices into tmp_val - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; - } - - // Zero out unneeded values - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); - } - - // Copy values to threadgroup memory - STEEL_PRAGMA_UNROLL - for (short j = 0; j < vec_size; j++) { - dst[i * dst_ld + j] = tmp_val[j]; - } - } - } - - /* Iteration helper */ - METAL_FUNC void next() { - src += tile_stride; - } -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h deleted file mode 100644 index 8b9ddb29..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/mma.h +++ /dev/null @@ -1,1146 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include -#include -#include - -#include "../../steel/defines.h" -#include "../../steel/gemm/transforms.h" -#include "../../steel/utils/integral_constant.h" - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// MMA helper -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template -struct BaseMMAFrag { - static_assert( - kFragRows_ == 8, - "Only 8 x 8 fragment matrices are currently supported"); - static_assert( - kFragCols_ == 8, - "Only 8 x 8 fragment matrices are currently supported"); -}; - -template -struct BaseMMAFrag { - STEEL_CONST int kFragRows = 8; - STEEL_CONST int kFragCols = 8; - - STEEL_CONST int kElemsPerFrag = (kFragRows * kFragCols) / 32; - - STEEL_CONST int kElemRows = 1; - STEEL_CONST int kElemCols = 2; - - static_assert( - kElemRows * kElemCols == kElemsPerFrag, - "MMAFrag shape is not consistent with MMAFrag size"); - - typedef metal::simdgroup_matrix mat_type; - typedef metal::vec frag_type; - - METAL_FUNC static constexpr short2 get_coord( - ushort simd_lane_id [[thread_index_in_simdgroup]]) { - const short qid = simd_lane_id / 4; - const short fm = (qid & 4) + ((simd_lane_id / 2) % 4); - const short fn = (qid & 2) * 2 + (simd_lane_id % 2) * 2; - return short2{fn, fm}; - } - - template - METAL_FUNC static constexpr void - load(thread frag_type& dst, SrcPtrType src, StrX str_x, StrY str_y) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = static_cast(src[i * str_x + j * str_y]); - } - } - } - - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX, - typename OffY> - METAL_FUNC static constexpr void load_safe( - thread frag_type& dst, - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - if ((off_x + i) < lim_x && (off_y + j) < lim_y) { - dst[i * kElemCols + j] = - static_cast(src[(off_x + i) * str_x + (off_x + j) * str_y]); - } else { - dst[i * kElemCols + j] = T(0); - } - } - } - } - - template - METAL_FUNC static constexpr void - store(const thread frag_type& src, DstPtrType dst, StrX str_x, StrY str_y) { - using U = pointer_element_t; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * str_x + j * str_y] = static_cast(src[i * kElemCols + j]); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX, - typename OffY> - METAL_FUNC static constexpr void store_safe( - const thread frag_type& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) { - using U = pointer_element_t; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - if ((off_x + i) < lim_x && (off_y + j) < lim_y) { - dst[(off_x + i) * str_x + (off_y + j) * str_y] = - static_cast(src[i * kElemCols + j]); - } - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename StartX, - typename StopX, - typename StartY, - typename StopY, - typename OffX, - typename OffY> - METAL_FUNC static constexpr void store_slice( - const thread frag_type& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - StartX start_x, - StopX stop_x, - StartY start_y, - StopY stop_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) { - using U = pointer_element_t; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - if ((off_x + i) < stop_x && (off_x + i) >= start_x && - (off_y + j) < stop_y && (off_y + j) >= start_y) { - dst[(off_x + i) * str_x + (off_y + j) * str_y] = - static_cast(src[i * kElemCols + j]); - } - } - } - } - - METAL_FUNC static constexpr void mma( - thread frag_type& D, - thread frag_type& A, - thread frag_type& B, - thread frag_type& C) { - mat_type D_mat; - mat_type A_mat; - mat_type B_mat; - mat_type C_mat; - - reinterpret_cast(A_mat.thread_elements()) = A; - reinterpret_cast(B_mat.thread_elements()) = B; - reinterpret_cast(C_mat.thread_elements()) = C; - - mma(D_mat, A_mat, B_mat, C_mat); - - D = reinterpret_cast(D_mat.thread_elements()); - } - - METAL_FUNC static constexpr void mma( - thread mat_type& D, - thread mat_type& A, - thread mat_type& B, - thread mat_type& C) { - simdgroup_multiply_accumulate(D, A, B, C); - } -}; - -template < - typename T, - int kTileRows_, - int kTileCols_, - class MMAFrag_ = BaseMMAFrag> -struct MMATile { - using MMAFrag_t = MMAFrag_; - using elem_type = T; - STEEL_CONST int kFragRows = MMAFrag_t::kFragRows; - STEEL_CONST int kFragCols = MMAFrag_t::kFragCols; - STEEL_CONST int kElemsPerFrag = MMAFrag_t::kElemsPerFrag; - - STEEL_CONST int kTileRows = kTileRows_; - STEEL_CONST int kTileCols = kTileCols_; - - STEEL_CONST int kRows = kTileRows * kFragRows; - STEEL_CONST int kCols = kTileCols * kFragCols; - - STEEL_CONST int kNumFrags = kTileRows * kTileCols; - STEEL_CONST int kElemsPerTile = kNumFrags * kElemsPerFrag; - - typedef typename MMAFrag_t::mat_type mat_type; - typedef typename MMAFrag_t::frag_type frag_type; - - frag_type val_frags[kNumFrags] = {frag_type(0)}; - - METAL_FUNC MMATile() thread {} - - METAL_FUNC constexpr void clear() { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kNumFrags; ++i) { - val_frags[i] = frag_type(0); - } - } - - METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { - return val_frags[i * kTileCols + j]; - } - - METAL_FUNC constexpr const thread frag_type& frag_at( - const short i, - const short j) const { - return val_frags[i * kTileCols + j]; - } - - METAL_FUNC mat_type mat_at(const short i, const short j) { - mat_type val_mat; - STEEL_PRAGMA_UNROLL - for (short ii = 0; ii < kElemsPerFrag; ++ii) { - val_mat.thread_elements()[ii] = frag_at(i, j)[ii]; - } - return val_mat; - } - - METAL_FUNC thread elem_type* elems() { - return reinterpret_cast(val_frags); - } - - METAL_FUNC const thread elem_type* elems() const { - return reinterpret_cast(val_frags); - } - - template - METAL_FUNC void load(const threadgroup U* src) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::load( - frag_at(i, j), - &( - src[(i * kFragRows) * w_x * str_x + - (j * kFragCols) * w_y * str_y]), - Int{}, - Int{}); - } - } - } - - template - METAL_FUNC void store(threadgroup U* dst) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::store( - frag_at(i, j), - &( - dst[(i * kFragRows) * w_x * str_x + - (j * kFragCols) * w_y * str_y]), - Int{}, - Int{}); - } - } - } - - template - METAL_FUNC void load(const device U* src, const int ld) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::load( - frag_at(i, j), - &(src[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), - ld, - Int<1>{}); - } - } - } - - template - METAL_FUNC void store(device U* dst, const int ld) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - MMAFrag_t::store( - frag_at(i, j), - &(dst[(i * kFragRows) * w_x * ld + (j * kFragCols) * w_y]), - ld, - Int<1>{}); - } - } - } - - template - METAL_FUNC void - load_safe(const device U* src, const int ld, const short2 src_tile_dims) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - MMAFrag_t::load_safe( - frag_at(i, j), - src, - ld, - Int<1>{}, - src_tile_dims.y, - src_tile_dims.x, - (i * kFragRows) * w_x, - (j * kFragCols) * w_y); - } - } - } - - template - METAL_FUNC void - store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - MMAFrag_t::store_safe( - frag_at(i, j), - dst, - ld, - Int<1>{}, - dst_tile_dims.y, - dst_tile_dims.x, - (i * kFragRows) * w_x, - (j * kFragCols) * w_y); - } - } - } - - template - METAL_FUNC void store_slice( - device U* dst, - const int ld, - const short2 start, - const short2 stop) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - MMAFrag_t::store_slice( - frag_at(i, j), - dst, - ld, - Int<1>{}, - start.y, - stop.y, - start.x, - stop.x, - (i * kFragRows) * w_x, - (j * kFragCols) * w_y); - } - } - } -}; - -template -METAL_FUNC void tile_matmad( - thread MMATile& D, - thread MMATile& A, - thread MMATile& B, - thread MMATile& C) { - STEEL_PRAGMA_UNROLL - for (short m = 0; m < M; ++m) { - STEEL_PRAGMA_UNROLL - for (short n = 0; n < N; ++n) { - short n_serp = (m % 2) ? (N - 1 - n) : n; - STEEL_PRAGMA_UNROLL - for (short k = 0; k < K; ++k) { - MMATile::MMAFrag_t::mma( - D.frag_at(m, n_serp), - A.frag_at(m, k), - B.frag_at(k, n_serp), - C.frag_at(m, n_serp)); - } - } - } -} - -template -struct TransformNone { - static METAL_FUNC complex64_t apply(complex64_t x) { - return x; - } - static METAL_FUNC complex64_t apply(complex64_t x, complex64_t) { - return x; - } -}; - -template < - typename T, - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - short lda_tgp, - short ldb_tgp, - typename AccumType = float, - typename Epilogue = TransformNone> -struct BlockMMA { - // MMAFrag size - STEEL_CONST short kFragSize = 8; - using MMAFrag_acc_t = BaseMMAFrag; - - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = kFragSize * WM; - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = kFragSize * WN; - - // Warp tile size along M - STEEL_CONST short TM = BM / (kFragSize * WM); - // Warp tile size along N - STEEL_CONST short TN = BN / (kFragSize * WN); - - // Threadgroup A strides - STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M - STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K - - // Threadgroup B strides - STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K - STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N - - // Threadgroup strides along K - STEEL_CONST short tile_stride_a = kFragSize * A_str_k; - STEEL_CONST short tile_stride_b = kFragSize * B_str_k; - - // Simdgroup matrices - MMATile Atile; - MMATile Btile; - MMATile Ctile; - - // Offsets within threadgroup - short sm; - short sn; - - short As_offset; - short Bs_offset; - - /* Constructor */ - METAL_FUNC BlockMMA( - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) { - // Determine thread position in simdgroup matrix - short tm = kFragSize * (simd_group_id / WN); - short tn = kFragSize * (simd_group_id % WN); - - short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); - sm = simd_coord.y; - sn = simd_coord.x; - - // Determine thread and simdgroup offset - As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // M, K - Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // K, N - - sm += tm; - sn += tn; - } - - /* (BM, BK) X (BK, BN) multiply accumulate function */ - METAL_FUNC void mma(const threadgroup T* As, const threadgroup T* Bs) { - // Adjust for simdgroup and thread location - As += As_offset; - Bs += Bs_offset; - - // Iterate over BK in blocks of kFragSize - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += kFragSize) { - simdgroup_barrier(mem_flags::mem_none); - - Atile.template load(As); - - simdgroup_barrier(mem_flags::mem_none); - - Btile.template load(Bs); - - simdgroup_barrier(mem_flags::mem_none); - - tile_matmad(Ctile, Atile, Btile, Ctile); - - // Progress to next simdgroup tile - As += tile_stride_a; - Bs += tile_stride_b; - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* D, const int ldd) { - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { - Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); - } - - // Adjust for simdgroup and thread location - D += sm * ldd + sn; - - Ctile.template store(D, ldd); - } - - METAL_FUNC void - store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { - Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); - } - - D += sm * ldd + sn; - start -= short2(sn, sm); - stop -= short2(sn, sm); - - // TODO: Check the start as well - if (stop.y <= 0 || stop.x <= 0) { - return; - } - - Ctile.template store_slice(D, ldd, start, stop); - } - - METAL_FUNC void - store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { - Ctile.elems()[i] = Epilogue::apply(Ctile.elems()[i]); - } - - // Adjust for simdgroup and thread location - D += sm * ldd + sn; - dst_tile_dims -= short2(sn, sm); - - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - - Ctile.template store_safe(D, ldd, dst_tile_dims); - } - - /* Apply epilogue */ - template - METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ctile)::kElemsPerTile; i++) { - Ctile.elems()[i] = epilogue_op.apply(Ctile.elems()[i]); - } - } - - /* Apply epilogue */ - template - METAL_FUNC void apply_epilogue( - const device U* C, - const int ldc, - const int fdc, - thread const BinaryEpilogue& epilogue_op) { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = Ctile.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < decltype(Ctile)::kElemsPerFrag; k++) { - accum[k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); - } - } - } - } - - /* Apply epilogue */ - template - METAL_FUNC void apply_epilogue_safe( - const device U* C, - const int ldc, - const int fdc, - short2 dst_tile_dims, - thread const BinaryEpilogue& epilogue_op) { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - dst_tile_dims -= short2(sn, sm); - - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread auto& accum = Ctile.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - - constexpr short kelems = decltype(Ctile)::kElemsPerFrag; - - // Read C - U c_elems[kelems] = {0}; - - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - if ((j * TN_stride + k) < dst_tile_dims.x) { - c_elems[k] = C[offset_c + k * fdc]; - } - } - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - accum[k] = epilogue_op.apply(accum[k], c_elems[k]); - } - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - D += (sm)*ldd + sn; - - constexpr short kelems = decltype(Ctile)::kElemsPerFrag; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = Ctile.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - D[offset_d + k] = epilogue_op.apply(accum[k], C[offset_c + k * fdc]); - } - } - } - } - - METAL_FUNC void store_result_safe( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - short2 dst_tile_dims, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - D += (sm)*ldd + sn; - dst_tile_dims -= short2(sn, sm); - - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - - constexpr short kelems = decltype(Ctile)::kElemsPerFrag; - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in C - thread const auto& accum = Ctile.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int offset_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - if ((j * TN_stride + k) < dst_tile_dims.x) { - D[offset_d + k] = - epilogue_op.apply(accum[k], C[offset_c + k * fdc]); - } - } - } - } - } - } -}; - -template < - typename U, - int BM, - int BN, - int BK, - int WM, - int WN, - bool transpose_a, - bool transpose_b, - short lda_tgp, - short ldb_tgp, - typename AccumType, - typename Epilogue> -struct BlockMMA< - complex64_t, - U, - BM, - BN, - BK, - WM, - WN, - transpose_a, - transpose_b, - lda_tgp, - ldb_tgp, - AccumType, - Epilogue> { - static_assert( - metal::is_same_v, - "BlockMMA expects float accumulators"); - static_assert( - metal::is_same_v, - "For complex BlockMMA, U must be complex64_t; use a different epilogue for projections"); - // MMAFrag size - STEEL_CONST short kFragSize = 8; - using MMAFrag_acc_t = BaseMMAFrag; - - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TM_stride = kFragSize * WM; - // Warp tile simdgroup matrix strides along M - STEEL_CONST short TN_stride = kFragSize * WN; - - // Warp tile size along M - STEEL_CONST short TM = BM / (kFragSize * WM); - // Warp tile size along N - STEEL_CONST short TN = BN / (kFragSize * WN); - - // Threadgroup A strides - STEEL_CONST short A_str_m = transpose_a ? 1 : lda_tgp; // M - STEEL_CONST short A_str_k = transpose_a ? lda_tgp : 1; // K - - // Threadgroup B strides - STEEL_CONST short B_str_k = transpose_b ? 1 : ldb_tgp; // K - STEEL_CONST short B_str_n = transpose_b ? ldb_tgp : 1; // N - - // Threadgroup strides along K - STEEL_CONST short tile_stride_a = kFragSize * A_str_k; - STEEL_CONST short tile_stride_b = kFragSize * B_str_k; - - // When indexing complex as float[2] - STEEL_CONST short A_str_m_f = A_str_m * 2; - STEEL_CONST short A_str_k_f = A_str_k * 2; - STEEL_CONST short B_str_k_f = B_str_k * 2; - STEEL_CONST short B_str_n_f = B_str_n * 2; - STEEL_CONST short tile_stride_a_f = tile_stride_a * 2; - STEEL_CONST short tile_stride_b_f = tile_stride_b * 2; - - // Accumulators (real/imag) - MMATile Ctile_r; - MMATile Ctile_i; - - // Offsets within threadgroup - short sm, sn; - short As_offset, Bs_offset; - - /* Constructor */ - METAL_FUNC BlockMMA( - ushort simd_group_id [[simdgroup_index_in_threadgroup]], - ushort simd_lane_id [[thread_index_in_simdgroup]]) { - // Determine thread position in simdgroup matrix - short tm = kFragSize * (simd_group_id / WN); - short tn = kFragSize * (simd_group_id % WN); - - short2 simd_coord = MMAFrag_acc_t::get_coord(simd_lane_id); - sm = simd_coord.y; - sn = simd_coord.x; - - // Determine thread and simdgroup offset - As_offset = (tm + sm) * A_str_m + (sn)*A_str_k; // (M,K) - Bs_offset = (sm)*B_str_k + (tn + sn) * B_str_n; // (K,N) - - sm += tm; - sn += tn; - } - - /* Karatsuba MMA: 3 real MMAs per K-chunk */ - METAL_FUNC void mma( - const threadgroup complex64_t* As, - const threadgroup complex64_t* Bs) { - // Adjust for simdgroup and thread location - As += As_offset; - Bs += Bs_offset; - threadgroup const float* As_f = - reinterpret_cast(As); - threadgroup const float* Bs_f = - reinterpret_cast(Bs); - - // Iterate over BK in blocks of kFragSize - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < BK; kk += kFragSize) { - simdgroup_barrier(mem_flags::mem_none); - - MMATile Ar, Ai; - Ar.template load(As_f + 0); - Ai.template load(As_f + 1); - - simdgroup_barrier(mem_flags::mem_none); - - MMATile Br, Bi; - Br.template load(Bs_f + 0); - Bi.template load(Bs_f + 1); - - simdgroup_barrier(mem_flags::mem_none); - - // P = Ar*Br ; Q = Ai*Bi ; R = (Ar+Ai)*(Br+Bi) - MMATile P, Q, R; - - tile_matmad(P, Ar, Br, P); - tile_matmad(Q, Ai, Bi, Q); - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ar)::kElemsPerTile; ++i) - Ar.elems()[i] += Ai.elems()[i]; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Br)::kElemsPerTile; ++i) - Br.elems()[i] += Bi.elems()[i]; - - tile_matmad(R, Ar, Br, R); - - // C_r += P - Q ; C_i -= Q - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; ++i) { - const auto p = P.elems()[i]; - const auto q = Q.elems()[i]; - const auto r = R.elems()[i]; - Ctile_r.elems()[i] += (p - q); - Ctile_i.elems()[i] += (r - p - q); - } - - // Progress to next simdgroup tile - As_f += tile_stride_a_f; - Bs_f += tile_stride_b_f; - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result(device U* D, const int ldd) { - // Adjust for simdgroup and thread location - D += sm * ldd + sn; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - thread const auto& r = Ctile_r.frag_at(i, j); - thread const auto& im = Ctile_i.frag_at(i, j); - int off = (i * TM_stride) * ldd + (j * TN_stride); - STEEL_PRAGMA_UNROLL - for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { - D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); - } - } - } - } - - METAL_FUNC void - store_result_slice(device U* D, const int ldd, short2 start, short2 stop) { - D += sm * ldd + sn; - start -= short2(sn, sm); - stop -= short2(sn, sm); - - if (stop.y <= 0 || stop.x <= 0) - return; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; ++i) { - const int row = i * TM_stride; - if (row >= start.y && row < stop.y) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; ++j) { - const int off = row * ldd + (j * TN_stride); - thread const auto& r = Ctile_r.frag_at(i, j); - thread const auto& im = Ctile_i.frag_at(i, j); - - STEEL_PRAGMA_UNROLL - for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; ++k) { - const int col = j * TN_stride + k; - if (col >= start.x && col < stop.x) { - D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); - } - } - } - } - } - } - - METAL_FUNC void - store_result_safe(device U* D, const int ldd, short2 dst_tile_dims) { - D += sm * ldd + sn; - dst_tile_dims -= short2(sn, sm); - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - int off = (i * TM_stride) * ldd + (j * TN_stride); - thread const auto& r = Ctile_r.frag_at(i, j); - thread const auto& im = Ctile_i.frag_at(i, j); - STEEL_PRAGMA_UNROLL - for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { - if ((j * TN_stride + k) < dst_tile_dims.x) { - D[off + k] = Epilogue::apply(complex64_t(r[k], im[k])); - } - } - } - } - } - } - - /* Apply epilogue */ - template - METAL_FUNC void apply_epilogue(thread const UnaryEpilogue& epilogue_op) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < decltype(Ctile_r)::kElemsPerTile; i++) { - complex64_t out = epilogue_op.apply( - complex64_t(Ctile_r.elems()[i], Ctile_i.elems()[i])); - Ctile_r.elems()[i] = out.real; - Ctile_i.elems()[i] = out.imag; - } - } - - /* Apply epilogue */ - template - METAL_FUNC void apply_epilogue( - const device U* C, - const int ldc, - const int fdc, - thread const BinaryEpilogue& epilogue_op) { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in Cr, Ci - thread auto& r = Ctile_r.frag_at(i, j); - thread auto& im = Ctile_i.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - - STEEL_PRAGMA_UNROLL - for (short k = 0; k < decltype(Ctile_r)::kElemsPerFrag; k++) { - complex64_t out = epilogue_op.apply( - complex64_t(r[k], im[k]), C[offset_c + k * fdc]); - r[k] = out.real; - im[k] = out.imag; - } - } - } - } - - /* Apply epilogue */ - template - METAL_FUNC void apply_epilogue_safe( - const device U* C, - const int ldc, - const int fdc, - short2 dst_tile_dims, - thread const BinaryEpilogue& epilogue_op) { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - dst_tile_dims -= short2(sn, sm); - - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in Cr, Ci - thread auto& r = Ctile_r.frag_at(i, j); - thread auto& im = Ctile_i.frag_at(i, j); - int offset_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - - constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; - complex64_t tmp[kelems]; - - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - if ((j * TN_stride + k) < dst_tile_dims.x && - (i * TM_stride) < dst_tile_dims.y) { - tmp[k] = C[offset_c + k * fdc]; - } else { - tmp[k] = complex64_t(0.0f, 0.0f); - } - } - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - complex64_t out = epilogue_op.apply(complex64_t(r[k], im[k]), tmp[k]); - r[k] = out.real; - im[k] = out.imag; - } - } - } - } - - /* Store results from simdgroup_matrix results into device memory */ - METAL_FUNC void store_result( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - D += (sm)*ldd + sn; - - constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; - - // Loop over all simdgroup tiles - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; j++) { - // Get accumulated result and associated offset in Cr, Ci - thread const auto& r = Ctile_r.frag_at(i, j); - thread const auto& im = Ctile_i.frag_at(i, j); - int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int off_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - D[off_d + k] = - epilogue_op.apply(complex64_t(r[k], im[k]), C[off_c + k * fdc]); - } - } - } - } - - METAL_FUNC void store_result_safe( - device U* D, - const int ldd, - const device U* C, - const int ldc, - const int fdc, - short2 dst_tile_dims, - thread const Epilogue& epilogue_op) const { - // Adjust for simdgroup and thread location - C += (sm)*ldc + (sn)*fdc; - D += (sm)*ldd + sn; - dst_tile_dims -= short2(sn, sm); - - if (dst_tile_dims.x <= 0 || dst_tile_dims.y <= 0) - return; - - constexpr short kelems = decltype(Ctile_r)::kElemsPerFrag; - - STEEL_PRAGMA_UNROLL - for (int i = 0; i < TM; i++) { - if (i * TM_stride < dst_tile_dims.y) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < TN; j++) { - // Get accumulated result and associated offset in Cr, Ci - thread const auto& r = Ctile_r.frag_at(i, j); - thread const auto& im = Ctile_i.frag_at(i, j); - int off_c = (i * TM_stride) * ldc + (j * TN_stride) * fdc; - int off_d = (i * TM_stride) * ldd + (j * TN_stride); - - // Apply epilogue - STEEL_PRAGMA_UNROLL - for (short k = 0; k < kelems; k++) { - if ((j * TN_stride + k) < dst_tile_dims.x) { - D[off_d + k] = epilogue_op.apply( - complex64_t(r[k], im[k]), C[off_c + k * fdc]); - } - } - } - } - } - } -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/nax.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/nax.h deleted file mode 100644 index 740068be..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/nax.h +++ /dev/null @@ -1,1084 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include -#include -#include - -#include "../../steel/defines.h" -#include "../../steel/gemm/transforms.h" -#include "../../steel/utils/integral_constant.h" - -#include - -using namespace metal; - -/////////////////////////////////////////////////////////////////////////////// -// MMA helper -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -/////////////////////////////////////////////////////////////////////////////// -// NAX Steel with new tiles -/////////////////////////////////////////////////////////////////////////////// - -struct BaseNAXFrag { - STEEL_CONST short kFragRows = 16; - STEEL_CONST short kFragCols = 16; - - STEEL_CONST short kElemsPerFrag = (kFragRows * kFragCols) / 32; - - STEEL_CONST short kElemRows = 2; - STEEL_CONST short kElemCols = 4; - - STEEL_CONST short kElemRowsJump = 8; - - static_assert( - kElemRows * kElemCols == kElemsPerFrag, - "MMAFrag shape is not consistent with MMAFrag size"); - - template - using dtype_frag_t = typename metal::vec; - - METAL_FUNC static short2 get_coord() { - const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); - const short qid = simd_lane_id >> 2; - const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)); - const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4; - return short2{fn, fm}; - } - - METAL_FUNC static short2 get_coord(short idx) { - const ushort simd_lane_id = __metal_get_thread_index_in_simdgroup(ushort()); - const short qid = simd_lane_id >> 2; - const short fm = ((qid & 4) | ((simd_lane_id >> 1) & 3)) + (idx >> 2) * 8; - const short fn = ((qid & 2) | (simd_lane_id & 1)) * 4 + idx % 4; - return short2{fn, fm}; - } - - template < - typename T, - typename SrcPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void load( - thread dtype_frag_t& dst, - SrcPtrType src, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) { - const short2 sc = get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - if constexpr (metal::is_same_v>) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = static_cast(src[r * str_x + c + j]); - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = - static_cast(src[r * str_x + (c + j) * str_y]); - } - } - } - } - - template < - typename T, - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void load_rows( - thread dtype_frag_t& dst, - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) { - const short2 sc = get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - if (r < lim_x) { - if constexpr (metal::is_same_v>) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = static_cast(src[r * str_x + (c + j)]); - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[i * kElemCols + j] = - static_cast(src[r * str_x + (c + j) * str_y]); - } - } - - } else { - dst = dtype_frag_t(0); - } - } - } - - template < - typename T, - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void load_safe( - thread dtype_frag_t& dst, - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) { - const short2 sc = get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - if (r < lim_x && (c + j) < lim_y) { - dst[i * kElemCols + j] = - static_cast(src[r * str_x + (c + j) * str_y]); - } else { - dst[i * kElemCols + j] = T(0); - } - } - } - } - - template < - typename T, - typename DstPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void store( - const thread dtype_frag_t& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) { - using U = pointer_element_t; - - const short2 sc = get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - if constexpr (metal::is_same_v>) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[r * str_x + (c + j) * str_y] = - static_cast(src[i * kElemCols + j]); - } - } - } - } - - template < - typename T, - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void store_rows( - const thread dtype_frag_t& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) { - using U = pointer_element_t; - - const short2 sc = get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - if (r < lim_x) { - if constexpr (metal::is_same_v>) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[r * str_x + c + j] = static_cast(src[i * kElemCols + j]); - } - } else { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - dst[r * str_x + (c + j) * str_y] = - static_cast(src[i * kElemCols + j]); - } - } - } - } - } - - template < - typename T, - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void store_safe( - const thread dtype_frag_t& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) { - using U = pointer_element_t; - - const short2 sc = get_coord(); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - const auto r = off_x + i * kElemRowsJump + sc.y; - const auto c = off_y + sc.x; - - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - if (r < lim_x && (c + j) < lim_y) { - dst[r * str_x + (c + j) * str_y] = - static_cast(src[i * kElemCols + j]); - } - } - } - } - - template < - typename T, - typename DstPtrType, - typename StrX, - typename StrY, - typename StartX, - typename StopX, - typename StartY, - typename StopY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC static constexpr void store_slice( - const thread dtype_frag_t& src, - DstPtrType dst, - StrX str_x, - StrY str_y, - StartX start_x, - StopX stop_x, - StartY start_y, - StopY stop_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) { - using U = pointer_element_t; - - const short2 sc = get_coord(); - - const_for_loop<0, kElemRows, 1>([&](auto idx_row) { - const auto r = off_x + idx_row * Int{}; - if (r >= stop_x - sc.y || r < start_x - sc.y) { - return; - } - - const_for_loop<0, kElemCols, 1>([&](auto idx_col) { - const auto c = off_y + idx_col; - if (c >= stop_y - sc.x || c < start_y - sc.x) { - return; - } - - const auto src_idx = idx_row * Int{} + idx_col; - dst[(r + sc.y) * str_x + (c + sc.x) * str_y] = - static_cast(src[src_idx]); - }); - }); - } - - template - METAL_FUNC static constexpr void row_reduce( - thread const dtype_frag_t& inp_vals, - thread T* reduced_vals) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - T thr_reduce = Op::apply( - Op::apply(inp_vals[i * kElemCols + 0], inp_vals[i * kElemCols + 1]), - Op::apply(inp_vals[i * kElemCols + 2], inp_vals[i * kElemCols + 3])); - - T qgr_reduce = simd_shuffle_xor(thr_reduce, ushort(1)); - qgr_reduce = Op::apply(thr_reduce, qgr_reduce); - - T sgr_reduce = simd_shuffle_xor(qgr_reduce, ushort(8)); - sgr_reduce = Op::apply(qgr_reduce, sgr_reduce); - - reduced_vals[i] = Op::apply(reduced_vals[i], sgr_reduce); - } - } - - template - METAL_FUNC static constexpr void row_bin_op( - thread dtype_frag_t& inp_vals, - thread T* row_vals) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kElemRows; i++) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kElemCols; j++) { - inp_vals[i * kElemCols + j] = - Op::apply(inp_vals[i * kElemCols + j], row_vals[i]); - } - } - } -}; - -template < - typename T, - short kRows_, - short kCols_, - typename NAXFrag_t = BaseNAXFrag> -struct NAXSubTile { - STEEL_CONST short kRows = kRows_; - STEEL_CONST short kCols = kCols_; - - STEEL_CONST short kFragRows = NAXFrag_t::kFragRows; - STEEL_CONST short kFragCols = NAXFrag_t::kFragCols; - STEEL_CONST short kElemsPerFrag = NAXFrag_t::kElemsPerFrag; - - STEEL_CONST short kSubTileRows = kRows / kFragRows; - STEEL_CONST short kSubTileCols = kCols / kFragCols; - - STEEL_CONST short kNumFrags = kSubTileRows * kSubTileCols; - STEEL_CONST short kElemsPerSubTile = kNumFrags * kElemsPerFrag; - - STEEL_CONST int kRowsPerThread = kSubTileRows * NAXFrag_t::kElemRows; - STEEL_CONST int kColsPerThread = kSubTileCols * NAXFrag_t::kElemCols; - - STEEL_CONST short kFragThrRows = NAXFrag_t::kElemRows; - STEEL_CONST short kFragThrCols = NAXFrag_t::kElemCols; - STEEL_CONST short kFragRowsJump = NAXFrag_t::kElemRowsJump; - - using frag_type = typename NAXFrag_t::template dtype_frag_t; - - frag_type val_frags[kNumFrags]; - - METAL_FUNC constexpr void clear() { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kNumFrags; ++i) { - val_frags[i] = frag_type(0); - } - } - - METAL_FUNC constexpr thread frag_type& frag_at(const short i, const short j) { - return val_frags[i * kSubTileCols + j]; - } - - METAL_FUNC constexpr const thread frag_type& frag_at( - const short i, - const short j) const { - return val_frags[i * kSubTileCols + j]; - } - - template - METAL_FUNC constexpr thread frag_type& frag_at() { - return val_frags[i * kSubTileCols + j]; - } - - template - METAL_FUNC constexpr const thread frag_type& frag_at() const { - return val_frags[i * kSubTileCols + j]; - } - - METAL_FUNC thread T* elems() { - return reinterpret_cast(val_frags); - } - - METAL_FUNC const thread T* elems() const { - return reinterpret_cast(val_frags); - } - - template - METAL_FUNC void row_reduce(thread metal::vec& vals) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::template row_reduce( - frag_at(i, j), &vals[i * kFragThrRows]); - } - } - } - - template - METAL_FUNC void row_bin_op(thread metal::vec& vals) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::template row_bin_op( - frag_at(i, j), &vals[i * kFragThrRows]); - } - } - } - - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load( - SrcPtrType src, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load( - frag_at(i, j), - src, - str_x, - str_y, - off_x + i * kFragRows, - off_y + j * kFragCols); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store( - DstPtrType dst, - StrX str_x, - StrY str_y, - OffX off_x = {}, - OffY off_y = {}) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store( - frag_at(i, j), - dst, - str_x, - str_y, - off_x + i * kFragRows, - off_y + j * kFragCols); - } - } - } - - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load_rows( - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load_rows( - frag_at(i, j), - src, - str_x, - str_y, - lim_x, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } - } - } - - template < - typename SrcPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void load_safe( - SrcPtrType src, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::load_safe( - frag_at(i, j), - src, - str_x, - str_y, - lim_x, - lim_y, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename LimY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_safe( - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - LimY lim_y, - OffX off_x = {}, - OffY off_y = {}) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store_safe( - frag_at(i, j), - dst, - str_x, - str_y, - lim_x, - lim_y, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename LimX, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_rows( - DstPtrType dst, - StrX str_x, - StrY str_y, - LimX lim_x, - OffX off_x = {}, - OffY off_y = {}) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kSubTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kSubTileCols; ++j) { - NAXFrag_t::store_safe( - frag_at(i, j), - dst, - str_x, - str_y, - lim_x, - off_x + (i * kFragRows), - off_y + (j * kFragCols)); - } - } - } - - template < - typename DstPtrType, - typename StrX, - typename StrY, - typename StartX, - typename StopX, - typename StartY, - typename StopY, - typename OffX = Int<0>, - typename OffY = Int<0>> - METAL_FUNC constexpr void store_slice( - DstPtrType dst, - StrX str_x, - StrY str_y, - StartX start_x, - StopX stop_x, - StartY start_y, - StopY stop_y, - OffX off_x = Int<0>{}, - OffY off_y = Int<0>{}) const { - const_for_loop<0, kSubTileRows, 1>([&](auto idx_row) { - const_for_loop<0, kSubTileCols, 1>([&](auto idx_col) { - NAXFrag_t::store_slice( - frag_at(), - dst, - str_x, - str_y, - start_x, - stop_x, - start_y, - stop_y, - off_x + idx_row * Int{}, - off_y + idx_col * Int{}); - }); - }); - } -}; - -template < - short RC, - short CC, - short RA, - short CA, - short RB, - short CB, - typename CType, - typename AType, - typename BType, - bool transpose_a, - bool transpose_b, - typename NAXFrag_t = BaseNAXFrag> -METAL_FUNC void subtile_matmad_nax( - thread NAXSubTile& C, - thread NAXSubTile& A, - metal::bool_constant, - thread NAXSubTile& B, - metal::bool_constant) { - // Static checks - constexpr short FMa = transpose_a ? CA : RA; - constexpr short FMc = RC; - static_assert(FMa == FMc, "NAX matmul: M dimensions do not match"); - - constexpr short FNb = transpose_b ? RB : CB; - constexpr short FNc = CC; - static_assert(FNb == FNc, "NAX matmul: N dimensions do not match"); - - constexpr short FKa = transpose_a ? RA : CA; - constexpr short FKb = transpose_b ? CB : RB; - static_assert(FKa == FKb, "NAX matmul: N dimensions do not match"); - - constexpr short FM = FMc; - constexpr short FN = FNc; - constexpr short FK = FKa; - - constexpr int TM = FM / 16; - constexpr int TN = FN / 16; - constexpr int TK = FK / 16; - - // Create Matmul descriptor - constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor( - FM, - FN, - FK, - transpose_a, - transpose_b, - true, - mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); - - // Create matmul op - mpp::tensor_ops::matmul2d gemm_op; - - // Create matmul operands in registers - auto ct_a = - gemm_op.template get_left_input_cooperative_tensor(); - auto ct_b = - gemm_op - .template get_right_input_cooperative_tensor(); - - // Create matmul output in register - auto ct_c = gemm_op.template get_destination_cooperative_tensor< - decltype(ct_a), - decltype(ct_b), - CType>(); - - // Load A in to left operand registers - STEEL_PRAGMA_UNROLL - for (short mm = 0; mm < TM; mm++) { - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < TK; kk++) { - const short fi = transpose_a ? kk : mm; - const short fj = transpose_a ? mm : kk; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < 8; i++) { - ct_a[(TK * mm + kk) * 8 + i] = A.frag_at(fi, fj)[i]; - } - } - } - - // Load B into right operand registers - STEEL_PRAGMA_UNROLL - for (short nn = 0; nn < TN; nn++) { - STEEL_PRAGMA_UNROLL - for (short kk = 0; kk < TK; kk++) { - const short fi = transpose_b ? nn : kk; - const short fj = transpose_b ? kk : nn; - - STEEL_PRAGMA_UNROLL - for (short i = 0; i < 8; i++) { - ct_b[(TN * kk + nn) * 8 + i] = B.frag_at(fi, fj)[i]; - } - } - } - - // Load C into output registers (op handles accumulation) - STEEL_PRAGMA_UNROLL - for (short i = 0; i < ct_c.get_capacity(); i++) { - ct_c[i] = C.elems()[i]; - } - - // Do matmul - gemm_op.run(ct_a, ct_b, ct_c); - - // Copy out results - STEEL_PRAGMA_UNROLL - for (short i = 0; i < ct_c.get_capacity(); i++) { - C.elems()[i] = ct_c[i]; - } -} - -template -struct NAXTile { - using NAXSubTile_t = NAXSubTile_; - using elem_type = T; - STEEL_CONST short kSubTileRows = NAXSubTile_t::kRows; - STEEL_CONST short kSubTileCols = NAXSubTile_t::kCols; - STEEL_CONST short kElemsPerSubTile = NAXSubTile_t::kElemsPerSubTile; - - STEEL_CONST short kTileRows = kTileRows_; - STEEL_CONST short kTileCols = kTileCols_; - - STEEL_CONST short kRows = kTileRows * kSubTileRows; - STEEL_CONST short kCols = kTileCols * kSubTileCols; - - STEEL_CONST short kSubTiles = kTileRows * kTileCols; - STEEL_CONST short kElemsPerTile = kSubTiles * kElemsPerSubTile; - - STEEL_CONST short kRowsPerThread = kTileRows * NAXSubTile_t::kRowsPerThread; - STEEL_CONST short kColsPerThread = kTileCols * NAXSubTile_t::kColsPerThread; - - STEEL_CONST short kSubTileThrRows = NAXSubTile_t::kRowsPerThread; - STEEL_CONST short kSubTileThrCols = NAXSubTile_t::kColsPerThread; - - NAXSubTile_t val_subtiles[kSubTiles]; - - METAL_FUNC NAXTile() thread {} - - METAL_FUNC constexpr void clear() { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kSubTiles; ++i) { - val_subtiles[i].clear(); - } - } - - METAL_FUNC constexpr thread NAXSubTile_t& subtile_at( - const short i, - const short j) { - return val_subtiles[i * kTileCols + j]; - } - - METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at( - const short i, - const short j) const { - return val_subtiles[i * kTileCols + j]; - } - - template - METAL_FUNC constexpr const thread NAXSubTile_t& subtile_at() const { - return val_subtiles[i * kTileCols + j]; - } - - METAL_FUNC thread elem_type* elems() { - return reinterpret_cast(val_subtiles[0].elems()); - } - - METAL_FUNC const thread elem_type* elems() const { - return reinterpret_cast(val_subtiles[0].elems()); - } - - template - METAL_FUNC void row_reduce(thread metal::vec& vals) const { - auto sub_rows = (thread metal::vec*)(&vals); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).template row_reduce(sub_rows[i]); - } - } - } - - template - METAL_FUNC void row_bin_op(thread metal::vec& vals) { - auto sub_rows = (thread metal::vec*)(&vals); - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).template row_bin_op(sub_rows[i]); - } - } - } - - template - METAL_FUNC void load(const threadgroup U* src) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load( - src, - Int{}, - Int{}, - i * kSubTileRows, - j * kSubTileCols); - } - } - } - - template - METAL_FUNC void store(threadgroup U* dst) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store( - dst, - Int{}, - Int{}, - i * kSubTileRows, - j * kSubTileCols); - } - } - } - - template - METAL_FUNC void load(const device U* src, const int ld) { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load( - &src[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); - } - } - } - - template - METAL_FUNC void store(device U* dst, const int ld) const { - STEEL_PRAGMA_UNROLL - for (short i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store( - &dst[(i * kSubTileRows * ld + j * kSubTileCols)], ld, Int<1>{}); - } - } - } - - template - METAL_FUNC void - load_rows(const device U* src, const int ld, const short n_rows) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load_rows( - &src[(i * kSubTileRows) * ld + (j * kSubTileCols)], - ld, - Int<1>{}, - n_rows - i * kSubTileRows); - } - } - } - - template - METAL_FUNC void - load_safe(const device U* src, const int ld, const short2 src_tile_dims) { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).load_safe( - src, - ld, - Int<1>{}, - src_tile_dims.y, - src_tile_dims.x, - i * kSubTileRows, - j * kSubTileCols); - } - } - } - - template - METAL_FUNC void store_rows(device U* dst, const int ld, const short n_rows) - const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store_rows( - &dst[(i * kSubTileRows) * ld + (j * kSubTileCols)], - ld, - Int<1>{}, - n_rows - i * kSubTileRows); - } - } - } - - template - METAL_FUNC void - store_safe(device U* dst, const int ld, const short2 dst_tile_dims) const { - STEEL_PRAGMA_UNROLL - for (int i = 0; i < kTileRows; ++i) { - STEEL_PRAGMA_UNROLL - for (int j = 0; j < kTileCols; ++j) { - subtile_at(i, j).store_safe( - dst, - ld, - Int<1>{}, - dst_tile_dims.y, - dst_tile_dims.x, - i * kSubTileRows, - j * kSubTileCols); - } - } - } - - template - METAL_FUNC void store_slice( - device U* dst, - const int ld, - const short2 start, - const short2 stop) const { - const_for_loop<0, kTileRows, 1>([&](auto idx_row) { - const_for_loop<0, kTileCols, 1>([&](auto idx_col) { - subtile_at().store_slice( - dst, - ld, - Int<1>{}, - start.y, - stop.y, - start.x, - stop.x, - idx_row * Int{}, - idx_col * Int{}); - }); - }); - } -}; - -template < - class CTile, - class ATile, - class BTile, - bool transpose_a, - bool transpose_b> -METAL_FUNC void tile_matmad_nax( - thread CTile& C, - thread ATile& A, - metal::bool_constant, - thread BTile& B, - metal::bool_constant) { - // Static checks - constexpr short TMa = transpose_a ? ATile::kTileCols : ATile::kTileRows; - constexpr short TMc = CTile::kTileRows; - static_assert(TMa == TMc, "NAX tile matmul: M dimensions do not match"); - - constexpr short FMa = transpose_a ? ATile::kSubTileCols : ATile::kSubTileRows; - constexpr short FMc = CTile::kSubTileRows; - static_assert(FMa == FMc, "NAX subtile matmul: M dimensions do not match"); - - constexpr short TNb = transpose_b ? BTile::kTileRows : BTile::kTileCols; - constexpr short TNc = CTile::kTileCols; - static_assert(TNb == TNc, "NAX tile matmul: N dimensions do not match"); - - constexpr short FNb = transpose_b ? BTile::kSubTileRows : BTile::kSubTileCols; - constexpr short FNc = CTile::kSubTileCols; - static_assert(FNb == FNc, "NAX subtile matmul: N dimensions do not match"); - - constexpr short TKa = transpose_a ? ATile::kTileRows : ATile::kTileCols; - constexpr short TKb = transpose_b ? BTile::kTileCols : BTile::kTileRows; - static_assert(TKa == TKb, "NAX tile matmul: K dimensions do not match"); - - constexpr short FKa = transpose_a ? ATile::kSubTileRows : ATile::kSubTileCols; - constexpr short FKb = transpose_b ? BTile::kSubTileCols : BTile::kSubTileRows; - static_assert(FKa == FKb, "NAX subtile matmul: K dimensions do not match"); - - constexpr short TM = TMc; - constexpr short TN = TNc; - constexpr short TK = TKa; - - // Do matmul here - STEEL_PRAGMA_UNROLL - for (short i = 0; i < TM; ++i) { - STEEL_PRAGMA_UNROLL - for (short j = 0; j < TN; ++j) { - STEEL_PRAGMA_UNROLL - for (short k = 0; k < TK; ++k) { - const short ra = transpose_a ? k : i; - const short ca = transpose_a ? i : k; - const short rb = transpose_b ? j : k; - const short cb = transpose_b ? k : j; - - subtile_matmad_nax( - C.subtile_at(i, j), - A.subtile_at(ra, ca), - metal::bool_constant{}, - B.subtile_at(rb, cb), - metal::bool_constant{}); - } - } - } -} - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/params.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/params.h deleted file mode 100644 index b0ba07dd..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/params.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -/////////////////////////////////////////////////////////////////////////////// -// GEMM param classes -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -struct GEMMParams { - const int M; - const int N; - const int K; - - const int lda; - const int ldb; - const int ldd; - - const int tiles_n; - const int tiles_m; - - const int64_t batch_stride_a; - const int64_t batch_stride_b; - const int64_t batch_stride_d; - - const int swizzle_log; - const int gemm_k_iterations_aligned; - - const int batch_ndim; -}; - -struct GEMMSpiltKParams { - const int M; - const int N; - const int K; - - const int lda; - const int ldb; - const int ldc; - - const int tiles_n; - const int tiles_m; - - const int split_k_partitions; - const int split_k_partition_stride; - const int split_k_partition_size; - - const int swizzle_log; - const int gemm_k_iterations_aligned; -}; - -struct GEMMAddMMParams { - const int ldc; - const int fdc; - - const int64_t batch_stride_c; - - const float alpha; - const float beta; -}; - -} // namespace steel -} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/transforms.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/transforms.h deleted file mode 100644 index 704776ba..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/transforms.h +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include "../../steel/utils.h" - -/////////////////////////////////////////////////////////////////////////////// -// Transforms and Epilogues -/////////////////////////////////////////////////////////////////////////////// - -namespace mlx { -namespace steel { - -template -struct TransformNone { - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT) { - return static_cast(x); - } -}; - -template -struct TransformAdd { - TransformAdd(const float, const float) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - static METAL_FUNC OutT apply(InT x, OutT c) { - return static_cast(x) + c; - } -}; - -template -struct TransformAxpby { - const float alpha; - const float beta; - - TransformAxpby(const float alpha_, const float beta_) - : alpha(alpha_), beta(beta_) {} - - static METAL_FUNC OutT apply(InT x) { - return static_cast(x); - } - - METAL_FUNC OutT apply(InT x, OutT c) const { - return static_cast( - x * static_cast(alpha) + (static_cast(beta) * c)); - } -}; - -template -struct AccumHelper { - typedef float accum_type; -}; - -struct BlockSwizzle { - static METAL_FUNC int2 - swizzle(uint3 tid [[threadgroup_position_in_grid]], const int swizzle_log) { - const int tid_x = (tid.x) >> swizzle_log; - const int tid_y = - ((tid.y) << swizzle_log) + ((tid.x) & ((1 << swizzle_log) - 1)); - return int2(tid_x, tid_y); - } -}; - -} // namespace steel -} // namespace mlx \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/utils.h b/Source/Cmlx/mlx-generated/metal/steel/utils.h deleted file mode 100644 index 55720a28..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/utils.h +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include - -METAL_FUNC ulong2 elem_to_loc_broadcast( - uint elem, - constant const int* shape, - constant const int64_t* a_strides, - constant const int64_t* b_strides, - int ndim) { - ulong loc_a{0}; - ulong loc_b{0}; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - int pos_in_dim = (elem % shape[i]); - elem /= shape[i]; - loc_a += pos_in_dim * a_strides[i]; - loc_b += pos_in_dim * b_strides[i]; - } - return ulong2(loc_a, loc_b); -} - -METAL_FUNC ulong3 elem_to_loc_broadcast( - uint elem, - constant const int* shape, - constant const int64_t* a_strides, - constant const int64_t* b_strides, - constant const int64_t* c_strides, - int ndim) { - ulong loc_a{0}; - ulong loc_b{0}; - ulong loc_c{0}; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - int pos_in_dim = (elem % shape[i]); - elem /= shape[i]; - loc_a += pos_in_dim * a_strides[i]; - loc_b += pos_in_dim * b_strides[i]; - loc_c += pos_in_dim * c_strides[i]; - } - return ulong3(loc_a, loc_b, loc_c); -} diff --git a/Source/Cmlx/mlx-generated/metal/steel/utils/integral_constant.h b/Source/Cmlx/mlx-generated/metal/steel/utils/integral_constant.h deleted file mode 100644 index 40bcff8c..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/utils/integral_constant.h +++ /dev/null @@ -1,134 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include -#include "../../steel/utils/type_traits.h" - -#pragma METAL internals : enable - -namespace mlx { -namespace steel { - -/////////////////////////////////////////////////////////////////////////////// -// Integral constant with casting -/////////////////////////////////////////////////////////////////////////////// - -template -struct integral_constant { - static constexpr constant T value = v; - using value_type = T; - using type = integral_constant; - - METAL_FUNC constexpr operator value_type() const noexcept { - return value; - } - - // METAL_FUNC constexpr value_type operator()() const noexcept { - // return value; - // } -}; - -template -using bool_constant = integral_constant; -using true_type = bool_constant; -using false_type = bool_constant; - -template -struct is_integral : bool_constant::value> {}; - -template -struct is_integral> - : bool_constant::value> {}; - -template -constexpr constant bool is_integral_v = is_integral::value; - -template -using Int = integral_constant; - -/////////////////////////////////////////////////////////////////////////////// -// Binary Operators on Integral constants -/////////////////////////////////////////////////////////////////////////////// - -#define integral_const_binop(__op__, __operator__) \ - template \ - METAL_FUNC constexpr auto __operator__( \ - integral_constant, integral_constant) { \ - constexpr auto res = tv __op__ uv; \ - return integral_constant{}; \ - } - -integral_const_binop(+, operator+); -integral_const_binop(-, operator-); -integral_const_binop(*, operator*); -integral_const_binop(/, operator/); - -integral_const_binop(==, operator==); -integral_const_binop(!=, operator!=); -integral_const_binop(<, operator<); -integral_const_binop(>, operator>); -integral_const_binop(<=, operator<=); -integral_const_binop(>=, operator>=); - -integral_const_binop(&&, operator&&); -integral_const_binop(||, operator||); - -template >> -METAL_FUNC constexpr auto operator||(true_type, T) { - return true_type{}; -} -template >> -METAL_FUNC constexpr auto operator||(T, true_type) { - return true_type{}; -} - -template >> -METAL_FUNC constexpr auto operator&&(false_type, T) { - return false_type{}; -} - -template >> -METAL_FUNC constexpr auto operator&&(T, false_type) { - return false_type{}; -} - -// Dispatch utilities -template -void dispatch_bool(bool v, F f) { - if (v) { - f(true_type{}); - } else { - f(false_type{}); - } -} - -template -constexpr void const_for_loop(F f) { - if constexpr (start < stop) { - constexpr auto idx = Int{}; - f(idx); - const_for_loop(f); - } -} - -#undef integral_const_binop - -/////////////////////////////////////////////////////////////////////////////// -// Reduction operators -/////////////////////////////////////////////////////////////////////////////// - -template -METAL_FUNC constexpr T sum(T x) { - return x; -} - -template -METAL_FUNC constexpr auto sum(T x, Us... us) { - return x + sum(us...); -} - -} // namespace steel -} // namespace mlx - -#pragma METAL internals : disable \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/steel/utils/type_traits.h b/Source/Cmlx/mlx-generated/metal/steel/utils/type_traits.h deleted file mode 100644 index f004dc83..00000000 --- a/Source/Cmlx/mlx-generated/metal/steel/utils/type_traits.h +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright © 2024 Apple Inc. - -#pragma once - -#include - -#pragma METAL internals : enable - -namespace metal { - -template -struct is_empty : metal::bool_constant<__is_empty(T)> {}; - -#ifdef __cpp_variable_templates -template -constexpr constant bool is_empty_v = is_empty::value; -#endif - -template -struct make_void { - typedef void type; -}; - -template -using void_t = typename make_void::type; - -template -struct is_static : metal::bool_constant>::value> {}; - -template -struct pointer_element {}; - -template -struct pointer_element { - using type = remove_cv_t; -}; -template -struct pointer_element { - using type = remove_cv_t; -}; -template -struct pointer_element { - using type = remove_cv_t; -}; -template -struct pointer_element { - using type = remove_cv_t; -}; - -template -using pointer_element_t = typename pointer_element>::type; - -} // namespace metal - -#pragma METAL internals : disable \ No newline at end of file diff --git a/Source/Cmlx/mlx-generated/metal/ternary.h b/Source/Cmlx/mlx-generated/metal/ternary.h deleted file mode 100644 index 705b73e2..00000000 --- a/Source/Cmlx/mlx-generated/metal/ternary.h +++ /dev/null @@ -1,145 +0,0 @@ -// Copyright © 2024 Apple Inc. - -template < - typename T, - typename Op, - bool BSCALAR, - bool CSCALAR, - int N = WorkPerThread::n> -[[kernel]] void ternary_v( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - auto bidx = BSCALAR ? 0 : index + i; - auto cidx = CSCALAR ? 0 : index + i; - d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); - } - } else { - for (int i = 0; i < N; ++i) { - auto bidx = BSCALAR ? 0 : index + i; - auto cidx = CSCALAR ? 0 : index + i; - d[index + i] = Op()(a[index + i], b[bidx], c[cidx]); - } - } -} - -template < - typename T, - typename Op, - bool BSCALAR, - bool CSCALAR, - int N = WorkPerThread::n> -[[kernel]] void ternary_v2( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - auto bidx = BSCALAR ? 0 : offset + i; - auto cidx = CSCALAR ? 0 : offset + i; - d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); - } - } else { - for (int i = 0; i < N; ++i) { - auto bidx = BSCALAR ? 0 : offset + i; - auto cidx = CSCALAR ? 0 : offset + i; - d[offset + i] = Op()(a[offset + i], b[bidx], c[cidx]); - } - } -} - -template -[[kernel]] void ternary_g_nd1( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const int64_t& a_strides, - constant const int64_t& b_strides, - constant const int64_t& c_strides, - uint index [[thread_position_in_grid]]) { - auto a_idx = elem_to_loc_1(index, a_strides); - auto b_idx = elem_to_loc_1(index, b_strides); - auto c_idx = elem_to_loc_1(index, c_strides); - d[index] = Op()(a[a_idx], b[b_idx], c[c_idx]); -} - -template -[[kernel]] void ternary_g_nd2( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const int64_t a_strides[2], - constant const int64_t b_strides[2], - constant const int64_t c_strides[2], - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_2(index, a_strides); - auto b_idx = elem_to_loc_2(index, b_strides); - auto c_idx = elem_to_loc_2(index, c_strides); - IdxT out_idx = index.x + IdxT(grid_dim.x) * index.y; - d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); -} - -template -[[kernel]] void ternary_g_nd3( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const int64_t a_strides[3], - constant const int64_t b_strides[3], - constant const int64_t c_strides[3], - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto a_idx = elem_to_loc_3(index, a_strides); - auto b_idx = elem_to_loc_3(index, b_strides); - auto c_idx = elem_to_loc_3(index, c_strides); - IdxT out_idx = index.x + grid_dim.x * (index.y + IdxT(grid_dim.y) * index.z); - d[out_idx] = Op()(a[a_idx], b[b_idx], c[c_idx]); -} - -template -[[kernel]] void ternary_g( - device const bool* a, - device const T* b, - device const T* c, - device T* d, - constant const int* shape, - constant const int64_t* a_strides, - constant const int64_t* b_strides, - constant const int64_t* c_strides, - constant const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc_3_nd( - {N * index.x, index.y, index.z}, - shape, - a_strides, - b_strides, - c_strides, - ndim); - auto xshape = shape[ndim - 1]; - IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); - IdxT a_xstride = a_strides[ndim - 1]; - IdxT b_xstride = b_strides[ndim - 1]; - IdxT c_xstride = c_strides[ndim - 1]; - for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { - d[out_idx++] = Op()(a[idx.x], b[idx.y], c[idx.z]); - idx.x += a_xstride; - idx.y += b_xstride; - idx.z += c_xstride; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/ternary_ops.h b/Source/Cmlx/mlx-generated/metal/ternary_ops.h deleted file mode 100644 index e0235d9d..00000000 --- a/Source/Cmlx/mlx-generated/metal/ternary_ops.h +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#pragma once - -struct Select { - template - T operator()(bool condition, T x, T y) { - return condition ? x : y; - } -}; diff --git a/Source/Cmlx/mlx-generated/metal/unary.h b/Source/Cmlx/mlx-generated/metal/unary.h deleted file mode 100644 index db7be3d4..00000000 --- a/Source/Cmlx/mlx-generated/metal/unary.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright © 2024 Apple Inc. - -template ::n> -[[kernel]] void unary_v( - device const T* in, - device U* out, - constant uint& size, - uint index [[thread_position_in_grid]]) { - index *= N; - if (N > 1 && index + N > size) { - for (int i = 0; index + i < size; ++i) { - out[index + i] = static_cast(Op()(in[index + i])); - } - } else { - for (int i = 0; i < N; ++i) { - out[index + i] = static_cast(Op()(in[index + i])); - } - } -} - -template ::n> -[[kernel]] void unary_v2( - device const T* in, - device U* out, - constant int64_t& size, - uint2 index [[thread_position_in_grid]], - uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); - if (N > 1 && offset + N > size) { - for (int i = 0; offset + i < size; ++i) { - out[offset + i] = static_cast(Op()(in[offset + i])); - } - } else { - for (int i = 0; i < N; ++i) { - out[offset + i] = static_cast(Op()(in[offset + i])); - } - } -} - -template < - typename T, - typename U, - typename Op, - int N = 1, - typename IdxT = int64_t> -[[kernel]] void unary_g( - device const T* in, - device U* out, - constant const int* in_shape, - constant const int64_t* in_strides, - device const int& ndim, - uint3 index [[thread_position_in_grid]], - uint3 grid_dim [[threads_per_grid]]) { - auto idx = elem_to_loc( - {N * index.x, index.y, index.z}, in_shape, in_strides, ndim); - auto xshape = in_shape[ndim - 1]; - IdxT xstride = in_strides[ndim - 1]; - IdxT out_idx = N * index.x + xshape * (index.y + IdxT(grid_dim.y) * index.z); - for (int i = 0; i < N && (int(N * index.x) + i) < xshape; ++i) { - out[out_idx++] = static_cast(Op()(in[idx])); - idx += xstride; - } -} diff --git a/Source/Cmlx/mlx-generated/metal/unary_ops.h b/Source/Cmlx/mlx-generated/metal/unary_ops.h deleted file mode 100644 index 0ec0febc..00000000 --- a/Source/Cmlx/mlx-generated/metal/unary_ops.h +++ /dev/null @@ -1,454 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#pragma once - -#include -#include - -#include "cexpf.h" -#include "erf.h" -#include "expm1f.h" -#include "fp8.h" - -namespace { -constant float inf = metal::numeric_limits::infinity(); -} - -struct Abs { - template - T operator()(T x) { - return metal::abs(x); - }; - uint8_t operator()(uint8_t x) { - return x; - }; - uint16_t operator()(uint16_t x) { - return x; - }; - uint32_t operator()(uint32_t x) { - return x; - }; - uint64_t operator()(uint64_t x) { - return x; - }; - bool operator()(bool x) { - return x; - }; - complex64_t operator()(complex64_t x) { - return {metal::precise::sqrt(x.real * x.real + x.imag * x.imag), 0}; - }; -}; - -struct ArcCos { - template - T operator()(T x) { - return metal::precise::acos(x); - }; - - complex64_t operator()(complex64_t x); -}; - -struct ArcCosh { - template - T operator()(T x) { - return metal::precise::acosh(x); - }; -}; - -struct ArcSin { - template - T operator()(T x) { - return metal::precise::asin(x); - }; - - complex64_t operator()(complex64_t x); -}; - -struct ArcSinh { - template - T operator()(T x) { - return metal::precise::asinh(x); - }; -}; - -struct ArcTan { - template - T operator()(T x) { - return metal::precise::atan(x); - }; - - complex64_t operator()(complex64_t x); -}; - -struct ArcTanh { - template - T operator()(T x) { - return metal::precise::atanh(x); - }; -}; - -struct BitwiseInvert { - template - T operator()(T x) { - return ~x; - }; -}; - -struct Ceil { - template - T operator()(T x) { - return metal::ceil(x); - }; - int8_t operator()(int8_t x) { - return x; - }; - int16_t operator()(int16_t x) { - return x; - }; - int32_t operator()(int32_t x) { - return x; - }; - int64_t operator()(int64_t x) { - return x; - }; - uint8_t operator()(uint8_t x) { - return x; - }; - uint16_t operator()(uint16_t x) { - return x; - }; - uint32_t operator()(uint32_t x) { - return x; - }; - uint64_t operator()(uint64_t x) { - return x; - }; - bool operator()(bool x) { - return x; - }; -}; - -struct Cos { - template - T operator()(T x) { - return metal::precise::cos(x); - }; - - complex64_t operator()(complex64_t x) { - return { - metal::precise::cos(x.real) * metal::precise::cosh(x.imag), - -metal::precise::sin(x.real) * metal::precise::sinh(x.imag)}; - }; -}; - -struct Cosh { - template - T operator()(T x) { - return metal::precise::cosh(x); - }; - - complex64_t operator()(complex64_t x) { - return { - metal::precise::cosh(x.real) * metal::precise::cos(x.imag), - metal::precise::sinh(x.real) * metal::precise::sin(x.imag)}; - }; -}; - -struct Conjugate { - complex64_t operator()(complex64_t x) { - return complex64_t{x.real, -x.imag}; - } -}; - -struct Erf { - template - T operator()(T x) { - return static_cast(erf(static_cast(x))); - }; -}; - -struct ErfInv { - template - T operator()(T x) { - return static_cast(erfinv(static_cast(x))); - }; -}; - -struct Exp { - template - T operator()(T x) { - return metal::precise::exp(x); - }; - complex64_t operator()(complex64_t x) { - return cexpf(x); - } -}; - -struct Expm1 { - template - T operator()(T x) { - return static_cast(expm1f(static_cast(x))); - }; -}; - -struct Floor { - template - T operator()(T x) { - return metal::floor(x); - }; - int8_t operator()(int8_t x) { - return x; - }; - int16_t operator()(int16_t x) { - return x; - }; - int32_t operator()(int32_t x) { - return x; - }; - int64_t operator()(int64_t x) { - return x; - }; - uint8_t operator()(uint8_t x) { - return x; - }; - uint16_t operator()(uint16_t x) { - return x; - }; - uint32_t operator()(uint32_t x) { - return x; - }; - uint64_t operator()(uint64_t x) { - return x; - }; - bool operator()(bool x) { - return x; - }; -}; - -struct Imag { - float operator()(complex64_t x) { - return x.imag; - }; -}; - -struct Log { - template - T operator()(T x) { - return metal::precise::log(x); - }; - - complex64_t operator()(complex64_t x) { - auto r = metal::precise::log(Abs{}(x).real); - auto i = metal::precise::atan2(x.imag, x.real); - return {r, i}; - }; -}; - -struct Log2 { - template - T operator()(T x) { - return metal::precise::log2(x); - }; - - complex64_t operator()(complex64_t x) { - auto y = Log{}(x); - return {y.real / M_LN2_F, y.imag / M_LN2_F}; - }; -}; - -struct Log10 { - template - T operator()(T x) { - return metal::precise::log10(x); - }; - - complex64_t operator()(complex64_t x) { - auto y = Log{}(x); - return {y.real / M_LN10_F, y.imag / M_LN10_F}; - }; -}; - -struct Log1p { - template - T operator()(T x) { - return log1p(x); - }; -}; - -struct LogicalNot { - template - T operator()(T x) { - return !x; - }; -}; - -struct Negative { - template - T operator()(T x) { - return -x; - }; -}; - -struct Real { - float operator()(complex64_t x) { - return x.real; - }; -}; - -struct Round { - template - T operator()(T x) { - return metal::rint(x); - }; - complex64_t operator()(complex64_t x) { - return {metal::rint(x.real), metal::rint(x.imag)}; - }; -}; - -struct Sigmoid { - template - T operator()(T x) { - auto y = 1 / (1 + metal::exp(metal::abs(x))); - return (x < 0) ? y : 1 - y; - } -}; - -struct Sign { - template - T operator()(T x) { - return (x > T(0)) - (x < T(0)); - }; - uint32_t operator()(uint32_t x) { - return x != 0; - }; - complex64_t operator()(complex64_t x) { - if (x == complex64_t(0)) { - return x; - } - return x / - (complex64_t)metal::precise::sqrt(x.real * x.real + x.imag * x.imag); - }; -}; - -struct Sin { - template - T operator()(T x) { - return metal::precise::sin(x); - }; - - complex64_t operator()(complex64_t x) { - return { - metal::precise::sin(x.real) * metal::precise::cosh(x.imag), - metal::precise::cos(x.real) * metal::precise::sinh(x.imag)}; - }; -}; - -struct Sinh { - template - T operator()(T x) { - return metal::precise::sinh(x); - }; - - complex64_t operator()(complex64_t x) { - return { - metal::precise::sinh(x.real) * metal::precise::cos(x.imag), - metal::precise::cosh(x.real) * metal::precise::sin(x.imag)}; - }; -}; - -struct Square { - template - T operator()(T x) { - return x * x; - }; -}; - -struct Sqrt { - template - T operator()(T x) { - return metal::precise::sqrt(x); - }; - - complex64_t operator()(complex64_t x) { - if (x.real == 0.0 && x.imag == 0.0) { - return {0.0, 0.0}; - } - auto r = Abs{}(x).real; - auto a = metal::precise::sqrt((r + x.real) / 2.0); - auto b_abs = metal::precise::sqrt((r - x.real) / 2.0); - auto b = metal::copysign(b_abs, x.imag); - return {a, b}; - } -}; - -struct Rsqrt { - template - T operator()(T x) { - return metal::precise::rsqrt(x); - }; - - complex64_t operator()(complex64_t x) { - return 1.0 / Sqrt{}(x); - } -}; - -struct Tan { - template - T operator()(T x) { - return metal::precise::tan(x); - }; - - complex64_t operator()(complex64_t x) { - float tan_a = metal::precise::tan(x.real); - float tanh_b = metal::precise::tanh(x.imag); - float t1 = tan_a * tanh_b; - float denom = 1. + t1 * t1; - return {(tan_a - tanh_b * t1) / denom, (tanh_b + tan_a * t1) / denom}; - }; -}; - -struct Tanh { - template - T operator()(T x) { - return metal::precise::tanh(x); - }; - - complex64_t operator()(complex64_t x) { - float tanh_a = metal::precise::tanh(x.real); - float tan_b = metal::precise::tan(x.imag); - float t1 = tanh_a * tan_b; - float denom = 1. + t1 * t1; - return {(tanh_a + tan_b * t1) / denom, (tan_b - tanh_a * t1) / denom}; - }; -}; - -complex64_t ArcCos::operator()(complex64_t x) { - auto i = complex64_t{0.0, 1.0}; - auto y = Log{}(x + i * Sqrt{}(1.0 - x * x)); - return {y.imag, -y.real}; -}; - -complex64_t ArcSin::operator()(complex64_t x) { - auto i = complex64_t{0.0, 1.0}; - auto y = Log{}(i * x + Sqrt{}(1.0 - x * x)); - return {y.imag, -y.real}; -}; - -complex64_t ArcTan::operator()(complex64_t x) { - auto i = complex64_t{0.0, 1.0}; - auto ix = i * x; - return (1.0 / complex64_t{0.0, 2.0}) * Log{}((1.0 + ix) / (1.0 - ix)); -}; - -struct ToFP8 { - template - uint8_t operator()(T f) { - return fp8_e4m3(f).bits; - } -}; - -struct FromFP8 { - float operator()(uint8_t x) { - return float(*(thread fp8_e4m3*)(&x)); - } -}; diff --git a/Source/Cmlx/mlx-generated/metal/utils.h b/Source/Cmlx/mlx-generated/metal/utils.h deleted file mode 100644 index 9651ef06..00000000 --- a/Source/Cmlx/mlx-generated/metal/utils.h +++ /dev/null @@ -1,445 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#pragma once - -#include - -#include "bf16.h" -#include "bf16_math.h" -#include "complex.h" -#include "defines.h" -#include "logging.h" - -typedef half float16_t; - -// Work per thread values for different types. The values here are expected to -// match get_work_per_thread in mlx/backend/metal/utils.h -template -struct WorkPerThread { - static_assert(sizeof(U) <= 8, "Type too large"); - static constexpr int constant n = 8 / sizeof(U); -}; - -/////////////////////////////////////////////////////////////////////////////// -// Type limits utils -/////////////////////////////////////////////////////////////////////////////// - -template -struct Limits { - static const constant U max = metal::numeric_limits::max(); - static const constant U min = metal::numeric_limits::min(); - static const constant U finite_max = metal::numeric_limits::max(); - static const constant U finite_min = metal::numeric_limits::min(); -}; - -#define instantiate_default_limit(type) \ - template <> \ - struct Limits { \ - static constexpr constant type max = metal::numeric_limits::max(); \ - static constexpr constant type min = metal::numeric_limits::min(); \ - static constexpr constant type finite_max = \ - metal::numeric_limits::max(); \ - static constexpr constant type finite_min = \ - metal::numeric_limits::min(); \ - }; - -instantiate_default_limit(uint8_t); -instantiate_default_limit(uint16_t); -instantiate_default_limit(uint32_t); -instantiate_default_limit(uint64_t); -instantiate_default_limit(int8_t); -instantiate_default_limit(int16_t); -instantiate_default_limit(int32_t); -instantiate_default_limit(int64_t); - -#define instantiate_float_limit(type) \ - template <> \ - struct Limits { \ - static constexpr constant type max = \ - metal::numeric_limits::infinity(); \ - static constexpr constant type min = \ - -metal::numeric_limits::infinity(); \ - static constexpr constant type finite_max = \ - metal::numeric_limits::max(); \ - static constexpr constant type finite_min = \ - -metal::numeric_limits::max(); \ - }; - -instantiate_float_limit(half); -instantiate_float_limit(float); -instantiate_float_limit(bfloat16_t); - -template <> -struct Limits { - static constexpr constant bool max = true; - static constexpr constant bool min = false; -}; - -template <> -struct Limits { - static constexpr constant complex64_t max = complex64_t( - metal::numeric_limits::infinity(), - metal::numeric_limits::infinity()); - static constexpr constant complex64_t min = complex64_t( - -metal::numeric_limits::infinity(), - -metal::numeric_limits::infinity()); -}; - -/////////////////////////////////////////////////////////////////////////////// -// Indexing utils -/////////////////////////////////////////////////////////////////////////////// - -#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") - -/////////////////////////////////////////////////////////////////////////////// -// Single Array with generic dims - -template -METAL_FUNC IdxT elem_to_loc( - IdxT elem, - constant const int* shape, - constant const int64_t* strides, - int ndim) { - IdxT loc = 0; - for (int i = ndim - 1; i >= 0 && elem > 0; --i) { - loc += (elem % shape[i]) * IdxT(strides[i]); - elem /= shape[i]; - } - return loc; -} - -// Non templated version to handle arbitrary dims -template -METAL_FUNC IdxT elem_to_loc( - uint3 elem, - constant const int* shape, - constant const int64_t* strides, - int ndim) { - IdxT loc = - elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]); - for (int d = ndim - 3; d >= 0; --d) { - loc += (elem.z % shape[d]) * IdxT(strides[d]); - elem.z /= shape[d]; - } - return loc; -} - -/////////////////////////////////////////////////////////////////////////////// -// Single Array with fixed N dims - -template -METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) { - return elem * IdxT(stride); -} - -template -METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) { - return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]); -} - -template -METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) { - return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) + - elem.z * IdxT(strides[0]); -} - -/////////////////////////////////////////////////////////////////////////////// -// Multiple Arrays with generic dims - -template -METAL_FUNC vec elem_to_loc_2_nd( - uint3 elem, - constant const int* shape, - constant const int64_t* a_strides, - constant const int64_t* b_strides, - int ndim) { - vec loc = { - IdxT( - elem.x * IdxT(a_strides[ndim - 1]) + - IdxT(elem.y) * IdxT(a_strides[ndim - 2])), - IdxT( - elem.x * IdxT(b_strides[ndim - 1]) + - elem.y * IdxT(b_strides[ndim - 2]))}; - for (int d = ndim - 3; d >= 0; --d) { - uint l = elem.z % shape[d]; - loc.x += l * IdxT(a_strides[d]); - loc.y += l * IdxT(b_strides[d]); - elem.z /= shape[d]; - } - return loc; -} - -template -METAL_FUNC vec elem_to_loc_3_nd( - uint3 elem, - constant const int* shape, - constant const int64_t* a_strides, - constant const int64_t* b_strides, - constant const int64_t* c_strides, - int ndim) { - vec loc = { - IdxT(elem.x * IdxT(a_strides[ndim - 1])) + - IdxT(elem.y * IdxT(a_strides[ndim - 2])), - IdxT(elem.x * IdxT(b_strides[ndim - 1])) + - IdxT(elem.y * IdxT(b_strides[ndim - 2])), - IdxT(elem.x * IdxT(c_strides[ndim - 1])) + - IdxT(elem.y * IdxT(c_strides[ndim - 2]))}; - for (int d = ndim - 3; d >= 0; --d) { - uint l = elem.z % shape[d]; - loc.x += l * IdxT(a_strides[d]); - loc.y += l * IdxT(b_strides[d]); - loc.z += l * IdxT(c_strides[d]); - elem.z /= shape[d]; - } - return loc; -} - -/////////////////////////////////////////////////////////////////////////////// -// Elem to loc in a loop utils -/////////////////////////////////////////////////////////////////////////////// - -template -struct LoopedElemToLoc { - int dim; - LoopedElemToLoc inner_looper; - OffsetT offset{0}; - int index{0}; - - LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {} - - void next(const constant int* shape, const constant int64_t* strides) { - if (dim == 0) { - return; - } - index++; - offset += OffsetT(strides[dim - 1]); - if (index >= shape[dim - 1]) { - index = 0; - inner_looper.next(shape, strides); - offset = inner_looper.offset; - } - } - - void next(int n, const constant int* shape, const constant int64_t* strides) { - if (dim == 0) { - return; - } - index += n; - offset += n * OffsetT(strides[dim - 1]); - - if (index >= shape[dim - 1]) { - int extra = index - shape[dim - 1]; - if (extra >= shape[dim - 1]) { - inner_looper.next(1 + extra / shape[dim - 1], shape, strides); - extra = extra % shape[dim - 1]; - } else { - inner_looper.next(shape, strides); - } - index = 0; - offset = inner_looper.offset; - if (extra > 0) { - next(extra, shape, strides); - } - } - } - - OffsetT location() { - return offset; - } -}; - -template -struct LoopedElemToLoc<1, OffsetT, true> { - int dim; - OffsetT offset{0}; - uint index{0}; - - LoopedElemToLoc(int dim) : dim(dim) {} - - void next(const constant int* shape, const constant int64_t* strides) { - index++; - if (dim > 1) { - offset = elem_to_loc(index, shape, strides, dim); - } else { - offset += OffsetT(strides[0]); - } - } - - void next(int n, const constant int* shape, const constant int64_t* strides) { - index += n; - if (dim > 1) { - offset = elem_to_loc(index, shape, strides, dim); - } else { - offset = index * OffsetT(strides[0]); - } - } - - OffsetT location() { - return offset; - } -}; - -template -struct LoopedElemToLoc<1, OffsetT, false> { - OffsetT offset{0}; - - LoopedElemToLoc(int) {} - - void next(const constant int*, const constant int64_t* strides) { - offset += OffsetT(strides[0]); - } - - void next(int n, const constant int*, const constant int64_t* strides) { - offset += n * OffsetT(strides[0]); - } - - OffsetT location() { - return offset; - } -}; - -/////////////////////////////////////////////////////////////////////////////// -// Calculation utils -/////////////////////////////////////////////////////////////////////////////// - -/** Compute ceil((float)N/(float)M) */ -template -inline T ceildiv(T N, U M) { - return (N + M - 1) / M; -} - -// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 -inline float log1p(float x) { - float xp1 = 1.0f + x; - if (xp1 == Limits::max) { - return Limits::max; - } - if (xp1 == 1.0f) { - return x; - } - - return x * (metal::log(xp1) / (xp1 - 1.0f)); -} - -inline bfloat16_t log1p(bfloat16_t x) { - float xp1 = 1.0f + static_cast(x); - if (xp1 == Limits::max) { - return Limits::max; - } - if (xp1 == 1.0f) { - return x; - } - - return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f))); -} - -inline complex64_t log1p(complex64_t in) { - float x = in.real; - float y = in.imag; - float zabs = metal::precise::sqrt(x * x + y * y); - float theta = metal::atan2(y, x + 1); - if (zabs < 0.5f) { - float r = x * (2 + x) + y * y; - if (r == 0) { // handle underflow - return {x, theta}; - } - return {0.5f * log1p(r), theta}; - } else { - auto z0 = metal::sqrt((x + 1) * (x + 1) + y * y); - return {metal::log(z0), theta}; - } -} - -/////////////////////////////////////////////////////////////////////////////// -// SIMD shuffle ops -/////////////////////////////////////////////////////////////////////////////// - -inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) { - return as_type( - metal::simd_shuffle_down(as_type(data), delta)); -} - -inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) { - return as_type( - metal::simd_shuffle_down(as_type(data), delta)); -} - -inline bool simd_shuffle_down(bool data, uint16_t delta) { - return simd_shuffle_down(static_cast(data), delta); -} - -inline complex64_t simd_shuffle_down(complex64_t data, uint16_t delta) { - return complex64_t( - simd_shuffle_down(data.real, delta), simd_shuffle_down(data.imag, delta)); -} - -inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) { - return as_type(metal::simd_shuffle_up(as_type(data), delta)); -} - -inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) { - return as_type(metal::simd_shuffle_up(as_type(data), delta)); -} - -inline bool simd_shuffle_up(bool data, uint16_t delta) { - return simd_shuffle_up(static_cast(data), delta); -} - -inline complex64_t simd_shuffle_up(complex64_t data, uint16_t delta) { - return complex64_t( - simd_shuffle_up(data.real, delta), simd_shuffle_up(data.imag, delta)); -} - -inline uint64_t -simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) { - return as_type(metal::simd_shuffle_and_fill_up( - as_type(data), as_type(filling), delta)); -} - -inline int64_t -simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) { - return as_type(metal::simd_shuffle_and_fill_up( - as_type(data), as_type(filling), delta)); -} - -inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) { - return simd_shuffle_and_fill_up( - static_cast(data), static_cast(filling), delta); -} - -inline complex64_t simd_shuffle_and_fill_up( - complex64_t data, - complex64_t filling, - uint16_t delta) { - return complex64_t( - simd_shuffle_and_fill_up(data.real, filling.real, delta), - simd_shuffle_and_fill_up(data.imag, filling.imag, delta)); -} - -inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) { - return as_type(metal::simd_shuffle(as_type(data), lane)); -} - -inline int64_t simd_shuffle(int64_t data, uint16_t lane) { - return as_type(metal::simd_shuffle(as_type(data), lane)); -} - -inline bool simd_shuffle(bool data, uint16_t lane) { - return simd_shuffle(static_cast(data), lane); -} - -inline complex64_t simd_shuffle(complex64_t data, uint16_t lane) { - return complex64_t( - simd_shuffle(data.real, lane), simd_shuffle(data.imag, lane)); -} - -// std::conditional is not included with Metal -template -struct ConditionalType { - using type = U; -}; - -template -struct ConditionalType { - using type = T; -}; diff --git a/Source/MLX/MLXFastKernel.swift b/Source/MLX/MLXFastKernel.swift index 03714913..45606d76 100644 --- a/Source/MLX/MLXFastKernel.swift +++ b/Source/MLX/MLXFastKernel.swift @@ -6,6 +6,7 @@ import Cmlx /// /// Currently: /// - `Int` +/// - `UInt32` /// - `Bool` /// - `DType` /// @@ -14,6 +15,7 @@ public protocol KernelTemplateArg {} extension Bool: KernelTemplateArg {} extension Int: KernelTemplateArg {} +extension UInt32: KernelTemplateArg {} extension DType: KernelTemplateArg {} extension MLXFast { @@ -114,8 +116,16 @@ extension MLXFast { mlx_fast_metal_kernel_config_add_template_arg_bool(config, name, value) case let value as Int: + guard let int32Value = Int32(exactly: value) else { + fatalError( + "KernelTemplateArg \(name) Int value \(value) is outside the Int32 range." + ) + } mlx_fast_metal_kernel_config_add_template_arg_int( - config, name, Int32(value)) + config, name, int32Value) + + case let value as UInt32: + mlx_fast_metal_kernel_config_add_template_arg_uint32(config, name, value) case let value as DType: mlx_fast_metal_kernel_config_add_template_arg_dtype( diff --git a/Source/MLX/TurboQuant.swift b/Source/MLX/TurboQuant.swift new file mode 100644 index 00000000..85f15b17 --- /dev/null +++ b/Source/MLX/TurboQuant.swift @@ -0,0 +1,4778 @@ +import Cmlx +import Foundation + +#if canImport(Metal) + import Metal +#endif + +/// TurboQuant preset requested by higher-level runtime code. +/// +/// This additive Swift API gives callers one stable surface for the fast packed +/// MLX compatibility path, a deterministic TurboQuantProd/QJL reference codec, +/// and the TurboQuantProd key plus bitpacked-value Metal backend. +public enum TurboQuantPreset: String, Codable, Sendable, CaseIterable { + case turbo2_5 + case turbo3_5 + + public var displayName: String { + switch self { + case .turbo2_5: + "TurboQuant 2.5-bit" + case .turbo3_5: + "TurboQuant 3.5-bit" + } + } + + /// Current native MLX packed-lane width used by the compatibility path. + /// + /// MLX's public packed quantized matmul kernels accept integer lane widths. + /// The mixed-bit Metal path uses ``baseMagnitudeBits`` and + /// ``highMagnitudeBits`` directly; this value exists for MLX packed fallback + /// interoperability. + public var effectiveBits: Int { + switch self { + case .turbo2_5: + 2 + case .turbo3_5: + 4 + } + } + + public var baseMagnitudeBits: Int { + switch self { + case .turbo2_5: + 2 + case .turbo3_5: + 3 + } + } + + public var highMagnitudeBits: Int { + switch self { + case .turbo2_5: + 3 + case .turbo3_5: + 4 + } + } + + public var targetMagnitudeBits: Float { + switch self { + case .turbo2_5: + 2.5 + case .turbo3_5: + 3.5 + } + } + + public var defaultValueBits: Int { + switch self { + case .turbo2_5: + 2 + case .turbo3_5: + 4 + } + } +} + +public enum TurboQuantTensorRole: String, Codable, Sendable, CaseIterable { + case key + case value + case vector +} + +public enum TurboQuantBackend: String, Codable, Sendable, CaseIterable { + /// MLX's native packed quantization and quantized matrix-multiply kernels. + /// + /// This is the production backend Pine uses today on iOS. + case mlxPacked + + /// Deterministic CPU reference implementation for the TurboQuantProd key + /// path, affine value path, and QJL residual sign estimator. + case polarQJLReference + + /// Mixed-bit key and bitpacked-value PolarQuant/QJL Metal kernels. + case metalPolarQJL +} + +public enum TurboQuantReferenceFormat: String, Codable, Sendable, Hashable, CaseIterable { + case magnitudeResidualSign + case turboQuantProd + case affineValue +} + +public enum TurboQuantKernelProfile: String, Codable, Sendable, CaseIterable { + case portableA16A17 + case wideA18A19 + case sustainedA19Pro + case mlxPackedFallback + + public var displayName: String { + switch self { + case .portableA16A17: + "Portable A16/A17" + case .wideA18A19: + "Wide A18/A19" + case .sustainedA19Pro: + "Sustained A19 Pro" + case .mlxPackedFallback: + "MLX packed fallback" + } + } + + var fusedDecodeThreadgroupWidth: Int { + switch self { + case .portableA16A17: + 128 + case .wideA18A19, .sustainedA19Pro: + 256 + case .mlxPackedFallback: + 128 + } + } +} + +public enum TurboQuantRuntimeSelfTestStatus: String, Codable, Sendable, CaseIterable { + case notRun + case passed + case failed +} + +public struct TurboQuantRuntimeProbeResult: Equatable, Codable, Sendable { + public var status: TurboQuantRuntimeSelfTestStatus + public var metalRuntimeAvailable: Bool + public var encodeDecodePassed: Bool + public var qkPassed: Bool + public var avPassed: Bool + public var tiledFusedPassed: Bool + public var selectedKernelProfile: TurboQuantKernelProfile + public var failureReason: String? + public var encodeDecodeLatencySeconds: Double? + public var twoStageLatencySeconds: Double? + public var tiledFusedLatencySeconds: Double? + + public init( + status: TurboQuantRuntimeSelfTestStatus = .notRun, + metalRuntimeAvailable: Bool = false, + encodeDecodePassed: Bool = false, + qkPassed: Bool = false, + avPassed: Bool = false, + tiledFusedPassed: Bool = false, + selectedKernelProfile: TurboQuantKernelProfile = .mlxPackedFallback, + failureReason: String? = nil, + encodeDecodeLatencySeconds: Double? = nil, + twoStageLatencySeconds: Double? = nil, + tiledFusedLatencySeconds: Double? = nil + ) { + self.status = status + self.metalRuntimeAvailable = metalRuntimeAvailable + self.encodeDecodePassed = encodeDecodePassed + self.qkPassed = qkPassed + self.avPassed = avPassed + self.tiledFusedPassed = tiledFusedPassed + self.selectedKernelProfile = selectedKernelProfile + self.failureReason = failureReason + self.encodeDecodeLatencySeconds = encodeDecodeLatencySeconds + self.twoStageLatencySeconds = twoStageLatencySeconds + self.tiledFusedLatencySeconds = tiledFusedLatencySeconds + } + + public var passed: Bool { + status == .passed + && metalRuntimeAvailable + && encodeDecodePassed + && qkPassed + && avPassed + && tiledFusedPassed + } +} + +public struct TurboQuantDeviceCapabilities: Equatable, Codable, Sendable { + public var metalAvailable: Bool + public var architectureName: String + public var supportedGPUFamilies: [String: Bool] + public var maxBufferBytes: Int + public var recommendedWorkingSetBytes: Int? + public var physicalMemoryBytes: Int? + public var maxThreadgroupWidth: Int? + public var runtimeProbe: TurboQuantRuntimeProbeResult + + public init( + metalAvailable: Bool, + architectureName: String, + supportedGPUFamilies: [String: Bool] = [:], + maxBufferBytes: Int = 0, + recommendedWorkingSetBytes: Int? = nil, + physicalMemoryBytes: Int? = nil, + maxThreadgroupWidth: Int? = nil, + runtimeProbe: TurboQuantRuntimeProbeResult = TurboQuantRuntimeProbeResult() + ) { + self.metalAvailable = metalAvailable + self.architectureName = architectureName + self.supportedGPUFamilies = supportedGPUFamilies + self.maxBufferBytes = maxBufferBytes + self.recommendedWorkingSetBytes = recommendedWorkingSetBytes + self.physicalMemoryBytes = physicalMemoryBytes + self.maxThreadgroupWidth = maxThreadgroupWidth + self.runtimeProbe = runtimeProbe + } + + public var selectedKernelProfile: TurboQuantKernelProfile { + runtimeProbe.selectedKernelProfile + } + + public static var current: TurboQuantDeviceCapabilities { + var capabilities = detectedTurboQuantDeviceCapabilities() + capabilities.runtimeProbe = TurboQuantRuntimeProbe.shared.result() + return capabilities + } +} + +public struct TurboQuantKernelAvailability: Equatable, Codable, Sendable { + public var supportsMLXPacked: Bool + public var supportsPolarQJLReference: Bool + public var supportsMetalPolarQJLCodec: Bool + public var supportsMetalPolarQJLAttention: Bool + public var supportsMetalPolarQJL: Bool + public var selectedKernelProfile: TurboQuantKernelProfile + public var selfTestStatus: TurboQuantRuntimeSelfTestStatus + public var selfTestFailureReason: String? + + public init( + supportsMLXPacked: Bool = true, + supportsPolarQJLReference: Bool = true, + supportsMetalPolarQJLCodec: Bool = false, + supportsMetalPolarQJLAttention: Bool = false, + supportsMetalPolarQJL: Bool = false, + selectedKernelProfile: TurboQuantKernelProfile = .mlxPackedFallback, + selfTestStatus: TurboQuantRuntimeSelfTestStatus = .notRun, + selfTestFailureReason: String? = nil + ) { + self.supportsMLXPacked = supportsMLXPacked + self.supportsPolarQJLReference = supportsPolarQJLReference + self.supportsMetalPolarQJLCodec = supportsMetalPolarQJLCodec + self.supportsMetalPolarQJLAttention = supportsMetalPolarQJLAttention + self.supportsMetalPolarQJL = supportsMetalPolarQJL + self.selectedKernelProfile = selectedKernelProfile + self.selfTestStatus = selfTestStatus + self.selfTestFailureReason = selfTestFailureReason + } + + public static var current: TurboQuantKernelAvailability { + let metalAvailable = metalRuntimeAvailable() + let probe = TurboQuantRuntimeProbe.shared.result() + let attentionAvailable = metalAvailable && probe.passed + return TurboQuantKernelAvailability( + supportsMetalPolarQJLCodec: metalAvailable, + supportsMetalPolarQJLAttention: attentionAvailable, + supportsMetalPolarQJL: attentionAvailable, + selectedKernelProfile: probe.selectedKernelProfile, + selfTestStatus: probe.status, + selfTestFailureReason: probe.failureReason + ) + } + + public func supports(_ backend: TurboQuantBackend) -> Bool { + switch backend { + case .mlxPacked: + supportsMLXPacked + case .polarQJLReference: + supportsPolarQJLReference + case .metalPolarQJL: + supportsMetalPolarQJL + } + } + + public func runtimeBackend(for requestedBackend: TurboQuantBackend) -> TurboQuantBackend { + if supports(requestedBackend) { + requestedBackend + } else { + .mlxPacked + } + } + + public func fallbackReason(for requestedBackend: TurboQuantBackend) -> String? { + guard !supports(requestedBackend) else { return nil } + + switch requestedBackend { + case .mlxPacked: + return nil + case .polarQJLReference: + return + "PolarQuant/QJL reference backend unavailable; using MLX packed TurboQuant lanes." + case .metalPolarQJL: + if let selfTestFailureReason { + return + "TurboQuant Metal self-test failed: \(selfTestFailureReason); using MLX packed TurboQuant lanes." + } + return + "TurboQuant Metal kernels unavailable; using MLX packed TurboQuant lanes." + } + } +} + +public enum TurboQuantError: Error, Equatable, CustomStringConvertible { + case invalidGroupSize(Int) + case invalidMetalConfiguration(String) + case invalidQualityInput(String) + case invalidReferenceCode(String) + case unsupportedBackend(TurboQuantBackend, String) + + public var description: String { + switch self { + case .invalidGroupSize(let groupSize): + "TurboQuant group size must be positive, got \(groupSize)." + case .invalidMetalConfiguration(let message): + "Invalid TurboQuant Metal configuration: \(message)" + case .invalidQualityInput(let message): + "Invalid TurboQuant quality input: \(message)" + case .invalidReferenceCode(let message): + "Invalid TurboQuant reference code: \(message)" + case .unsupportedBackend(let backend, let message): + "Unsupported TurboQuant backend \(backend.rawValue): \(message)" + } + } +} + +public struct TurboQuantConfiguration: Hashable, Codable, Sendable { + public var preset: TurboQuantPreset + public var role: TurboQuantTensorRole + public var groupSize: Int + public var mode: QuantizationMode + public var backend: TurboQuantBackend + public var seed: UInt64 + public var qjlResidualScale: Float + public var valueBits: Int? + + public init( + preset: TurboQuantPreset = .turbo3_5, + role: TurboQuantTensorRole = .vector, + groupSize: Int = 64, + mode: QuantizationMode = .affine, + backend: TurboQuantBackend = .mlxPacked, + seed: UInt64 = 0x9E37_79B9_7F4A_7C15, + qjlResidualScale: Float = 0.5, + valueBits: Int? = nil + ) { + self.preset = preset + self.role = role + self.groupSize = groupSize + self.mode = mode + self.backend = backend + self.seed = seed + self.qjlResidualScale = qjlResidualScale + self.valueBits = valueBits + } + + public var effectiveBits: Int { preset.effectiveBits } + + public var resolvedValueBits: Int { + valueBits ?? preset.defaultValueBits + } + + public var runtimeBackend: TurboQuantBackend { + TurboQuantKernelAvailability.current.runtimeBackend(for: backend) + } + + public var runtimeFallbackReason: String? { + TurboQuantKernelAvailability.current.fallbackReason(for: backend) + } + + public static func deterministicSeed( + modelID: String, + revision: String, + cacheLayoutVersion: Int + ) -> UInt64 { + var hash: UInt64 = 0xCBF2_9CE4_8422_2325 + for byte in "\(modelID)#\(revision)#\(cacheLayoutVersion)".utf8 { + hash ^= UInt64(byte) + hash &*= 0x0000_0100_0000_01B3 + } + return hash == 0 ? 0x9E37_79B9_7F4A_7C15 : hash + } +} + +public typealias TurboQuantPackedTensor = ( + weight: MLXArray, + scales: MLXArray, + biases: MLXArray? +) + +public struct TurboQuantReferenceCode: Hashable, Codable, Sendable { + public var shape: [Int] + public var preset: TurboQuantPreset + public var role: TurboQuantTensorRole + public var format: TurboQuantReferenceFormat + public var groupSize: Int + public var seed: UInt64 + public var residualScale: Float + public var baseMagnitudeBits: Int + public var highMagnitudeBits: Int + public var valueCount: Int + public var baseScales: [Float] + public var highScales: [Float] + public var residualScales: [Float] + public var signs: Data + public var highPrecisionMask: Data + public var residualSigns: Data + public var packedMagnitudes: Data + + private enum CodingKeys: String, CodingKey { + case shape + case preset + case role + case format + case groupSize + case seed + case residualScale + case baseMagnitudeBits + case highMagnitudeBits + case valueCount + case baseScales + case highScales + case residualScales + case signs + case highPrecisionMask + case residualSigns + case packedMagnitudes + } + + public init( + shape: [Int], + preset: TurboQuantPreset, + role: TurboQuantTensorRole, + format: TurboQuantReferenceFormat = .magnitudeResidualSign, + groupSize: Int, + seed: UInt64, + residualScale: Float, + baseMagnitudeBits: Int, + highMagnitudeBits: Int, + valueCount: Int, + baseScales: [Float], + highScales: [Float], + residualScales: [Float]? = nil, + signs: Data, + highPrecisionMask: Data, + residualSigns: Data, + packedMagnitudes: Data + ) { + self.shape = shape + self.preset = preset + self.role = role + self.format = format + self.groupSize = groupSize + self.seed = seed + self.residualScale = residualScale + self.baseMagnitudeBits = baseMagnitudeBits + self.highMagnitudeBits = highMagnitudeBits + self.valueCount = valueCount + self.baseScales = baseScales + self.highScales = highScales + self.residualScales = residualScales ?? [] + self.signs = signs + self.highPrecisionMask = highPrecisionMask + self.residualSigns = residualSigns + self.packedMagnitudes = packedMagnitudes + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + shape = try container.decode([Int].self, forKey: .shape) + preset = try container.decode(TurboQuantPreset.self, forKey: .preset) + role = try container.decode(TurboQuantTensorRole.self, forKey: .role) + format = + try container.decodeIfPresent(TurboQuantReferenceFormat.self, forKey: .format) + ?? .magnitudeResidualSign + groupSize = try container.decode(Int.self, forKey: .groupSize) + seed = try container.decode(UInt64.self, forKey: .seed) + residualScale = try container.decodeIfPresent(Float.self, forKey: .residualScale) ?? 0.5 + baseMagnitudeBits = try container.decode(Int.self, forKey: .baseMagnitudeBits) + highMagnitudeBits = try container.decode(Int.self, forKey: .highMagnitudeBits) + valueCount = try container.decode(Int.self, forKey: .valueCount) + baseScales = try container.decode([Float].self, forKey: .baseScales) + highScales = try container.decode([Float].self, forKey: .highScales) + residualScales = try container.decodeIfPresent([Float].self, forKey: .residualScales) ?? [] + signs = try container.decode(Data.self, forKey: .signs) + highPrecisionMask = try container.decode(Data.self, forKey: .highPrecisionMask) + residualSigns = try container.decode(Data.self, forKey: .residualSigns) + packedMagnitudes = try container.decode(Data.self, forKey: .packedMagnitudes) + } + + public func encode(to encoder: Encoder) throws { + var container = encoder.container(keyedBy: CodingKeys.self) + try container.encode(shape, forKey: .shape) + try container.encode(preset, forKey: .preset) + try container.encode(role, forKey: .role) + try container.encode(format, forKey: .format) + try container.encode(groupSize, forKey: .groupSize) + try container.encode(seed, forKey: .seed) + try container.encode(residualScale, forKey: .residualScale) + try container.encode(baseMagnitudeBits, forKey: .baseMagnitudeBits) + try container.encode(highMagnitudeBits, forKey: .highMagnitudeBits) + try container.encode(valueCount, forKey: .valueCount) + try container.encode(baseScales, forKey: .baseScales) + try container.encode(highScales, forKey: .highScales) + try container.encode(residualScales, forKey: .residualScales) + try container.encode(signs, forKey: .signs) + try container.encode(highPrecisionMask, forKey: .highPrecisionMask) + try container.encode(residualSigns, forKey: .residualSigns) + try container.encode(packedMagnitudes, forKey: .packedMagnitudes) + } + + public var storageByteCount: Int { + switch format { + case .affineValue: + packedMagnitudes.count + + (baseScales.count + highScales.count) * MemoryLayout.stride + case .turboQuantProd: + packedMagnitudes.count + + signs.count + + (baseScales.count + highScales.count) * MemoryLayout.stride + case .magnitudeResidualSign: + packedMagnitudes.count + + signs.count + + highPrecisionMask.count + + residualSigns.count + + (baseScales.count + highScales.count + residualScales.count) + * MemoryLayout.stride + } + } + + public var approximateBitsPerValue: Double { + guard valueCount > 0 else { return 0 } + return Double(storageByteCount * 8) / Double(valueCount) + } +} + +public struct TurboQuantMetalCode { + public var shape: [Int] + public var preset: TurboQuantPreset + public var role: TurboQuantTensorRole + public var groupSize: Int + public var seed: UInt64 + public var valueBits: Int + public var valueCount: Int + public var groupCount: Int + public var magnitudeWordsPerGroup: Int + public var bitsetWordsPerGroup: Int + public var scalesPerGroup: Int + public var packedMagnitudes: MLXArray + public var signs: MLXArray + public var highPrecisionMask: MLXArray + public var residualSigns: MLXArray + public var scales: MLXArray + + public var storageByteCount: Int { + if role == .value { + return packedMagnitudes.nbytes + scales.nbytes + } + return packedMagnitudes.nbytes + + signs.nbytes + + highPrecisionMask.nbytes + + residualSigns.nbytes + + scales.nbytes + } + + public var approximateBitsPerValue: Double { + guard valueCount > 0 else { return 0 } + return Double(storageByteCount * 8) / Double(valueCount) + } +} + +public enum TurboQuantAttentionPath: String, Codable, Sendable, CaseIterable { + case onlineFused + case tiledOnlineFused + case twoStageCompressed + case mlxPackedFallback + case baseline +} + +public struct TurboQuantAttentionLayout: Hashable, Codable, Sendable { + public static let currentVersion = 4 + + public var layoutVersion: Int + public var batchSize: Int + public var kvHeadCount: Int + public var capacity: Int + public var logicalLength: Int + public var ringOffset: Int + public var pinnedPrefixLength: Int + public var headDimension: Int + public var groupsPerVector: Int + public var magnitudeWordsPerGroup: Int + public var bitsetWordsPerGroup: Int + + public init( + layoutVersion: Int = TurboQuantAttentionLayout.currentVersion, + batchSize: Int, + kvHeadCount: Int, + capacity: Int, + logicalLength: Int, + ringOffset: Int = 0, + pinnedPrefixLength: Int = 0, + headDimension: Int, + groupsPerVector: Int, + magnitudeWordsPerGroup: Int, + bitsetWordsPerGroup: Int + ) { + self.layoutVersion = layoutVersion + self.batchSize = batchSize + self.kvHeadCount = kvHeadCount + self.capacity = capacity + self.logicalLength = logicalLength + self.ringOffset = ringOffset + self.pinnedPrefixLength = pinnedPrefixLength + self.headDimension = headDimension + self.groupsPerVector = groupsPerVector + self.magnitudeWordsPerGroup = magnitudeWordsPerGroup + self.bitsetWordsPerGroup = bitsetWordsPerGroup + } + + public var logicalShape: [Int] { + [batchSize, kvHeadCount, logicalLength, headDimension] + } + + public var storageShape: [Int] { + [batchSize, kvHeadCount, capacity, headDimension] + } +} + +public struct TurboQuantAttentionCode { + public var layout: TurboQuantAttentionLayout + public var preset: TurboQuantPreset + public var role: TurboQuantTensorRole + public var groupSize: Int + public var seed: UInt64 + public var valueBits: Int + public var scalesPerGroup: Int + public var packedMagnitudes: MLXArray + public var signs: MLXArray + public var highPrecisionMask: MLXArray + public var residualSigns: MLXArray + public var scales: MLXArray + + public init( + layout: TurboQuantAttentionLayout, + preset: TurboQuantPreset, + role: TurboQuantTensorRole, + groupSize: Int, + seed: UInt64, + valueBits: Int? = nil, + scalesPerGroup: Int? = nil, + packedMagnitudes: MLXArray, + signs: MLXArray, + highPrecisionMask: MLXArray, + residualSigns: MLXArray, + scales: MLXArray + ) { + self.layout = layout + self.preset = preset + self.role = role + self.groupSize = groupSize + self.seed = seed + self.valueBits = valueBits ?? preset.defaultValueBits + self.scalesPerGroup = scalesPerGroup ?? (role == .value ? 2 : 3) + self.packedMagnitudes = packedMagnitudes + self.signs = signs + self.highPrecisionMask = highPrecisionMask + self.residualSigns = residualSigns + self.scales = scales + } + + public var storageByteCount: Int { + if role == .value { + return packedMagnitudes.nbytes + scales.nbytes + } + return packedMagnitudes.nbytes + + signs.nbytes + + highPrecisionMask.nbytes + + residualSigns.nbytes + + scales.nbytes + } + + public var approximateBitsPerValue: Double { + let values = + layout.batchSize * layout.kvHeadCount + * Swift.max(layout.logicalLength, 1) * layout.headDimension + return Double(storageByteCount * 8) / Double(values) + } +} + +public struct TurboQuantQualityThresholds: Hashable, Codable, Sendable { + public var maxRelativeMSE: Float + public var minCosineSimilarity: Float + public var maxInnerProductRelativeError: Float + + public init( + maxRelativeMSE: Float = 0.02, + minCosineSimilarity: Float = 0.99, + maxInnerProductRelativeError: Float = 0.08 + ) { + self.maxRelativeMSE = maxRelativeMSE + self.minCosineSimilarity = minCosineSimilarity + self.maxInnerProductRelativeError = maxInnerProductRelativeError + } +} + +public struct TurboQuantQualityReport: Hashable, Codable, Sendable { + public var mse: Float + public var relativeMSE: Float + public var maxAbsoluteError: Float + public var cosineSimilarity: Float + public var innerProductRelativeError: Float + public var thresholds: TurboQuantQualityThresholds + + public var passes: Bool { + relativeMSE <= thresholds.maxRelativeMSE + && cosineSimilarity >= thresholds.minCosineSimilarity + && innerProductRelativeError <= thresholds.maxInnerProductRelativeError + } +} + +public func turboQuantized( + _ array: MLXArray, + configuration: TurboQuantConfiguration = TurboQuantConfiguration(), + stream: StreamOrDevice = .default +) -> TurboQuantPackedTensor { + let packed = quantized( + array, + groupSize: configuration.groupSize, + bits: configuration.effectiveBits, + mode: configuration.mode, + stream: stream + ) + return (packed.wq, packed.scales, packed.biases) +} + +public func turboDequantized( + _ packed: TurboQuantPackedTensor, + configuration: TurboQuantConfiguration = TurboQuantConfiguration(), + dtype: DType? = nil, + stream: StreamOrDevice = .default +) -> MLXArray { + dequantized( + packed.weight, + scales: packed.scales, + biases: packed.biases, + groupSize: configuration.groupSize, + bits: configuration.effectiveBits, + mode: configuration.mode, + dtype: dtype, + stream: stream + ) +} + +public func turboQuantizedMM( + _ x: MLXArray, + _ packed: TurboQuantPackedTensor, + transpose: Bool = true, + configuration: TurboQuantConfiguration = TurboQuantConfiguration(), + stream: StreamOrDevice = .default +) -> MLXArray { + quantizedMM( + x, + packed.weight, + scales: packed.scales, + biases: packed.biases, + transpose: transpose, + groupSize: configuration.groupSize, + bits: configuration.effectiveBits, + mode: configuration.mode, + stream: stream + ) +} + +public func turboQuantizedMM( + _ x: MLXArray, + _ code: TurboQuantMetalCode, + transpose: Bool = true, + outputDType: DType? = nil, + stream: StreamOrDevice = .gpu +) throws -> MLXArray { + try turboQuantMetalMM( + x, + code, + transpose: transpose, + outputDType: outputDType, + stream: stream + ) +} + +public func turboQuantReferenceEncode( + _ array: MLXArray, + configuration: TurboQuantConfiguration = TurboQuantConfiguration( + backend: .polarQJLReference + ) +) throws -> TurboQuantReferenceCode { + guard configuration.groupSize > 0 else { + throw TurboQuantError.invalidGroupSize(configuration.groupSize) + } + + let values = array.asArray(Float.self) + return try encodeTurboQuantReference( + values: values, shape: array.shape, configuration: configuration) +} + +public func turboQuantReferenceDecode( + _ code: TurboQuantReferenceCode +) throws -> MLXArray { + let values = try decodeTurboQuantReference(code) + return MLXArray(values, code.shape) +} + +public func turboQuantReferenceQuality( + _ array: MLXArray, + configuration: TurboQuantConfiguration = TurboQuantConfiguration( + backend: .polarQJLReference + ), + thresholds: TurboQuantQualityThresholds = TurboQuantQualityThresholds() +) throws -> TurboQuantQualityReport { + let original = array.asArray(Float.self) + let code = try turboQuantReferenceEncode(array, configuration: configuration) + let decoded = try turboQuantReferenceDecode(code).asArray(Float.self) + return try turboQuantQuality( + original: original, + decoded: decoded, + seed: configuration.seed, + thresholds: thresholds + ) +} + +public func turboQuantReferenceInnerProduct( + query: MLXArray, + code: TurboQuantReferenceCode +) throws -> Float { + let queryValues = query.asArray(Float.self) + guard queryValues.count == code.valueCount else { + throw TurboQuantError.invalidQualityInput( + "query contains \(queryValues.count) values but code contains \(code.valueCount)" + ) + } + if code.format == .turboQuantProd { + return try turboQuantProductInnerProduct(query: queryValues, code: code) + } + let decoded = try decodeTurboQuantReference(code) + return zip(queryValues, decoded).reduce(Float(0)) { partial, pair in + partial + pair.0 * pair.1 + } +} + +public func turboQuantMetalEncode( + _ array: MLXArray, + configuration: TurboQuantConfiguration = TurboQuantConfiguration(backend: .metalPolarQJL), + stream: StreamOrDevice = .gpu +) throws -> TurboQuantMetalCode { + try validateMetalConfiguration(array: array, configuration: configuration) + + let valueCount = array.size + let groupSize = configuration.groupSize + let groupCount = (valueCount + groupSize - 1) / groupSize + let magnitudeWordsPerGroup = metalMagnitudeWordsPerGroup( + groupSize: groupSize, + preset: configuration.preset, + role: configuration.role, + valueBits: configuration.resolvedValueBits + ) + let bitsetWordsPerGroup = (groupSize + 31) / 32 + let scalesPerGroup = metalScalesPerGroup(role: configuration.role) + let threadGroupSize = Swift.max(1, Swift.min(groupCount, 64)) + let bitsetShape = [groupCount * bitsetWordsPerGroup] + + let outputs = TurboQuantMetalKernels.encode( + [array], + template: metalTemplate( + configuration: configuration, + valueCount: valueCount, + groupCount: groupCount, + magnitudeWordsPerGroup: magnitudeWordsPerGroup, + bitsetWordsPerGroup: bitsetWordsPerGroup + ), + grid: (groupCount, 1, 1), + threadGroup: (threadGroupSize, 1, 1), + outputShapes: [ + [groupCount * magnitudeWordsPerGroup], + bitsetShape, + bitsetShape, + bitsetShape, + [groupCount, scalesPerGroup], + ], + outputDTypes: [.uint32, .uint32, .uint32, .uint32, .float32], + initValue: 0, + stream: stream + ) + + return TurboQuantMetalCode( + shape: array.shape, + preset: configuration.preset, + role: configuration.role, + groupSize: groupSize, + seed: configuration.seed, + valueBits: configuration.resolvedValueBits, + valueCount: valueCount, + groupCount: groupCount, + magnitudeWordsPerGroup: magnitudeWordsPerGroup, + bitsetWordsPerGroup: bitsetWordsPerGroup, + scalesPerGroup: scalesPerGroup, + packedMagnitudes: outputs[0], + signs: outputs[1], + highPrecisionMask: outputs[2], + residualSigns: outputs[3], + scales: outputs[4] + ) +} + +public func turboQuantMetalDecode( + _ code: TurboQuantMetalCode, + dtype: DType = .float32, + stream: StreamOrDevice = .gpu +) throws -> MLXArray { + guard code.valueCount > 0 else { + throw TurboQuantError.invalidMetalConfiguration("empty arrays are not supported") + } + guard code.groupSize > 0, code.groupSize <= 128, code.groupSize % 32 == 0 else { + throw TurboQuantError.invalidGroupSize(code.groupSize) + } + guard dtype.isFloatingPoint else { + throw TurboQuantError.invalidMetalConfiguration( + "decode output dtype must be floating point") + } + + let threadGroupSize = Swift.max(1, Swift.min(code.valueCount, 256)) + let configuration = TurboQuantConfiguration( + preset: code.preset, + role: code.role, + groupSize: code.groupSize, + backend: .metalPolarQJL, + seed: code.seed, + valueBits: code.valueBits + ) + let outputs = TurboQuantMetalKernels.decode( + [ + code.packedMagnitudes, + code.signs, + code.highPrecisionMask, + code.residualSigns, + code.scales, + ], + template: metalTemplate( + configuration: configuration, + valueCount: code.valueCount, + groupCount: code.groupCount, + magnitudeWordsPerGroup: code.magnitudeWordsPerGroup, + bitsetWordsPerGroup: code.bitsetWordsPerGroup + ), + grid: (code.valueCount, 1, 1), + threadGroup: (threadGroupSize, 1, 1), + outputShapes: [code.shape], + outputDTypes: [dtype], + stream: stream + ) + + return outputs[0] +} + +public func turboQuantMetalMM( + _ x: MLXArray, + _ code: TurboQuantMetalCode, + transpose: Bool = true, + outputDType: DType? = nil, + stream: StreamOrDevice = .gpu +) throws -> MLXArray { + try requireTurboQuantMetalCodec() + guard x.ndim == 2 else { + throw TurboQuantError.invalidMetalConfiguration( + "mixed-bit matmul input must have shape [M, K]" + ) + } + guard code.shape.count == 2 else { + throw TurboQuantError.invalidMetalConfiguration( + "mixed-bit matmul weight code must have shape [N, K] or [K, N]" + ) + } + guard x.dtype.isFloatingPoint else { + throw TurboQuantError.invalidMetalConfiguration("mixed-bit matmul input must be floating point") + } + guard (outputDType ?? x.dtype).isFloatingPoint else { + throw TurboQuantError.invalidMetalConfiguration( + "mixed-bit matmul output dtype must be floating point") + } + + let xRows = x.dim(0) + let xColumns = x.dim(1) + let weightRows = code.shape[0] + let weightColumns = code.shape[1] + let outputColumns: Int + if transpose { + guard xColumns == weightColumns else { + throw TurboQuantError.invalidMetalConfiguration( + "transpose matmul expects x columns \(xColumns) to match encoded weight columns \(weightColumns)" + ) + } + outputColumns = weightRows + } else { + guard xColumns == weightRows else { + throw TurboQuantError.invalidMetalConfiguration( + "matmul expects x columns \(xColumns) to match encoded weight rows \(weightRows)" + ) + } + outputColumns = weightColumns + } + + let outputShape = [xRows, outputColumns] + let elementCount = outputShape.reduce(1, *) + let configuration = TurboQuantConfiguration( + preset: code.preset, + role: code.role, + groupSize: code.groupSize, + backend: .metalPolarQJL, + seed: code.seed, + valueBits: code.valueBits + ) + return TurboQuantMetalKernels.matmul( + [ + x, + code.packedMagnitudes, + code.signs, + code.highPrecisionMask, + code.residualSigns, + code.scales, + ], + template: metalTemplate( + configuration: configuration, + valueCount: code.valueCount, + groupCount: code.groupCount, + magnitudeWordsPerGroup: code.magnitudeWordsPerGroup, + bitsetWordsPerGroup: code.bitsetWordsPerGroup + ) + [ + ("X_ROWS", xRows), + ("X_COLUMNS", xColumns), + ("WEIGHT_ROWS", weightRows), + ("WEIGHT_COLUMNS", weightColumns), + ("TRANSPOSE_WEIGHT", transpose), + ], + grid: (elementCount, 1, 1), + threadGroup: (Swift.max(1, Swift.min(elementCount, 256)), 1, 1), + outputShapes: [outputShape], + outputDTypes: [outputDType ?? x.dtype], + stream: stream + )[0] +} + +public func turboQuantEmptyAttentionCode( + layout: TurboQuantAttentionLayout, + preset: TurboQuantPreset = .turbo3_5, + role: TurboQuantTensorRole, + groupSize: Int = 64, + seed: UInt64 = 0x9E37_79B9_7F4A_7C15, + valueBits: Int? = nil +) throws -> TurboQuantAttentionCode { + try validateAttentionLayout(layout, role: role, groupSize: groupSize) + let resolvedValueBits = valueBits ?? preset.defaultValueBits + let bitsetShape = [ + layout.batchSize, layout.kvHeadCount, layout.capacity, + layout.groupsPerVector, layout.bitsetWordsPerGroup, + ] + let scalesPerGroup = metalScalesPerGroup(role: role) + return TurboQuantAttentionCode( + layout: layout, + preset: preset, + role: role, + groupSize: groupSize, + seed: seed, + valueBits: resolvedValueBits, + scalesPerGroup: scalesPerGroup, + packedMagnitudes: MLXArray.zeros( + [ + layout.batchSize, layout.kvHeadCount, layout.capacity, + layout.groupsPerVector, layout.magnitudeWordsPerGroup, + ], + dtype: .uint32 + ), + signs: MLXArray.zeros(bitsetShape, dtype: .uint32), + highPrecisionMask: MLXArray.zeros(bitsetShape, dtype: .uint32), + residualSigns: MLXArray.zeros(bitsetShape, dtype: .uint32), + scales: MLXArray.zeros( + [ + layout.batchSize, layout.kvHeadCount, layout.capacity, + layout.groupsPerVector, scalesPerGroup, + ], + dtype: .float32 + ) + ) +} + +public func turboQuantAttentionLayout( + for array: MLXArray, + preset: TurboQuantPreset = .turbo3_5, + role: TurboQuantTensorRole = .key, + groupSize: Int = 64, + valueBits: Int? = nil, + capacity: Int? = nil, + logicalLength: Int? = nil, + ringOffset: Int = 0, + pinnedPrefixLength: Int = 0 +) throws -> TurboQuantAttentionLayout { + try validateAttentionArray(array, groupSize: groupSize) + return try turboQuantAttentionLayout( + shape: array.shape, + dtype: array.dtype, + preset: preset, + role: role, + groupSize: groupSize, + valueBits: valueBits, + capacity: capacity, + logicalLength: logicalLength, + ringOffset: ringOffset, + pinnedPrefixLength: pinnedPrefixLength + ) +} + +public func turboQuantAttentionLayout( + shape: [Int], + dtype: DType = .float32, + preset: TurboQuantPreset = .turbo3_5, + role: TurboQuantTensorRole = .key, + groupSize: Int = 64, + valueBits: Int? = nil, + capacity: Int? = nil, + logicalLength: Int? = nil, + ringOffset: Int = 0, + pinnedPrefixLength: Int = 0 +) throws -> TurboQuantAttentionLayout { + try validateAttentionShape(shape, dtype: dtype, groupSize: groupSize) + let headDimension = shape[3] + let groupsPerVector = (headDimension + groupSize - 1) / groupSize + let resolvedCapacity = capacity ?? shape[2] + let resolvedLogicalLength = logicalLength ?? shape[2] + let layout = TurboQuantAttentionLayout( + batchSize: shape[0], + kvHeadCount: shape[1], + capacity: resolvedCapacity, + logicalLength: resolvedLogicalLength, + ringOffset: ringOffset, + pinnedPrefixLength: pinnedPrefixLength, + headDimension: headDimension, + groupsPerVector: groupsPerVector, + magnitudeWordsPerGroup: metalMagnitudeWordsPerGroup( + groupSize: groupSize, + preset: preset, + role: role, + valueBits: valueBits ?? preset.defaultValueBits + ), + bitsetWordsPerGroup: (groupSize + 31) / 32 + ) + try validateAttentionLayout(layout, role: role, groupSize: groupSize) + return layout +} + +public func turboQuantMetalEncodeAttention( + _ array: MLXArray, + configuration: TurboQuantConfiguration = TurboQuantConfiguration( + role: .key, + backend: .metalPolarQJL + ), + capacity: Int? = nil, + logicalLength: Int? = nil, + ringOffset: Int = 0, + pinnedPrefixLength: Int = 0, + stream: StreamOrDevice = .gpu +) throws -> TurboQuantAttentionCode { + try validateAttentionArray(array, groupSize: configuration.groupSize) + if configuration.role == .value { + try validateTurboQuantValueBits(configuration.resolvedValueBits) + } + try requireTurboQuantMetalAttention() + + let layout = try turboQuantAttentionLayout( + for: array, + preset: configuration.preset, + role: configuration.role, + groupSize: configuration.groupSize, + valueBits: configuration.resolvedValueBits, + capacity: capacity, + logicalLength: logicalLength, + ringOffset: ringOffset, + pinnedPrefixLength: pinnedPrefixLength + ) + guard layout.logicalLength <= layout.capacity else { + throw TurboQuantError.invalidMetalConfiguration( + "logical length cannot exceed compressed attention capacity" + ) + } + + let rowGroupCount = + layout.batchSize * layout.kvHeadCount + * array.dim(2) * layout.groupsPerVector + let bitsetShape = [ + layout.batchSize, layout.kvHeadCount, layout.capacity, + layout.groupsPerVector, layout.bitsetWordsPerGroup, + ] + let scalesPerGroup = metalScalesPerGroup(role: configuration.role) + let outputs = TurboQuantMetalKernels.encodeAttention( + [array], + template: attentionTemplate( + configuration: configuration, + layout: layout, + inputLength: array.dim(2), + outputLength: array.dim(2), + queryHeadCount: 0, + queryLength: 0, + outputDType: .float32, + causal: false + ), + grid: (rowGroupCount, 1, 1), + threadGroup: (Swift.max(1, Swift.min(rowGroupCount, 256)), 1, 1), + outputShapes: [ + [ + layout.batchSize, layout.kvHeadCount, layout.capacity, + layout.groupsPerVector, layout.magnitudeWordsPerGroup, + ], + bitsetShape, + bitsetShape, + bitsetShape, + [ + layout.batchSize, layout.kvHeadCount, layout.capacity, + layout.groupsPerVector, scalesPerGroup, + ], + ], + outputDTypes: [.uint32, .uint32, .uint32, .uint32, .float32], + initValue: 0, + stream: stream + ) + + return TurboQuantAttentionCode( + layout: layout, + preset: configuration.preset, + role: configuration.role, + groupSize: configuration.groupSize, + seed: configuration.seed, + valueBits: configuration.resolvedValueBits, + scalesPerGroup: scalesPerGroup, + packedMagnitudes: outputs[0], + signs: outputs[1], + highPrecisionMask: outputs[2], + residualSigns: outputs[3], + scales: outputs[4] + ) +} + +public func turboQuantMetalDecodeAttention( + _ code: TurboQuantAttentionCode, + outputDType: DType = .float32, + stream: StreamOrDevice = .gpu +) throws -> MLXArray { + try validateAttentionLayout(code.layout, role: code.role, groupSize: code.groupSize) + try requireTurboQuantMetalAttention() + + let outputShape = code.layout.logicalShape + let elementCount = outputShape.reduce(1, *) + return TurboQuantMetalKernels.decodeAttention( + [ + code.packedMagnitudes, + code.signs, + code.highPrecisionMask, + code.residualSigns, + code.scales, + ], + template: attentionTemplate( + configuration: TurboQuantConfiguration( + preset: code.preset, + role: code.role, + groupSize: code.groupSize, + backend: .metalPolarQJL, + seed: code.seed, + valueBits: code.valueBits + ), + layout: code.layout, + inputLength: code.layout.logicalLength, + outputLength: code.layout.logicalLength, + queryHeadCount: 0, + queryLength: 0, + outputDType: outputDType, + causal: false + ), + grid: (elementCount, 1, 1), + threadGroup: (Swift.max(1, Swift.min(elementCount, 256)), 1, 1), + outputShapes: [outputShape], + outputDTypes: [outputDType], + stream: stream + )[0] +} + +public func turboQuantMetalQK( + queries: MLXArray, + keyCode: TurboQuantAttentionCode, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode = .none, + stream: StreamOrDevice = .gpu +) throws -> MLXArray { + try validateAttentionQuery(queries, code: keyCode) + try requireTurboQuantMetalAttention() + guard keyCode.role == .key else { + throw TurboQuantError.invalidMetalConfiguration("QK requires a key code") + } + + let outputShape = [ + queries.dim(0), queries.dim(1), queries.dim(2), keyCode.layout.logicalLength, + ] + let elementCount = outputShape.reduce(1, *) + var scores = TurboQuantMetalKernels.qk( + [ + queries, + keyCode.packedMagnitudes, + keyCode.signs, + keyCode.highPrecisionMask, + keyCode.residualSigns, + keyCode.scales, + ], + template: attentionTemplate( + configuration: TurboQuantConfiguration( + preset: keyCode.preset, + role: keyCode.role, + groupSize: keyCode.groupSize, + backend: .metalPolarQJL, + seed: keyCode.seed, + valueBits: keyCode.valueBits + ), + layout: keyCode.layout, + inputLength: keyCode.layout.logicalLength, + outputLength: keyCode.layout.logicalLength, + queryHeadCount: queries.dim(1), + queryLength: queries.dim(2), + outputDType: .float32, + causal: false + ) + [("ATTENTION_SCALE_BITS", scale.bitPattern)], + grid: (elementCount, 1, 1), + threadGroup: (Swift.max(1, Swift.min(elementCount, 256)), 1, 1), + outputShapes: [outputShape], + outputDTypes: [.float32], + stream: stream + )[0] + + applyAttentionMask(&scores, mask: mask, stream: stream) + return scores +} + +public func turboQuantMetalAV( + attentionWeights: MLXArray, + valueCode: TurboQuantAttentionCode, + outputDType: DType = .float32, + stream: StreamOrDevice = .gpu +) throws -> MLXArray { + try requireTurboQuantMetalAttention() + guard valueCode.role == .value else { + throw TurboQuantError.invalidMetalConfiguration("AV requires a value code") + } + guard attentionWeights.ndim == 4 else { + throw TurboQuantError.invalidMetalConfiguration("attention weights must be [B, Hq, L, T]") + } + guard attentionWeights.dim(0) == valueCode.layout.batchSize, + attentionWeights.dim(3) == valueCode.layout.logicalLength + else { + throw TurboQuantError.invalidMetalConfiguration( + "attention weights do not match the compressed value layout" + ) + } + guard attentionWeights.dim(1) % valueCode.layout.kvHeadCount == 0 else { + throw TurboQuantError.invalidMetalConfiguration( + "query heads must be a multiple of KV heads" + ) + } + + let outputShape = [ + attentionWeights.dim(0), attentionWeights.dim(1), attentionWeights.dim(2), + valueCode.layout.headDimension, + ] + let elementCount = outputShape.reduce(1, *) + return TurboQuantMetalKernels.av( + [ + attentionWeights, + valueCode.packedMagnitudes, + valueCode.signs, + valueCode.highPrecisionMask, + valueCode.residualSigns, + valueCode.scales, + ], + template: attentionTemplate( + configuration: TurboQuantConfiguration( + preset: valueCode.preset, + role: valueCode.role, + groupSize: valueCode.groupSize, + backend: .metalPolarQJL, + seed: valueCode.seed, + valueBits: valueCode.valueBits + ), + layout: valueCode.layout, + inputLength: valueCode.layout.logicalLength, + outputLength: valueCode.layout.logicalLength, + queryHeadCount: attentionWeights.dim(1), + queryLength: attentionWeights.dim(2), + outputDType: outputDType, + causal: false + ), + grid: (elementCount, 1, 1), + threadGroup: (Swift.max(1, Swift.min(elementCount, 256)), 1, 1), + outputShapes: [outputShape], + outputDTypes: [outputDType], + stream: stream + )[0] +} + +public func turboQuantMetalScaledDotProductAttention( + queries: MLXArray, + keyCode: TurboQuantAttentionCode, + valueCode: TurboQuantAttentionCode, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode = .none, + sinks: MLXArray? = nil, + preferOnlineFused: Bool = true, + kernelProfile: TurboQuantKernelProfile? = nil, + stream: StreamOrDevice = .gpu +) throws -> MLXArray { + try validateAttentionPair(keyCode: keyCode, valueCode: valueCode) + try validateAttentionQuery(queries, code: keyCode) + try validateAttentionSinks(sinks, queryHeadCount: queries.dim(1)) + try requireTurboQuantMetalAttention() + + if sinks == nil, + preferOnlineFused, + keyCode.layout.headDimension == valueCode.layout.headDimension, + keyCode.layout.groupsPerVector == valueCode.layout.groupsPerVector, + turboQuantMetalSupportsOnlineFusedAttention(queries: queries, keyCode: keyCode, mask: mask) + { + return try turboQuantMetalOnlineFusedAttention( + queries: queries, + keyCode: keyCode, + valueCode: valueCode, + scale: scale, + mask: mask, + kernelProfile: kernelProfile + ?? TurboQuantRuntimeProbe.shared.selectedKernelProfileWithoutRunningProbe(), + outputDType: queries.dtype, + stream: stream + ) + } + + let scores = try turboQuantMetalQK( + queries: queries, + keyCode: keyCode, + scale: scale, + mask: mask, + stream: stream + ) + var logits = scores.asType(.float32) + logits = try prependAttentionSinks( + logits, + sinks: sinks, + queryHeadCount: queries.dim(1), + stream: stream + ) + var weights = softmax(logits, axis: -1, stream: stream) + if sinks != nil { + weights = weights[.ellipsis, 1...] + } + return try turboQuantMetalAV( + attentionWeights: weights, + valueCode: valueCode, + outputDType: queries.dtype, + stream: stream + ) +} + +public func turboQuantMetalSupportsOnlineFusedAttention( + queries: MLXArray, + keyCode: TurboQuantAttentionCode, + mask: MLXFast.ScaledDotProductAttentionMaskMode = .none +) -> Bool { + turboQuantMetalSupportsOnlineFusedAttention( + queryShape: queries.shape, + keyCode: keyCode, + mask: mask + ) +} + +public func turboQuantMetalSupportsOnlineFusedAttention( + queryShape: [Int], + keyCode: TurboQuantAttentionCode, + mask: MLXFast.ScaledDotProductAttentionMaskMode = .none +) -> Bool { + turboQuantMetalSupportsOnlineFusedAttention( + queryShape: queryShape, + keyLayout: keyCode.layout, + mask: mask + ) +} + +public func turboQuantMetalSupportsOnlineFusedAttention( + queryShape: [Int], + keyLayout: TurboQuantAttentionLayout, + mask: MLXFast.ScaledDotProductAttentionMaskMode = .none +) -> Bool { + guard queryShape.count == 4 else { return false } + guard queryShape[0] == keyLayout.batchSize, queryShape[2] <= 8 else { return false } + guard [64, 80, 96, 128, 192, 256].contains(queryShape[3]) else { return false } + guard queryShape[3] == keyLayout.headDimension else { return false } + switch mask { + case .none, .causal: + return true + case .array, .arrays: + return false + } +} + +private func turboQuantMetalOnlineFusedAttention( + queries: MLXArray, + keyCode: TurboQuantAttentionCode, + valueCode: TurboQuantAttentionCode, + scale: Float, + mask: MLXFast.ScaledDotProductAttentionMaskMode, + kernelProfile: TurboQuantKernelProfile, + outputDType: DType, + stream: StreamOrDevice +) throws -> MLXArray { + let outputShape = [queries.dim(0), queries.dim(1), queries.dim(2), queries.dim(3)] + let rowCount = queries.dim(0) * queries.dim(1) * queries.dim(2) + let threadgroupWidth = min(256, max(1, kernelProfile.fusedDecodeThreadgroupWidth)) + let causal: Bool + switch mask { + case .causal: + causal = true + case .none: + causal = false + case .array, .arrays: + throw TurboQuantError.invalidMetalConfiguration( + "online fused TurboQuant attention does not support materialized masks" + ) + } + + return TurboQuantMetalKernels.fusedAttention( + [ + queries, + keyCode.packedMagnitudes, + keyCode.signs, + keyCode.highPrecisionMask, + keyCode.residualSigns, + keyCode.scales, + valueCode.packedMagnitudes, + valueCode.signs, + valueCode.highPrecisionMask, + valueCode.residualSigns, + valueCode.scales, + ], + template: attentionTemplate( + configuration: TurboQuantConfiguration( + preset: keyCode.preset, + role: .key, + groupSize: keyCode.groupSize, + backend: .metalPolarQJL, + seed: keyCode.seed, + valueBits: valueCode.valueBits + ), + layout: keyCode.layout, + inputLength: keyCode.layout.logicalLength, + outputLength: keyCode.layout.logicalLength, + queryHeadCount: queries.dim(1), + queryLength: queries.dim(2), + outputDType: outputDType, + causal: causal + ) + [ + ("VALUE_SEED_HI", metalTemplateUInt32High(valueCode.seed)), + ("VALUE_SEED_LO", metalTemplateUInt32Low(valueCode.seed)), + ("VALUE_MAG_WORDS_PER_GROUP", valueCode.layout.magnitudeWordsPerGroup), + ("VALUE_SCALES_PER_GROUP", valueCode.scalesPerGroup), + ("ATTENTION_SCALE_BITS", scale.bitPattern), + ("THREADS_PER_ROW", threadgroupWidth), + ], + grid: (rowCount * threadgroupWidth, 1, 1), + threadGroup: (threadgroupWidth, 1, 1), + outputShapes: [outputShape], + outputDTypes: [outputDType], + stream: stream + )[0] +} + +public func requireTurboQuantBackend(_ backend: TurboQuantBackend) throws { + let availability = TurboQuantKernelAvailability.current + guard availability.supports(backend) else { + throw TurboQuantError.unsupportedBackend( + backend, + availability.fallbackReason(for: backend) ?? "Backend unavailable." + ) + } +} + +public func requireTurboQuantMetalAttention() throws { + guard metalRuntimeAvailable() else { + throw TurboQuantError.unsupportedBackend( + .metalPolarQJL, + "Metal runtime is unavailable for PolarQuant/QJL compressed attention." + ) + } + guard !TurboQuantRuntimeProbe.shared.isRunningSelfTest() else { return } + let probe = TurboQuantRuntimeProbe.shared.result() + guard probe.passed else { + throw TurboQuantError.unsupportedBackend( + .metalPolarQJL, + probe.failureReason ?? "PolarQuant/QJL compressed attention self-test has not passed." + ) + } +} + +public func requireTurboQuantMetalCodec() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLCodec else { + throw TurboQuantError.unsupportedBackend( + .metalPolarQJL, + "Metal runtime is unavailable for the PolarQuant/QJL codec." + ) + } +} + +private func encodeTurboQuantReference( + values: [Float], + shape: [Int], + configuration: TurboQuantConfiguration +) throws -> TurboQuantReferenceCode { + let expectedCount = shape.reduce(1, *) + guard expectedCount == values.count else { + throw TurboQuantError.invalidReferenceCode( + "shape \(shape) contains \(expectedCount) values but input has \(values.count)" + ) + } + + if configuration.role == .value { + return try encodeTurboQuantAffineValueReference( + values: values, + shape: shape, + configuration: configuration + ) + } + + if configuration.role == .key { + return try encodeTurboQuantProductReference( + values: values, + shape: shape, + configuration: configuration + ) + } + + let groupSize = configuration.groupSize + let baseBits = configuration.preset.baseMagnitudeBits + let highBits = configuration.preset.highMagnitudeBits + let groupCount = (values.count + groupSize - 1) / groupSize + var baseScales = Array(repeating: Float(1), count: groupCount) + var highScales = Array(repeating: Float(1), count: groupCount) + var residualScales = Array(repeating: Float(0), count: groupCount) + var signs = [UInt8](repeating: 0, count: packedBitByteCount(values.count)) + var highPrecisionMask = [UInt8](repeating: 0, count: packedBitByteCount(values.count)) + var residualSigns = [UInt8](repeating: 0, count: packedBitByteCount(values.count)) + var magnitudes = [UInt8]() + var magnitudeBitOffset = 0 + + for groupIndex in 0 ..< groupCount { + let start = groupIndex * groupSize + let end = Swift.min(start + groupSize, values.count) + let count = end - start + guard count > 0 else { continue } + + var transformed = Array(repeating: Float(0), count: count) + var maxAbs = Float(0) + for localIndex in 0 ..< count { + let absoluteIndex = start + localIndex + let value = preconditionedValue( + values[absoluteIndex], + index: absoluteIndex, + seed: configuration.seed + ) + transformed[localIndex] = value + maxAbs = Swift.max(maxAbs, Swift.abs(value)) + } + + let baseMax = Float((1 << baseBits) - 1) + let highMax = Float((1 << highBits) - 1) + let safeMaxAbs = Swift.max(maxAbs, Float.leastNonzeroMagnitude) + baseScales[groupIndex] = safeMaxAbs / baseMax + highScales[groupIndex] = safeMaxAbs / highMax + + let highPrecisionCount = mixedPrecisionHighCount( + valueCount: count, + baseBits: baseBits, + highBits: highBits, + targetBits: configuration.preset.targetMagnitudeBits + ) + var highPrecisionIndices = Set() + if highPrecisionCount > 0 { + let ranked = transformed.indices.sorted { lhs, rhs in + let leftMagnitude = Swift.abs(transformed[lhs]) + let rightMagnitude = Swift.abs(transformed[rhs]) + if leftMagnitude == rightMagnitude { + return lhs < rhs + } + return leftMagnitude > rightMagnitude + } + highPrecisionIndices = Set(ranked.prefix(highPrecisionCount)) + } + + var residuals = Array(repeating: Float(0), count: count) + var residualMagnitudeSum = Float(0) + for localIndex in 0 ..< count { + let value = transformed[localIndex] + let highPrecision = highPrecisionIndices.contains(localIndex) + let bits = highPrecision ? highBits : baseBits + let scale = highPrecision ? highScales[groupIndex] : baseScales[groupIndex] + let levelMax = Float((1 << bits) - 1) + let magnitude = Swift.abs(value) + let quantizedMagnitude = UInt8( + Swift.max(0, Swift.min(Int((magnitude / scale).rounded()), Int(levelMax))) + ) + let signedDecoded = (value.sign == .minus ? -1 : 1) * Float(quantizedMagnitude) * scale + let residual = value - signedDecoded + residuals[localIndex] = residual + residualMagnitudeSum += Swift.abs(residual) + } + if configuration.role != .value { + residualScales[groupIndex] = residualMagnitudeSum / Float(count) + } + + for localIndex in 0 ..< count { + let absoluteIndex = start + localIndex + let value = transformed[localIndex] + let highPrecision = highPrecisionIndices.contains(localIndex) + let bits = highPrecision ? highBits : baseBits + let scale = highPrecision ? highScales[groupIndex] : baseScales[groupIndex] + let levelMax = Float((1 << bits) - 1) + let magnitude = Swift.abs(value) + let quantizedMagnitude = UInt8( + Swift.max(0, Swift.min(Int((magnitude / scale).rounded()), Int(levelMax))) + ) + setPackedBit(&signs, index: absoluteIndex, value: value.sign == .minus) + setPackedBit(&highPrecisionMask, index: absoluteIndex, value: highPrecision) + if configuration.role != .value { + setPackedBit( + &residualSigns, index: absoluteIndex, + value: residuals[localIndex].sign == .minus) + } + appendPackedBits( + UInt32(quantizedMagnitude), + bitCount: bits, + bytes: &magnitudes, + bitOffset: &magnitudeBitOffset + ) + } + } + + if configuration.role == .value { + residualSigns.removeAll(keepingCapacity: false) + } + + return TurboQuantReferenceCode( + shape: shape, + preset: configuration.preset, + role: configuration.role, + format: .magnitudeResidualSign, + groupSize: groupSize, + seed: configuration.seed, + residualScale: configuration.qjlResidualScale, + baseMagnitudeBits: baseBits, + highMagnitudeBits: highBits, + valueCount: values.count, + baseScales: baseScales, + highScales: highScales, + residualScales: residualScales, + signs: Data(signs), + highPrecisionMask: Data(highPrecisionMask), + residualSigns: Data(residualSigns), + packedMagnitudes: Data(magnitudes) + ) +} + +private func decodeTurboQuantReference(_ code: TurboQuantReferenceCode) throws -> [Float] { + switch code.format { + case .affineValue: + return try decodeTurboQuantAffineValueReference(code) + case .turboQuantProd: + return try decodeTurboQuantProductReference(code) + case .magnitudeResidualSign: + break + } + + guard code.groupSize > 0 else { + throw TurboQuantError.invalidGroupSize(code.groupSize) + } + guard code.shape.reduce(1, *) == code.valueCount else { + throw TurboQuantError.invalidReferenceCode( + "shape \(code.shape) does not match value count \(code.valueCount)" + ) + } + + let groupCount = (code.valueCount + code.groupSize - 1) / code.groupSize + guard code.baseScales.count == groupCount, code.highScales.count == groupCount else { + throw TurboQuantError.invalidReferenceCode("scale table count does not match groups") + } + guard code.residualScales.isEmpty || code.residualScales.count == groupCount else { + throw TurboQuantError.invalidReferenceCode( + "residual scale table count does not match groups") + } + guard code.signs.count >= packedBitByteCount(code.valueCount), + code.highPrecisionMask.count >= packedBitByteCount(code.valueCount) + else { + throw TurboQuantError.invalidReferenceCode("bitset storage is truncated") + } + if code.role != .value && code.residualSigns.count < packedBitByteCount(code.valueCount) { + throw TurboQuantError.invalidReferenceCode("residual sign storage is truncated") + } + + var values = Array(repeating: Float(0), count: code.valueCount) + var magnitudeBitOffset = 0 + + for groupIndex in 0 ..< groupCount { + let start = groupIndex * code.groupSize + let end = Swift.min(start + code.groupSize, code.valueCount) + for absoluteIndex in start ..< end { + let highPrecision = getPackedBit(code.highPrecisionMask, index: absoluteIndex) + let bits = highPrecision ? code.highMagnitudeBits : code.baseMagnitudeBits + let scale = highPrecision ? code.highScales[groupIndex] : code.baseScales[groupIndex] + let magnitude = Float( + try readPackedBits( + code.packedMagnitudes, + bitOffset: &magnitudeBitOffset, + bitCount: bits + ) + ) + let sign: Float = getPackedBit(code.signs, index: absoluteIndex) ? -1 : 1 + var reconstructed = sign * magnitude * scale + + if code.role != .value { + let residualSign: Float = + getPackedBit(code.residualSigns, index: absoluteIndex) ? -1 : 1 + let residualScale = + code.residualScales.isEmpty + ? code.residualScale * scale + : code.residualScales[groupIndex] + reconstructed += residualSign * residualScale + } + + values[absoluteIndex] = unpreconditionedValue( + reconstructed, + index: absoluteIndex, + seed: code.seed + ) + } + } + + return values +} + +private func encodeTurboQuantAffineValueReference( + values: [Float], + shape: [Int], + configuration: TurboQuantConfiguration +) throws -> TurboQuantReferenceCode { + let groupSize = configuration.groupSize + let valueBits = configuration.resolvedValueBits + try validateTurboQuantValueBits(valueBits) + + let groupCount = (values.count + groupSize - 1) / groupSize + var scales = Array(repeating: Float(0), count: groupCount) + var zeros = Array(repeating: Float(0), count: groupCount) + var packed = [UInt8]() + var bitOffset = 0 + let levelMax = Float((1 << valueBits) - 1) + + for groupIndex in 0 ..< groupCount { + let start = groupIndex * groupSize + let end = Swift.min(start + groupSize, values.count) + guard start < end else { continue } + + var minimum = Float.greatestFiniteMagnitude + var maximum = -Float.greatestFiniteMagnitude + for index in start ..< end { + minimum = Swift.min(minimum, values[index]) + maximum = Swift.max(maximum, values[index]) + } + + let range = maximum - minimum + let scale = range > Float.leastNonzeroMagnitude ? range / levelMax : 0 + scales[groupIndex] = scale + zeros[groupIndex] = minimum + + for index in start ..< end { + let quantized: UInt32 + if scale == 0 { + quantized = 0 + } else { + quantized = UInt32( + Swift.max( + 0, + Swift.min( + Int(((values[index] - minimum) / scale).rounded()), + Int(levelMax) + ) + ) + ) + } + appendPackedBits( + quantized, + bitCount: valueBits, + bytes: &packed, + bitOffset: &bitOffset + ) + } + } + + return TurboQuantReferenceCode( + shape: shape, + preset: configuration.preset, + role: configuration.role, + format: .affineValue, + groupSize: groupSize, + seed: configuration.seed, + residualScale: configuration.qjlResidualScale, + baseMagnitudeBits: valueBits, + highMagnitudeBits: valueBits, + valueCount: values.count, + baseScales: scales, + highScales: zeros, + residualScales: [], + signs: Data(), + highPrecisionMask: Data(), + residualSigns: Data(), + packedMagnitudes: Data(packed) + ) +} + +private func decodeTurboQuantAffineValueReference(_ code: TurboQuantReferenceCode) throws + -> [Float] +{ + guard code.groupSize > 0 else { + throw TurboQuantError.invalidGroupSize(code.groupSize) + } + try validateTurboQuantValueBits(code.baseMagnitudeBits) + let groupCount = (code.valueCount + code.groupSize - 1) / code.groupSize + guard code.baseScales.count == groupCount, code.highScales.count == groupCount else { + throw TurboQuantError.invalidReferenceCode("affine value scale table count mismatch") + } + + var values = Array(repeating: Float(0), count: code.valueCount) + var bitOffset = 0 + for groupIndex in 0 ..< groupCount { + let start = groupIndex * code.groupSize + let end = Swift.min(start + code.groupSize, code.valueCount) + let scale = code.baseScales[groupIndex] + let zero = code.highScales[groupIndex] + for index in start ..< end { + let quantized = try readPackedBits( + code.packedMagnitudes, + bitOffset: &bitOffset, + bitCount: code.baseMagnitudeBits + ) + values[index] = zero + Float(quantized) * scale + } + } + return values +} + +private func encodeTurboQuantProductReference( + values: [Float], + shape: [Int], + configuration: TurboQuantConfiguration +) throws -> TurboQuantReferenceCode { + let groupSize = configuration.groupSize + let baseBits = Swift.max(1, configuration.preset.baseMagnitudeBits - 1) + let highBits = Swift.max(baseBits, configuration.preset.highMagnitudeBits - 1) + let targetBits = Swift.max(1, configuration.preset.targetMagnitudeBits - 1) + let groupCount = (values.count + groupSize - 1) / groupSize + var norms = Array(repeating: Float(0), count: groupCount) + var residualNorms = Array(repeating: Float(0), count: groupCount) + var qjlSigns = [UInt8](repeating: 0, count: packedBitByteCount(values.count)) + var packed = [UInt8]() + var bitOffset = 0 + + for groupIndex in 0 ..< groupCount { + let start = groupIndex * groupSize + let end = Swift.min(start + groupSize, values.count) + let count = end - start + guard count > 0 else { continue } + + var group = Array(values[start ..< end]) + let norm = sqrt(group.reduce(Float(0)) { $0 + $1 * $1 }) + norms[groupIndex] = norm + if norm > Float.leastNonzeroMagnitude { + for index in group.indices { + group[index] /= norm + } + } + + let rotated = applyTurboQuantRotation( + group, + seed: configuration.seed, + groupIndex: groupIndex, + inverse: false + ) + let highCount = mixedPrecisionHighCount( + valueCount: count, + baseBits: baseBits, + highBits: highBits, + targetBits: targetBits + ) + let highMask = productHighPrecisionMask( + valueCount: count, + highCount: highCount, + seed: configuration.seed, + groupIndex: groupIndex + ) + var quantizedRotated = Array(repeating: Float(0), count: count) + + for localIndex in 0 ..< count { + let bits = highMask[localIndex] ? highBits : baseBits + let codebook = turboQuantLloydMaxCodebook( + bits: bits, + coordinateStdDev: 1 / sqrt(Float(count)) + ) + let codeIndex = nearestCodebookIndex(rotated[localIndex], codebook: codebook) + quantizedRotated[localIndex] = codebook[codeIndex] + appendPackedBits( + UInt32(codeIndex), + bitCount: bits, + bytes: &packed, + bitOffset: &bitOffset + ) + } + + var residualSquared = Float(0) + for localIndex in 0 ..< count { + let residual = rotated[localIndex] - quantizedRotated[localIndex] + residualSquared += residual * residual + setPackedBit( + &qjlSigns, + index: start + localIndex, + value: residual.sign == .minus + ) + } + residualNorms[groupIndex] = norm * sqrt(residualSquared) + } + + return TurboQuantReferenceCode( + shape: shape, + preset: configuration.preset, + role: configuration.role, + format: .turboQuantProd, + groupSize: groupSize, + seed: configuration.seed, + residualScale: configuration.qjlResidualScale, + baseMagnitudeBits: baseBits, + highMagnitudeBits: highBits, + valueCount: values.count, + baseScales: norms, + highScales: residualNorms, + residualScales: [], + signs: Data(qjlSigns), + highPrecisionMask: Data(), + residualSigns: Data(), + packedMagnitudes: Data(packed) + ) +} + +private func decodeTurboQuantProductReference(_ code: TurboQuantReferenceCode) throws -> [Float] { + guard code.groupSize > 0 else { + throw TurboQuantError.invalidGroupSize(code.groupSize) + } + let groupCount = (code.valueCount + code.groupSize - 1) / code.groupSize + guard code.baseScales.count == groupCount, code.highScales.count == groupCount else { + throw TurboQuantError.invalidReferenceCode("TurboQuantProd norm table count mismatch") + } + + var values = Array(repeating: Float(0), count: code.valueCount) + var bitOffset = 0 + for groupIndex in 0 ..< groupCount { + let start = groupIndex * code.groupSize + let end = Swift.min(start + code.groupSize, code.valueCount) + let count = end - start + guard count > 0 else { continue } + + let highCount = mixedPrecisionHighCount( + valueCount: count, + baseBits: code.baseMagnitudeBits, + highBits: code.highMagnitudeBits, + targetBits: Swift.max(1, code.preset.targetMagnitudeBits - 1) + ) + let highMask = productHighPrecisionMask( + valueCount: count, + highCount: highCount, + seed: code.seed, + groupIndex: groupIndex + ) + var rotated = Array(repeating: Float(0), count: count) + for localIndex in 0 ..< count { + let bits = highMask[localIndex] ? code.highMagnitudeBits : code.baseMagnitudeBits + let codebook = turboQuantLloydMaxCodebook( + bits: bits, + coordinateStdDev: 1 / sqrt(Float(count)) + ) + let codeIndex = Int( + try readPackedBits( + code.packedMagnitudes, + bitOffset: &bitOffset, + bitCount: bits + ) + ) + guard codeIndex < codebook.count else { + throw TurboQuantError.invalidReferenceCode("TurboQuantProd codebook index overflow") + } + rotated[localIndex] = codebook[codeIndex] + } + + let unrotated = applyTurboQuantRotation( + rotated, + seed: code.seed, + groupIndex: groupIndex, + inverse: true + ) + let norm = code.baseScales[groupIndex] + for localIndex in 0 ..< count { + values[start + localIndex] = unrotated[localIndex] * norm + } + } + return values +} + +private func turboQuantProductInnerProduct(query: [Float], code: TurboQuantReferenceCode) throws + -> Float +{ + let groupCount = (code.valueCount + code.groupSize - 1) / code.groupSize + guard code.baseScales.count == groupCount, code.highScales.count == groupCount else { + throw TurboQuantError.invalidReferenceCode("TurboQuantProd norm table count mismatch") + } + guard code.signs.count >= packedBitByteCount(code.valueCount) else { + throw TurboQuantError.invalidReferenceCode("TurboQuantProd QJL sign storage is truncated") + } + + var total = Float(0) + var bitOffset = 0 + for groupIndex in 0 ..< groupCount { + let start = groupIndex * code.groupSize + let end = Swift.min(start + code.groupSize, code.valueCount) + let count = end - start + guard count > 0 else { continue } + + let highCount = mixedPrecisionHighCount( + valueCount: count, + baseBits: code.baseMagnitudeBits, + highBits: code.highMagnitudeBits, + targetBits: Swift.max(1, code.preset.targetMagnitudeBits - 1) + ) + let highMask = productHighPrecisionMask( + valueCount: count, + highCount: highCount, + seed: code.seed, + groupIndex: groupIndex + ) + var quantizedRotated = Array(repeating: Float(0), count: count) + for localIndex in 0 ..< count { + let bits = highMask[localIndex] ? code.highMagnitudeBits : code.baseMagnitudeBits + let codebook = turboQuantLloydMaxCodebook( + bits: bits, + coordinateStdDev: 1 / sqrt(Float(count)) + ) + let codeIndex = Int( + try readPackedBits( + code.packedMagnitudes, + bitOffset: &bitOffset, + bitCount: bits + ) + ) + guard codeIndex < codebook.count else { + throw TurboQuantError.invalidReferenceCode("TurboQuantProd codebook index overflow") + } + quantizedRotated[localIndex] = codebook[codeIndex] + } + + let queryRotated = applyTurboQuantRotation( + Array(query[start ..< end]), + seed: code.seed, + groupIndex: groupIndex, + inverse: false + ) + let norm = code.baseScales[groupIndex] + for localIndex in 0 ..< count { + total += norm * quantizedRotated[localIndex] * queryRotated[localIndex] + } + + let residualNorm = code.highScales[groupIndex] + if residualNorm > 0 { + var signDot = Float(0) + for localIndex in 0 ..< count { + let sign: Float = + getPackedBit(code.signs, index: start + localIndex) ? -1 : 1 + signDot += sign * queryRotated[localIndex] + } + total += residualNorm * sqrt(Float.pi / (2 * Float(count))) * signDot + } + } + return total +} + +private func turboQuantQuality( + original: [Float], + decoded: [Float], + seed: UInt64, + thresholds: TurboQuantQualityThresholds +) throws -> TurboQuantQualityReport { + guard !original.isEmpty else { + throw TurboQuantError.invalidQualityInput("quality input must not be empty") + } + guard original.count == decoded.count else { + throw TurboQuantError.invalidQualityInput("original and decoded counts differ") + } + + var squaredError = Float(0) + var squaredSignal = Float(0) + var maxAbsoluteError = Float(0) + var dot = Float(0) + var originalNormSquared = Float(0) + var decodedNormSquared = Float(0) + var probeOriginalDot = Float(0) + var probeDecodedDot = Float(0) + + for index in original.indices { + let lhs = original[index] + let rhs = decoded[index] + let delta = lhs - rhs + squaredError += delta * delta + squaredSignal += lhs * lhs + maxAbsoluteError = Swift.max(maxAbsoluteError, Swift.abs(delta)) + dot += lhs * rhs + originalNormSquared += lhs * lhs + decodedNormSquared += rhs * rhs + + let probe = deterministicProbeValue(index: index, seed: seed) + probeOriginalDot += probe * lhs + probeDecodedDot += probe * rhs + } + + let count = Float(original.count) + let mse = squaredError / count + let relativeMSE = squaredError / Swift.max(squaredSignal, Float.leastNonzeroMagnitude) + let cosineDenominator = sqrt(originalNormSquared) * sqrt(decodedNormSquared) + let cosineSimilarity = dot / Swift.max(cosineDenominator, Float.leastNonzeroMagnitude) + let innerProductRelativeError = + Swift.abs(probeOriginalDot - probeDecodedDot) + / Swift.max(Swift.abs(probeOriginalDot), Float.leastNonzeroMagnitude) + + return TurboQuantQualityReport( + mse: mse, + relativeMSE: relativeMSE, + maxAbsoluteError: maxAbsoluteError, + cosineSimilarity: cosineSimilarity, + innerProductRelativeError: innerProductRelativeError, + thresholds: thresholds + ) +} + +private func deterministicProbeValue(index: Int, seed: UInt64) -> Float { + var state = seed ^ 0xD1B5_4A32_D192_ED03 + state &+= UInt64(index) &* 0x9E37_79B9_7F4A_7C15 + state ^= state >> 30 + state &*= 0xBF58_476D_1CE4_E5B9 + state ^= state >> 27 + state &*= 0x94D0_49BB_1331_11EB + state ^= state >> 31 + let unit = Float(UInt32(truncatingIfNeeded: state)) / Float(UInt32.max) + return unit * 2 - 1 +} + +private func mixedPrecisionHighCount( + valueCount: Int, + baseBits: Int, + highBits: Int, + targetBits: Float +) -> Int { + guard highBits > baseBits else { return 0 } + let fraction = (targetBits - Float(baseBits)) / Float(highBits - baseBits) + let clampedFraction = Swift.max(0, Swift.min(1, fraction)) + return Int((Float(valueCount) * clampedFraction).rounded()) +} + +private func validateTurboQuantValueBits(_ bits: Int) throws { + guard (2 ... 8).contains(bits) else { + throw TurboQuantError.invalidReferenceCode( + "TurboQuant value bits must be in 2...8, got \(bits)" + ) + } +} + +private func productHighPrecisionMask( + valueCount: Int, + highCount: Int, + seed: UInt64, + groupIndex: Int +) -> [Bool] { + guard highCount > 0 else { return Array(repeating: false, count: valueCount) } + guard highCount < valueCount else { return Array(repeating: true, count: valueCount) } + + let ranked = (0 ..< valueCount).sorted { lhs, rhs in + let lhsRank = productChannelRank(seed: seed, groupIndex: groupIndex, localIndex: lhs) + let rhsRank = productChannelRank(seed: seed, groupIndex: groupIndex, localIndex: rhs) + if lhsRank == rhsRank { + return lhs < rhs + } + return lhsRank < rhsRank + } + var mask = Array(repeating: false, count: valueCount) + for index in ranked.prefix(highCount) { + mask[index] = true + } + return mask +} + +private func productChannelRank(seed: UInt64, groupIndex: Int, localIndex: Int) -> UInt64 { + var state = seed + state ^= UInt64(groupIndex) &* 0x9E37_79B9_7F4A_7C15 + state &+= UInt64(localIndex) &* 0xD1B5_4A32_D192_ED03 + state ^= state >> 30 + state &*= 0xBF58_476D_1CE4_E5B9 + state ^= state >> 27 + state &*= 0x94D0_49BB_1331_11EB + state ^= state >> 31 + return state +} + +private func turboQuantLloydMaxCodebook(bits: Int, coordinateStdDev: Float) -> [Float] { + let levelCount = Swift.max(2, 1 << bits) + let sigma = Swift.max(Double(coordinateStdDev), Double(Float.leastNonzeroMagnitude)) + var levels = (0 ..< levelCount).map { index -> Double in + let centered = (Double(index) + 0.5) / Double(levelCount) * 2 - 1 + return centered * 2.5 * sigma + } + + for _ in 0 ..< 16 { + var boundaries = Array(repeating: -Double.infinity, count: levelCount + 1) + boundaries[levelCount] = Double.infinity + if levelCount > 1 { + for index in 1 ..< levelCount { + boundaries[index] = (levels[index - 1] + levels[index]) * 0.5 + } + } + + for index in 0 ..< levelCount { + let lower = boundaries[index] / sigma + let upper = boundaries[index + 1] / sigma + let probability = normalCDF(upper) - normalCDF(lower) + if probability > 1e-12 { + levels[index] = sigma * (normalPDF(lower) - normalPDF(upper)) / probability + } + } + } + + return levels.map(Float.init) +} + +private func nearestCodebookIndex(_ value: Float, codebook: [Float]) -> Int { + var bestIndex = 0 + var bestDistance = Float.greatestFiniteMagnitude + for (index, level) in codebook.enumerated() { + let distance = Swift.abs(value - level) + if distance < bestDistance { + bestDistance = distance + bestIndex = index + } + } + return bestIndex +} + +private func normalPDF(_ x: Double) -> Double { + guard x.isFinite else { return 0 } + return exp(-0.5 * x * x) / sqrt(2 * Double.pi) +} + +private func normalCDF(_ x: Double) -> Double { + if x == Double.infinity { return 1 } + if x == -Double.infinity { return 0 } + return 0.5 * (1 + erf(x / sqrt(2))) +} + +private func applyTurboQuantRotation( + _ values: [Float], + seed: UInt64, + groupIndex: Int, + inverse: Bool +) -> [Float] { + guard values.count > 1 else { + return values.enumerated().map { localIndex, value in + randomSign(index: groupIndex &* 4099 &+ localIndex, seed: seed) ? -value : value + } + } + if isPowerOfTwo(values.count) { + return applyRandomizedHadamardRotation( + values, + seed: seed, + groupIndex: groupIndex, + inverse: inverse + ) + } + return applyDeterministicGivensRotation( + values, + seed: seed, + groupIndex: groupIndex, + inverse: inverse + ) +} + +private func isPowerOfTwo(_ value: Int) -> Bool { + value > 0 && (value & (value - 1)) == 0 +} + +private func applyRandomizedHadamardRotation( + _ values: [Float], + seed: UInt64, + groupIndex: Int, + inverse: Bool +) -> [Float] { + var result = values + if inverse { + fastHadamardTransform(&result) + applyRotationSigns(&result, seed: seed, groupIndex: groupIndex) + } else { + applyRotationSigns(&result, seed: seed, groupIndex: groupIndex) + fastHadamardTransform(&result) + } + let scale = 1 / sqrt(Float(values.count)) + for index in result.indices { + result[index] *= scale + } + return result +} + +private func fastHadamardTransform(_ values: inout [Float]) { + var width = 1 + while width < values.count { + var start = 0 + while start < values.count { + for offset in 0 ..< width { + let lhs = values[start + offset] + let rhs = values[start + offset + width] + values[start + offset] = lhs + rhs + values[start + offset + width] = lhs - rhs + } + start += width * 2 + } + width *= 2 + } +} + +private func applyRotationSigns(_ values: inout [Float], seed: UInt64, groupIndex: Int) { + for index in values.indices { + if randomSign(index: groupIndex &* 4099 &+ index, seed: seed) { + values[index] = -values[index] + } + } +} + +private func applyDeterministicGivensRotation( + _ values: [Float], + seed: UInt64, + groupIndex: Int, + inverse: Bool +) -> [Float] { + var result = values + let passes = Array(0 ..< 4) + let orderedPasses = inverse ? Array(passes.reversed()) : passes + for pass in orderedPasses { + let offset = pass % 2 + var index = offset + while index + 1 < result.count { + let angle = deterministicRotationAngle( + seed: seed, + groupIndex: groupIndex, + pass: pass, + pairIndex: index / 2 + ) * (inverse ? -1 : 1) + let c = cos(angle) + let s = sin(angle) + let lhs = result[index] + let rhs = result[index + 1] + result[index] = c * lhs - s * rhs + result[index + 1] = s * lhs + c * rhs + index += 2 + } + } + return result +} + +private func deterministicRotationAngle( + seed: UInt64, + groupIndex: Int, + pass: Int, + pairIndex: Int +) -> Float { + let rank = productChannelRank( + seed: seed ^ (UInt64(pass) &* 0xA24B_AED4_963E_E407), + groupIndex: groupIndex, + localIndex: pairIndex + ) + let unit = Float(UInt32(truncatingIfNeeded: rank)) / Float(UInt32.max) + return (unit - 0.5) * Float.pi +} + +private func packedBitByteCount(_ bitCount: Int) -> Int { + (bitCount + 7) / 8 +} + +private func setPackedBit(_ bytes: inout [UInt8], index: Int, value: Bool) { + guard value else { return } + let byteIndex = index / 8 + let bitIndex = index % 8 + bytes[byteIndex] |= UInt8(1 << bitIndex) +} + +private func getPackedBit(_ data: Data, index: Int) -> Bool { + let byteIndex = index / 8 + let bitIndex = index % 8 + guard byteIndex < data.count else { return false } + return (data[byteIndex] & UInt8(1 << bitIndex)) != 0 +} + +private func appendPackedBits( + _ value: UInt32, + bitCount: Int, + bytes: inout [UInt8], + bitOffset: inout Int +) { + for localBit in 0 ..< bitCount { + if bitOffset / 8 == bytes.count { + bytes.append(0) + } + let bitSet = (value & (1 << UInt32(localBit))) != 0 + if bitSet { + bytes[bitOffset / 8] |= UInt8(1 << (bitOffset % 8)) + } + bitOffset += 1 + } +} + +private func readPackedBits( + _ data: Data, + bitOffset: inout Int, + bitCount: Int +) throws -> UInt32 { + var value: UInt32 = 0 + for localBit in 0 ..< bitCount { + let byteIndex = bitOffset / 8 + guard byteIndex < data.count else { + throw TurboQuantError.invalidReferenceCode("packed magnitude storage is truncated") + } + if (data[byteIndex] & UInt8(1 << (bitOffset % 8))) != 0 { + value |= 1 << UInt32(localBit) + } + bitOffset += 1 + } + return value +} + +private func preconditionedValue(_ value: Float, index: Int, seed: UInt64) -> Float { + randomSign(index: index, seed: seed) ? -value : value +} + +private func unpreconditionedValue(_ value: Float, index: Int, seed: UInt64) -> Float { + randomSign(index: index, seed: seed) ? -value : value +} + +private func randomSign(index: Int, seed: UInt64) -> Bool { + var state = seed &+ UInt64(index) &* 0x9E37_79B9_7F4A_7C15 + state ^= state >> 30 + state &*= 0xBF58_476D_1CE4_E5B9 + state ^= state >> 27 + state &*= 0x94D0_49BB_1331_11EB + state ^= state >> 31 + return (state & 1) == 1 +} + +private func metalTemplateUInt32High(_ value: UInt64) -> UInt32 { + UInt32((value >> 32) & 0xFFFF_FFFF) +} + +private func metalTemplateUInt32Low(_ value: UInt64) -> UInt32 { + UInt32(value & 0xFFFF_FFFF) +} + +private func metalRuntimeAvailable() -> Bool { + #if canImport(Metal) + guard MTLCreateSystemDefaultDevice() != nil else { return false } + #endif + return metalLibraryResourceAvailable() +} + +private func metalLibraryResourceAvailable() -> Bool { + let fileManager = FileManager.default + var candidates: [URL] = [] + + if let executablePath = CommandLine.arguments.first, !executablePath.isEmpty { + let executableDirectory = URL(fileURLWithPath: executablePath).deletingLastPathComponent() + candidates.append(executableDirectory.appendingPathComponent("mlx.metallib")) + candidates.append(executableDirectory.appendingPathComponent("default.metallib")) + candidates.append(executableDirectory.appendingPathComponent("Resources/mlx.metallib")) + candidates.append(executableDirectory.appendingPathComponent("Resources/default.metallib")) + appendSwiftPMMetalBundleCandidates(from: executableDirectory, to: &candidates) + } + + if let executableDirectory = Bundle.main.executableURL?.deletingLastPathComponent() { + appendSwiftPMMetalBundleCandidates(from: executableDirectory, to: &candidates) + } + + let currentDirectory = URL(fileURLWithPath: fileManager.currentDirectoryPath) + candidates.append(currentDirectory.appendingPathComponent("mlx.metallib")) + candidates.append(currentDirectory.appendingPathComponent("default.metallib")) + + for bundle in [Bundle.main] + Bundle.allBundles { + if bundle.url(forResource: "default", withExtension: "metallib") != nil + || bundle.url(forResource: "mlx", withExtension: "metallib") != nil + { + return true + } + appendSwiftPMMetalBundleCandidates(from: bundle.bundleURL, to: &candidates) + if let resourceURL = bundle.resourceURL { + candidates.append(resourceURL.appendingPathComponent("default.metallib")) + candidates.append(resourceURL.appendingPathComponent("mlx.metallib")) + candidates.append( + resourceURL.appendingPathComponent("mlx-swift_Cmlx.bundle/default.metallib")) + candidates.append( + resourceURL.appendingPathComponent("mlx-swift_Cmlx.bundle/mlx.metallib")) + appendSwiftPMMetalBundleCandidates(from: resourceURL, to: &candidates) + } + } + + return candidates.contains { fileManager.fileExists(atPath: $0.path) } +} + +private func appendSwiftPMMetalBundleCandidates(from directory: URL, to candidates: inout [URL]) { + var root = directory + for _ in 0 ..< 5 { + candidates.append(root.appendingPathComponent("mlx-swift_Cmlx.bundle/default.metallib")) + candidates.append(root.appendingPathComponent("mlx-swift_Cmlx.bundle/mlx.metallib")) + + let parent = root.deletingLastPathComponent() + guard parent.path != root.path else { break } + root = parent + } +} + +private func detectedTurboQuantDeviceCapabilities() -> TurboQuantDeviceCapabilities { + let metalAvailable = metalRuntimeAvailable() + let physicalMemory = Int(ProcessInfo.processInfo.physicalMemory) + + #if canImport(Metal) + if let device = MTLCreateSystemDefaultDevice() { + let architecture: String + if #available(macOS 14.0, iOS 17.0, tvOS 17.0, *) { + architecture = device.architecture.name + } else { + architecture = device.name + } + + let recommendedWorkingSet: Int? + if device.recommendedMaxWorkingSetSize > UInt64(Int.max) { + recommendedWorkingSet = Int.max + } else if device.recommendedMaxWorkingSetSize > 0 { + recommendedWorkingSet = Int(device.recommendedMaxWorkingSetSize) + } else { + recommendedWorkingSet = nil + } + + return TurboQuantDeviceCapabilities( + metalAvailable: metalAvailable, + architectureName: architecture, + supportedGPUFamilies: turboQuantSupportedGPUFamilies(device), + maxBufferBytes: device.maxBufferLength, + recommendedWorkingSetBytes: recommendedWorkingSet, + physicalMemoryBytes: physicalMemory, + maxThreadgroupWidth: device.maxThreadsPerThreadgroup.width + ) + } + #endif + + return TurboQuantDeviceCapabilities( + metalAvailable: metalAvailable, + architectureName: "Unknown", + physicalMemoryBytes: physicalMemory + ) +} + +#if canImport(Metal) + private func turboQuantSupportedGPUFamilies(_ device: MTLDevice) -> [String: Bool] { + var families = [ + "apple7": device.supportsFamily(.apple7), + "apple8": device.supportsFamily(.apple8), + "apple9": device.supportsFamily(.apple9), + "apple10": device.supportsFamily(.apple10), + "mac2": device.supportsFamily(.mac2), + "metal3": device.supportsFamily(.metal3), + ] + #if targetEnvironment(simulator) + families["metal4"] = false + #else + if #available(macOS 26.0, iOS 26.0, tvOS 26.0, visionOS 26.0, *) { + families["metal4"] = device.supportsFamily(.metal4) + } else { + families["metal4"] = false + } + #endif + return families + } +#endif + +private func selectTurboQuantKernelProfile( + architectureName: String, + supportedGPUFamilies: [String: Bool], + recommendedWorkingSetBytes: Int? +) -> TurboQuantKernelProfile { + let architecture = architectureName.lowercased() + let workingSet = recommendedWorkingSetBytes ?? 0 + + if workingSet >= 10_000_000_000 + || architecture.contains("a19pro") + || architecture.contains("a19 pro") + { + return .sustainedA19Pro + } + + if supportedGPUFamilies["apple10"] == true + || supportedGPUFamilies["apple9"] == true + || supportedGPUFamilies["apple8"] == true + || workingSet >= 7_000_000_000 + || architecture.contains("a18") + || architecture.contains("a19") + { + return .wideA18A19 + } + + return .portableA16A17 +} + +public final class TurboQuantRuntimeProbe: @unchecked Sendable { + public static let shared = TurboQuantRuntimeProbe() + + private let lock = NSLock() + private var cachedResult: TurboQuantRuntimeProbeResult? + private var runningSelfTest = false + + private init() {} + + public static var current: TurboQuantRuntimeProbeResult { + shared.result() + } + + public func result() -> TurboQuantRuntimeProbeResult { + lock.lock() + if let cachedResult { + lock.unlock() + return cachedResult + } + lock.unlock() + + let result = run(on: detectedTurboQuantDeviceCapabilities()) + + lock.lock() + cachedResult = result + lock.unlock() + return result + } + + func selectedKernelProfileWithoutRunningProbe() -> TurboQuantKernelProfile { + lock.lock() + let cached = cachedResult?.selectedKernelProfile + lock.unlock() + if let cached { return cached } + + let capabilities = detectedTurboQuantDeviceCapabilities() + guard capabilities.metalAvailable else { return .mlxPackedFallback } + return selectTurboQuantKernelProfile( + architectureName: capabilities.architectureName, + supportedGPUFamilies: capabilities.supportedGPUFamilies, + recommendedWorkingSetBytes: capabilities.recommendedWorkingSetBytes + ) + } + + func isRunningSelfTest() -> Bool { + lock.lock() + let running = runningSelfTest + lock.unlock() + return running + } + + private func run(on capabilities: TurboQuantDeviceCapabilities) -> TurboQuantRuntimeProbeResult + { + guard capabilities.metalAvailable else { + return TurboQuantRuntimeProbeResult( + status: .failed, + metalRuntimeAvailable: false, + selectedKernelProfile: .mlxPackedFallback, + failureReason: "Metal runtime or bundled metallib is unavailable." + ) + } + + let selectedProfile = selectTurboQuantKernelProfile( + architectureName: capabilities.architectureName, + supportedGPUFamilies: capabilities.supportedGPUFamilies, + recommendedWorkingSetBytes: capabilities.recommendedWorkingSetBytes + ) + + lock.lock() + runningSelfTest = true + lock.unlock() + defer { + lock.lock() + runningSelfTest = false + lock.unlock() + } + + do { + let queryValues: [Float] = (0 ..< 512).map { index in + let position = Double(index) + return Float(sin(position * 0.07) + 0.25 * cos(position * 0.013)) + } + let keyValues: [Float] = (0 ..< 640).map { index in + let position = Double(index) + return Float(0.5 * cos(position * 0.05) + 0.1 * sin(position * 0.19)) + } + let valueValues: [Float] = (0 ..< 640).map { index in + let position = Double(index) + return Float(0.35 * sin(position * 0.09) - 0.15 * cos(position * 0.17)) + } + let queries = MLXArray(queryValues, [1, 4, 2, 64]) + let keys = MLXArray(keyValues, [1, 2, 5, 64]) + let values = MLXArray(valueValues, [1, 2, 5, 64]) + let encodeStart = Date.timeIntervalSinceReferenceDate + let keyCode = try turboQuantMetalEncodeAttention( + keys, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .metalPolarQJL, + seed: 0x5EED_A11C_0000_0001 + ) + ) + let valueCode = try turboQuantMetalEncodeAttention( + values, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .value, + groupSize: 64, + backend: .metalPolarQJL, + seed: 0x5EED_A11C_0000_0002 + ) + ) + let decodedKeys = try turboQuantMetalDecodeAttention(keyCode, outputDType: .float32) + let decodedValues = try turboQuantMetalDecodeAttention(valueCode, outputDType: .float32) + eval(decodedKeys, decodedValues) + let encodeDecodeLatency = Date.timeIntervalSinceReferenceDate - encodeStart + let encodeDecodePassed = + decodedKeys.shape == keys.shape + && decodedValues.shape == values.shape + + let scale = 1 / sqrt(Float(64)) + let reference = MLXFast.scaledDotProductAttention( + queries: queries, + keys: keys, + values: values, + scale: scale, + mask: .causal + ) + eval(reference) + + let qk = try turboQuantMetalQK( + queries: queries, + keyCode: keyCode, + scale: scale, + mask: .causal + ) + eval(qk) + let qkPassed = qk.shape == [1, 4, 2, 5] + + let twoStageStart = Date.timeIntervalSinceReferenceDate + let weights = softmax(qk.asType(DType.float32), axis: -1) + let av = try turboQuantMetalAV( + attentionWeights: weights, + valueCode: valueCode, + outputDType: .float32 + ) + eval(av) + let twoStageLatency = Date.timeIntervalSinceReferenceDate - twoStageStart + + let fusedStart = Date.timeIntervalSinceReferenceDate + let fused = try turboQuantMetalScaledDotProductAttention( + queries: queries, + keyCode: keyCode, + valueCode: valueCode, + scale: scale, + mask: .causal, + preferOnlineFused: true, + kernelProfile: selectedProfile + ) + eval(av, fused) + let referenceValues = reference.asArray(Float.self) + let fusedLatency = Date.timeIntervalSinceReferenceDate - fusedStart + let avValues = av.asArray(Float.self) + let fusedValues = fused.asArray(Float.self) + let maxDelta = zip(avValues, fusedValues).reduce(Float(0)) { current, pair in + Swift.max(current, Swift.abs(pair.0 - pair.1)) + } + let referenceEnergy = referenceValues.reduce(Float(0)) { partial, value in + partial + value * value + } + let fusedReferenceRelativeMSE = zip(fusedValues, referenceValues).reduce(Float(0)) { + current, pair in + let delta = pair.0 - pair.1 + return current + delta * delta + } / Swift.max(referenceEnergy, Float.leastNonzeroMagnitude) + let avPassed = av.shape == [1, 4, 2, 64] + let fusedPassed = + av.shape == fused.shape && maxDelta < 1e-3 + && fusedReferenceRelativeMSE < 0.5 + let passed = encodeDecodePassed && qkPassed && avPassed && fusedPassed + + return TurboQuantRuntimeProbeResult( + status: passed ? .passed : .failed, + metalRuntimeAvailable: true, + encodeDecodePassed: encodeDecodePassed, + qkPassed: qkPassed, + avPassed: avPassed, + tiledFusedPassed: fusedPassed, + selectedKernelProfile: passed ? selectedProfile : .mlxPackedFallback, + failureReason: passed ? nil : "TurboQuant Metal tiny-shape self-test failed.", + encodeDecodeLatencySeconds: encodeDecodeLatency, + twoStageLatencySeconds: twoStageLatency, + tiledFusedLatencySeconds: fusedLatency + ) + } catch { + return TurboQuantRuntimeProbeResult( + status: .failed, + metalRuntimeAvailable: true, + selectedKernelProfile: .mlxPackedFallback, + failureReason: String(describing: error) + ) + } + } +} + +private func validateMetalConfiguration( + array: MLXArray, + configuration: TurboQuantConfiguration +) throws { + guard array.size > 0 else { + throw TurboQuantError.invalidMetalConfiguration("empty arrays are not supported") + } + guard array.dtype.isFloatingPoint else { + throw TurboQuantError.invalidMetalConfiguration("input dtype must be floating point") + } + guard configuration.groupSize > 0 else { + throw TurboQuantError.invalidGroupSize(configuration.groupSize) + } + guard configuration.groupSize <= 128, configuration.groupSize % 32 == 0 else { + throw TurboQuantError.invalidMetalConfiguration( + "group size must be 32, 64, 96, or 128 for the Metal codec" + ) + } + if configuration.role == .value { + try validateTurboQuantValueBits(configuration.resolvedValueBits) + } + try requireTurboQuantMetalCodec() +} + +private func metalMagnitudeWordsPerGroup( + groupSize: Int, + preset: TurboQuantPreset, + role: TurboQuantTensorRole = .key, + valueBits: Int? = nil +) -> Int { + if role == .value { + let bitCount = groupSize * (valueBits ?? preset.defaultValueBits) + return (bitCount + 31) / 32 + } + let baseBits = Swift.max(1, preset.baseMagnitudeBits - 1) + let highBits = Swift.max(baseBits, preset.highMagnitudeBits - 1) + let highCount = mixedPrecisionHighCount( + valueCount: groupSize, + baseBits: baseBits, + highBits: highBits, + targetBits: Swift.max(1, preset.targetMagnitudeBits - 1) + ) + let bitCount = + groupSize * baseBits + + highCount * (highBits - baseBits) + return (bitCount + 31) / 32 +} + +private func metalScalesPerGroup(role: TurboQuantTensorRole) -> Int { + role == .value ? 2 : 3 +} + +private func metalTemplate( + configuration: TurboQuantConfiguration, + valueCount: Int, + groupCount: Int, + magnitudeWordsPerGroup: Int, + bitsetWordsPerGroup: Int +) -> [(String, any KernelTemplateArg)] { + [ + ("GROUP_SIZE", configuration.groupSize), + ("VALUE_COUNT", valueCount), + ("GROUP_COUNT", groupCount), + ("BASE_BITS", configuration.preset.baseMagnitudeBits), + ("HIGH_BITS", configuration.preset.highMagnitudeBits), + ("KEY_BASE_BITS", Swift.max(1, configuration.preset.baseMagnitudeBits - 1)), + ( + "KEY_HIGH_BITS", + Swift.max( + Swift.max(1, configuration.preset.baseMagnitudeBits - 1), + configuration.preset.highMagnitudeBits - 1 + ) + ), + ("HIGH_NUMERATOR", 1), + ("HIGH_DENOMINATOR", 2), + ("MAG_WORDS_PER_GROUP", magnitudeWordsPerGroup), + ("BITSET_WORDS_PER_GROUP", bitsetWordsPerGroup), + ("VALUE_BITS", configuration.resolvedValueBits), + ("SCALES_PER_GROUP", metalScalesPerGroup(role: configuration.role)), + ("ROLE", metalRoleValue(configuration.role)), + ("SEED_HI", metalTemplateUInt32High(configuration.seed)), + ("SEED_LO", metalTemplateUInt32Low(configuration.seed)), + ] +} + +private func metalRoleValue(_ role: TurboQuantTensorRole) -> Int { + switch role { + case .key: + 0 + case .value: + 1 + case .vector: + 2 + } +} + +private func validateAttentionArray(_ array: MLXArray, groupSize: Int) throws { + try validateAttentionShape(array.shape, dtype: array.dtype, groupSize: groupSize) +} + +private func validateAttentionShape(_ shape: [Int], dtype: DType, groupSize: Int) throws { + guard shape.count == 4 else { + throw TurboQuantError.invalidMetalConfiguration( + "attention tensors must have shape [B, H, T, D]" + ) + } + guard shape.reduce(1, *) > 0 else { + throw TurboQuantError.invalidMetalConfiguration("empty attention tensors are not supported") + } + guard dtype.isFloatingPoint else { + throw TurboQuantError.invalidMetalConfiguration( + "attention tensor dtype must be floating point") + } + guard groupSize > 0 else { + throw TurboQuantError.invalidGroupSize(groupSize) + } + guard groupSize <= 128, groupSize % 32 == 0 else { + throw TurboQuantError.invalidMetalConfiguration( + "group size must be 32, 64, 96, or 128 for compressed attention" + ) + } + guard shape[3] <= 512 else { + throw TurboQuantError.invalidMetalConfiguration( + "head dimension \(shape[3]) is not supported by compressed attention" + ) + } +} + +private func validateAttentionLayout( + _ layout: TurboQuantAttentionLayout, + role: TurboQuantTensorRole, + groupSize: Int +) throws { + guard role == .key || role == .value else { + throw TurboQuantError.invalidMetalConfiguration( + "compressed attention codes must be encoded as key or value" + ) + } + guard layout.layoutVersion == TurboQuantAttentionLayout.currentVersion else { + throw TurboQuantError.invalidMetalConfiguration( + "unsupported compressed attention layout version \(layout.layoutVersion)" + ) + } + guard layout.batchSize > 0, layout.kvHeadCount > 0, layout.capacity > 0, + layout.logicalLength >= 0, layout.logicalLength <= layout.capacity, + layout.headDimension > 0 + else { + throw TurboQuantError.invalidMetalConfiguration("invalid compressed attention layout shape") + } + guard layout.ringOffset >= 0, layout.ringOffset < layout.capacity else { + throw TurboQuantError.invalidMetalConfiguration("ring offset is outside cache capacity") + } + guard layout.pinnedPrefixLength >= 0, layout.pinnedPrefixLength <= layout.capacity else { + throw TurboQuantError.invalidMetalConfiguration("pinned prefix is outside cache capacity") + } + let ringCapacity = layout.capacity - layout.pinnedPrefixLength + if ringCapacity == 0 { + guard layout.ringOffset == 0 else { + throw TurboQuantError.invalidMetalConfiguration( + "ring offset must be zero without ring capacity") + } + } else { + guard layout.ringOffset < ringCapacity else { + throw TurboQuantError.invalidMetalConfiguration( + "ring offset is outside rotating region") + } + } + guard layout.groupsPerVector == (layout.headDimension + groupSize - 1) / groupSize else { + throw TurboQuantError.invalidMetalConfiguration("groups per vector does not match layout") + } +} + +private func validateAttentionQuery( + _ queries: MLXArray, + code: TurboQuantAttentionCode +) throws { + try validateAttentionArray(queries, groupSize: code.groupSize) + guard queries.dim(0) == code.layout.batchSize else { + throw TurboQuantError.invalidMetalConfiguration( + "query batch size does not match compressed attention cache" + ) + } + guard queries.dim(3) == code.layout.headDimension else { + throw TurboQuantError.invalidMetalConfiguration( + "query head dimension does not match compressed attention cache" + ) + } + guard queries.dim(1) % code.layout.kvHeadCount == 0 else { + throw TurboQuantError.invalidMetalConfiguration( + "query heads must be a multiple of KV heads" + ) + } +} + +private func validateAttentionPair( + keyCode: TurboQuantAttentionCode, + valueCode: TurboQuantAttentionCode +) throws { + try validateAttentionLayout(keyCode.layout, role: keyCode.role, groupSize: keyCode.groupSize) + try validateAttentionLayout( + valueCode.layout, role: valueCode.role, groupSize: valueCode.groupSize) + guard keyCode.role == .key, valueCode.role == .value else { + throw TurboQuantError.invalidMetalConfiguration( + "compressed attention requires key and value codes") + } + guard attentionLayoutsShareSequence(keyCode.layout, valueCode.layout) else { + throw TurboQuantError.invalidMetalConfiguration( + "key and value compressed sequence layouts differ" + ) + } + guard keyCode.preset == valueCode.preset, keyCode.groupSize == valueCode.groupSize else { + throw TurboQuantError.invalidMetalConfiguration("key and value compressed presets differ") + } +} + +private func attentionLayoutsShareSequence( + _ keyLayout: TurboQuantAttentionLayout, + _ valueLayout: TurboQuantAttentionLayout +) -> Bool { + keyLayout.layoutVersion == valueLayout.layoutVersion + && keyLayout.batchSize == valueLayout.batchSize + && keyLayout.kvHeadCount == valueLayout.kvHeadCount + && keyLayout.capacity == valueLayout.capacity + && keyLayout.logicalLength == valueLayout.logicalLength + && keyLayout.ringOffset == valueLayout.ringOffset + && keyLayout.pinnedPrefixLength == valueLayout.pinnedPrefixLength +} + +private func validateAttentionSinks(_ sinks: MLXArray?, queryHeadCount: Int) throws { + guard let sinks else { return } + guard sinks.ndim == 1, sinks.dim(0) == queryHeadCount else { + throw TurboQuantError.invalidMetalConfiguration( + "attention sinks must have shape [query heads]" + ) + } + guard sinks.dtype.isFloatingPoint else { + throw TurboQuantError.invalidMetalConfiguration("attention sinks must be floating point") + } +} + +private func prependAttentionSinks( + _ scores: MLXArray, + sinks: MLXArray?, + queryHeadCount: Int, + stream: StreamOrDevice +) throws -> MLXArray { + guard let sinks else { return scores } + try validateAttentionSinks(sinks, queryHeadCount: queryHeadCount) + let sinkScores = broadcast( + expandedDimensions(sinks.asType(.float32), axes: [0, 2, 3], stream: stream), + to: [scores.dim(0), scores.dim(1), scores.dim(2), 1], + stream: stream + ) + return concatenated([sinkScores, scores], axis: -1, stream: stream) +} + +private func applyAttentionMask( + _ scores: inout MLXArray, + mask: MLXFast.ScaledDotProductAttentionMaskMode, + stream: StreamOrDevice +) { + switch mask { + case .causal: + let (qL, kL) = (scores.dim(-2), scores.dim(-1)) + let qIndices = MLXArray(0 ..< qL) + MLXArray(kL - qL) + let kIndices = MLXArray(0 ..< kL) + let causalMask = greaterEqual( + expandedDimensions(qIndices, axis: -1), + expandedDimensions(kIndices, axis: -2), + stream: stream + ) + scores = `where`( + causalMask, + scores, + MLXArray(-Float.greatestFiniteMagnitude), + stream: stream + ) + + case .array(let maskArray): + if maskArray.dtype == .bool { + scores = `where`( + maskArray, + scores, + MLXArray(-Float.greatestFiniteMagnitude), + stream: stream + ) + } else { + scores = scores + maskArray + } + + case .arrays(let maskArrays): + if let maskArray = maskArrays.first { + if maskArray.dtype == .bool { + scores = `where`( + maskArray, + scores, + MLXArray(-Float.greatestFiniteMagnitude), + stream: stream + ) + } else { + scores = scores + maskArray + } + } + + case .none: + break + } +} + +private func attentionTemplate( + configuration: TurboQuantConfiguration, + layout: TurboQuantAttentionLayout, + inputLength: Int, + outputLength: Int, + queryHeadCount: Int, + queryLength: Int, + outputDType: DType, + causal: Bool +) -> [(String, any KernelTemplateArg)] { + [ + ("BATCH_SIZE", layout.batchSize), + ("KV_HEADS", layout.kvHeadCount), + ("QUERY_HEADS", queryHeadCount), + ("INPUT_LENGTH", inputLength), + ("OUTPUT_LENGTH", outputLength), + ("CAPACITY", layout.capacity), + ("LOGICAL_LENGTH", layout.logicalLength), + ("RING_OFFSET", layout.ringOffset), + ("PINNED_PREFIX_LENGTH", layout.pinnedPrefixLength), + ("QUERY_LENGTH", queryLength), + ("HEAD_DIM", layout.headDimension), + ("GROUP_SIZE", configuration.groupSize), + ("GROUPS_PER_VECTOR", layout.groupsPerVector), + ("BASE_BITS", configuration.preset.baseMagnitudeBits), + ("HIGH_BITS", configuration.preset.highMagnitudeBits), + ("KEY_BASE_BITS", Swift.max(1, configuration.preset.baseMagnitudeBits - 1)), + ( + "KEY_HIGH_BITS", + Swift.max( + Swift.max(1, configuration.preset.baseMagnitudeBits - 1), + configuration.preset.highMagnitudeBits - 1 + ) + ), + ("MAG_WORDS_PER_GROUP", layout.magnitudeWordsPerGroup), + ("BITSET_WORDS_PER_GROUP", layout.bitsetWordsPerGroup), + ("VALUE_BITS", configuration.resolvedValueBits), + ("SCALES_PER_GROUP", metalScalesPerGroup(role: configuration.role)), + ("ROLE", metalRoleValue(configuration.role)), + ("SEED_HI", metalTemplateUInt32High(configuration.seed)), + ("SEED_LO", metalTemplateUInt32Low(configuration.seed)), + ("OUTPUT_DTYPE", outputDType), + ("DO_CAUSAL", causal), + ] +} + +private enum TurboQuantMetalKernels { + static let encode = MLXFast.metalKernel( + name: "turboquant_polar_qjl_encode", + inputNames: ["x"], + outputNames: ["packed", "signs", "high_mask", "residual_signs", "scales"], + source: encodeSource, + header: vectorHeader + ) + + static let decode = MLXFast.metalKernel( + name: "turboquant_polar_qjl_decode", + inputNames: ["packed", "signs", "high_mask", "residual_signs", "scales"], + outputNames: ["out"], + source: decodeSource, + header: vectorHeader + ) + + static let matmul = MLXFast.metalKernel( + name: "turboquant_polar_qjl_matmul", + inputNames: ["x", "packed", "signs", "high_mask", "residual_signs", "scales"], + outputNames: ["out"], + source: matmulSource, + header: vectorHeader + ) + + static let encodeAttention = MLXFast.metalKernel( + name: "turboquant_attention_encode", + inputNames: ["x"], + outputNames: ["packed", "signs", "high_mask", "residual_signs", "scales"], + source: encodeAttentionSource, + header: attentionHeader + ) + + static let decodeAttention = MLXFast.metalKernel( + name: "turboquant_attention_decode", + inputNames: ["packed", "signs", "high_mask", "residual_signs", "scales"], + outputNames: ["out"], + source: decodeAttentionSource, + header: attentionHeader + ) + + static let qk = MLXFast.metalKernel( + name: "turboquant_attention_qk", + inputNames: ["q", "k_packed", "k_signs", "k_high_mask", "k_residual_signs", "k_scales"], + outputNames: ["scores"], + source: qkSource, + header: attentionHeader + ) + + static let av = MLXFast.metalKernel( + name: "turboquant_attention_av", + inputNames: [ + "weights", "v_packed", "v_signs", "v_high_mask", "v_residual_signs", "v_scales", + ], + outputNames: ["out"], + source: avSource, + header: attentionHeader + ) + + static let fusedAttention = MLXFast.metalKernel( + name: "turboquant_attention_fused_decode", + inputNames: [ + "q", + "k_packed", "k_signs", "k_high_mask", "k_residual_signs", "k_scales", + "v_packed", "v_signs", "v_high_mask", "v_residual_signs", "v_scales", + ], + outputNames: ["out"], + source: fusedAttentionSource, + header: attentionHeader + ) + + private static let vectorHeader = """ + inline ulong tq_vector_mix_index(ulong seed, ulong index) { + ulong mixed = seed + index * 0x9E3779B97F4A7C15ul; + mixed ^= mixed >> 30; + mixed *= 0xBF58476D1CE4E5B9ul; + mixed ^= mixed >> 27; + mixed *= 0x94D049BB133111EBul; + mixed ^= mixed >> 31; + return mixed; + } + + inline bool tq_vector_random_sign(ulong seed, ulong index) { + return (tq_vector_mix_index(seed, index) & 1ul) != 0ul; + } + + inline ulong tq_product_channel_rank(ulong seed, uint group_index, uint local_index) { + ulong state = seed; + state ^= ulong(group_index) * 0x9E3779B97F4A7C15ul; + state += ulong(local_index) * 0xD1B54A32D192ED03ul; + state ^= state >> 30; + state *= 0xBF58476D1CE4E5B9ul; + state ^= state >> 27; + state *= 0x94D049BB133111EBul; + state ^= state >> 31; + return state; + } + + inline bool tq_product_high_precision( + ulong seed, + uint group_index, + uint local, + uint count, + uint high_count + ) { + if (high_count == 0u) { + return false; + } + if (high_count >= count) { + return true; + } + ulong local_rank = tq_product_channel_rank(seed, group_index, local); + uint rank = 0u; + for (uint other = 0u; other < count; other++) { + ulong other_rank = tq_product_channel_rank(seed, group_index, other); + if (other_rank < local_rank || (other_rank == local_rank && other < local)) { + rank += 1u; + } + } + return rank < high_count; + } + + inline float tq_codebook_unit(uint bits, uint code) { + if (bits <= 1u) { + return code == 0u ? -0.797884561f : 0.797884561f; + } + if (bits == 2u) { + switch (min(code, 3u)) { + case 0u: return -1.510499245f; + case 1u: return -0.452819573f; + case 2u: return 0.452819573f; + default: return 1.510499245f; + } + } + if (bits == 3u) { + switch (min(code, 7u)) { + case 0u: return -2.175028018f; + case 1u: return -1.367204388f; + case 2u: return -0.773020220f; + case 3u: return -0.251312159f; + case 4u: return 0.251312159f; + case 5u: return 0.773020220f; + case 6u: return 1.367204388f; + default: return 2.175028018f; + } + } + switch (min(code, 15u)) { + case 0u: return -2.778927695f; + case 1u: return -2.124836923f; + case 2u: return -1.680512470f; + case 3u: return -1.321175453f; + case 4u: return -1.003692455f; + case 5u: return -0.707453186f; + case 6u: return -0.421537889f; + case 7u: return -0.140103661f; + case 8u: return 0.140103661f; + case 9u: return 0.421537889f; + case 10u: return 0.707453186f; + case 11u: return 1.003692455f; + case 12u: return 1.321175453f; + case 13u: return 1.680512470f; + case 14u: return 2.124836923f; + default: return 2.778927695f; + } + } + + inline float tq_codebook_level(uint bits, uint code, uint count) { + return tq_codebook_unit(bits, code) * rsqrt(float(max(count, 1u))); + } + + inline uint tq_nearest_codebook_index(float value, uint bits, uint count) { + uint level_count = 1u << bits; + uint best_index = 0u; + float best_distance = INFINITY; + for (uint code = 0u; code < level_count; code++) { + float distance = fabs(value - tq_codebook_level(bits, code, count)); + if (distance < best_distance) { + best_distance = distance; + best_index = code; + } + } + return best_index; + } + + inline void tq_fast_hadamard(thread float* values, uint count) { + for (uint width = 1u; width < count; width <<= 1u) { + for (uint start = 0u; start < count; start += width << 1u) { + for (uint offset = 0u; offset < width; offset++) { + float lhs = values[start + offset]; + float rhs = values[start + offset + width]; + values[start + offset] = lhs + rhs; + values[start + offset + width] = lhs - rhs; + } + } + } + } + + inline void tq_apply_rotation_signs( + thread float* values, + uint count, + ulong seed, + uint group_index + ) { + for (uint local = 0u; local < count; local++) { + ulong sign_index = ulong(group_index) * 4099ul + ulong(local); + if (tq_vector_random_sign(seed, sign_index)) { + values[local] = -values[local]; + } + } + } + + inline void tq_apply_givens_pass( + thread float* values, + uint count, + ulong seed, + uint group_index, + uint pass, + float direction + ) { + uint offset = pass & 1u; + for (uint index = offset; index + 1u < count; index += 2u) { + ulong angle_rank = tq_product_channel_rank( + seed ^ (ulong(pass) * 0xA24BAED4963EE407ul), + group_index, + index >> 1u); + float unit = float(uint(angle_rank)) / 4294967295.0f; + float angle = (unit - 0.5f) * 3.14159265358979323846f * direction; + float c = cos(angle); + float s = sin(angle); + float lhs = values[index]; + float rhs = values[index + 1u]; + values[index] = c * lhs - s * rhs; + values[index + 1u] = s * lhs + c * rhs; + } + } + + inline void tq_apply_product_rotation( + thread float* values, + uint count, + ulong seed, + uint group_index, + bool inverse + ) { + if (count <= 1u) { + tq_apply_rotation_signs(values, count, seed, group_index); + return; + } + if ((count & (count - 1u)) == 0u) { + if (inverse) { + tq_fast_hadamard(values, count); + tq_apply_rotation_signs(values, count, seed, group_index); + } else { + tq_apply_rotation_signs(values, count, seed, group_index); + tq_fast_hadamard(values, count); + } + float scale = rsqrt(float(count)); + for (uint local = 0u; local < count; local++) { + values[local] *= scale; + } + return; + } + if (inverse) { + for (uint pass_index = 0u; pass_index < 4u; pass_index++) { + tq_apply_givens_pass(values, count, seed, group_index, 3u - pass_index, -1.0f); + } + } else { + for (uint pass = 0u; pass < 4u; pass++) { + tq_apply_givens_pass(values, count, seed, group_index, pass, 1.0f); + } + } + } + + inline bool tq_flat_high_precision( + device const uint* high_mask, + uint group_id, + uint local, + uint bitset_words_per_group + ) { + uint bitset_base = group_id * bitset_words_per_group; + uint word_index = local >> 5; + uint word_bit = local & 31u; + return (high_mask[bitset_base + word_index] & (1u << word_bit)) != 0u; + } + + inline uint tq_read_flat_code( + device const uint* packed, + device const uint* high_mask, + uint group_id, + uint local, + uint mag_words_per_group, + uint bitset_words_per_group, + uint base_bits, + uint high_bits + ) { + uint packed_base = group_id * mag_words_per_group; + bool high_precision = tq_flat_high_precision( + high_mask, group_id, local, bitset_words_per_group); + uint bits = high_precision ? high_bits : base_bits; + uint bit_offset = 0u; + for (uint prior = 0u; prior < local; prior++) { + bool prior_high = tq_flat_high_precision( + high_mask, group_id, prior, bitset_words_per_group); + bit_offset += prior_high ? high_bits : base_bits; + } + + uint quantized = 0u; + for (uint bit = 0u; bit < bits; bit++) { + uint global_bit = bit_offset + bit; + uint packed_word = global_bit >> 5; + uint packed_bit = global_bit & 31u; + if ((packed[packed_base + packed_word] & (1u << packed_bit)) != 0u) { + quantized |= 1u << bit; + } + } + return quantized; + } + + inline float tq_decode_flat_value( + device const uint* packed, + device const uint* signs, + device const uint* high_mask, + device const uint* residual_signs, + device const float* scales, + uint index, + ulong seed, + uint role, + uint group_size, + uint mag_words_per_group, + uint bitset_words_per_group, + uint base_bits, + uint high_bits, + uint key_base_bits, + uint key_high_bits, + uint value_bits, + uint scales_per_group, + uint value_count + ) { + uint group_id = index / group_size; + uint local = index - group_id * group_size; + uint packed_base = group_id * mag_words_per_group; + if (role == 1u) { + uint bit_offset = local * value_bits; + uint quantized = 0u; + for (uint bit = 0u; bit < value_bits; bit++) { + uint global_bit = bit_offset + bit; + uint packed_word = global_bit >> 5; + uint packed_bit = global_bit & 31u; + if ((packed[packed_base + packed_word] & (1u << packed_bit)) != 0u) { + quantized |= 1u << bit; + } + } + uint scale_base = group_id * scales_per_group; + return scales[scale_base + 1u] + float(quantized) * scales[scale_base]; + } + + uint count = min(group_size, value_count - group_id * group_size); + thread float rotated[128]; + for (uint decode_local = 0u; decode_local < count; decode_local++) { + bool high_precision = tq_flat_high_precision( + high_mask, group_id, decode_local, bitset_words_per_group); + uint bits = high_precision ? key_high_bits : key_base_bits; + uint code = tq_read_flat_code( + packed, high_mask, group_id, decode_local, + mag_words_per_group, bitset_words_per_group, + key_base_bits, key_high_bits); + rotated[decode_local] = tq_codebook_level(bits, code, count); + } + tq_apply_product_rotation(rotated, count, seed, group_id, true); + return rotated[local] * scales[group_id * scales_per_group]; + } + """ + + private static let encodeSource = """ + uint group_id = thread_position_in_grid.x; + if (group_id >= GROUP_COUNT) { + return; + } + + uint start = group_id * GROUP_SIZE; + uint count = min(uint(GROUP_SIZE), uint(VALUE_COUNT) - start); + if (count == 0) { + return; + } + + thread float values[GROUP_SIZE]; + ulong seed = (ulong(uint(SEED_HI)) << 32) | ulong(uint(SEED_LO)); + + if (ROLE == 1) { + float minimum = INFINITY; + float maximum = -INFINITY; + for (uint local = 0; local < count; local++) { + float value = float(x[start + local]); + minimum = min(minimum, value); + maximum = max(maximum, value); + } + + float value_max = float((1 << VALUE_BITS) - 1); + float range = maximum - minimum; + float value_scale = range > 1.17549435e-38f ? range / value_max : 0.0f; + uint scale_base = group_id * uint(SCALES_PER_GROUP); + scales[scale_base] = value_scale; + scales[scale_base + 1] = minimum; + + uint packed_base = group_id * MAG_WORDS_PER_GROUP; + for (uint word = 0; word < MAG_WORDS_PER_GROUP; word++) { + packed[packed_base + word] = 0u; + } + + for (uint local = 0; local < count; local++) { + float value = float(x[start + local]); + uint quantized = value_scale == 0.0f + ? 0u + : uint(clamp(round((value - minimum) / value_scale), 0.0f, value_max)); + uint bit_offset = local * uint(VALUE_BITS); + for (uint bit = 0; bit < uint(VALUE_BITS); bit++) { + if ((quantized & (1u << bit)) != 0u) { + uint global_bit = bit_offset + bit; + uint packed_word = global_bit >> 5; + uint packed_bit = global_bit & 31u; + packed[packed_base + packed_word] |= 1u << packed_bit; + } + } + } + return; + } + + float norm_squared = 0.0f; + for (uint local = 0; local < count; local++) { + float value = float(x[start + local]); + values[local] = value; + norm_squared += value * value; + } + + float norm = sqrt(norm_squared); + float inv_norm = norm > 1.17549435e-38f ? 1.0f / norm : 0.0f; + for (uint local = 0; local < count; local++) { + values[local] *= inv_norm; + } + tq_apply_product_rotation(values, count, seed, group_id, false); + + uint scale_base = group_id * uint(SCALES_PER_GROUP); + scales[scale_base] = norm; + scales[scale_base + 1] = 0.0f; + scales[scale_base + 2] = 0.0f; + + uint bitset_base = group_id * BITSET_WORDS_PER_GROUP; + for (uint word = 0; word < BITSET_WORDS_PER_GROUP; word++) { + signs[bitset_base + word] = 0u; + high_mask[bitset_base + word] = 0u; + residual_signs[bitset_base + word] = 0u; + } + + uint packed_base = group_id * MAG_WORDS_PER_GROUP; + for (uint word = 0; word < MAG_WORDS_PER_GROUP; word++) { + packed[packed_base + word] = 0u; + } + + uint high_count = uint(round(float(count * uint(HIGH_NUMERATOR)) / float(uint(HIGH_DENOMINATOR)))); + float residual_squared = 0.0f; + uint bit_offset = 0; + for (uint local = 0; local < count; local++) { + bool high_precision = tq_product_high_precision(seed, group_id, local, count, high_count); + uint bits = high_precision ? uint(KEY_HIGH_BITS) : uint(KEY_BASE_BITS); + uint quantized = tq_nearest_codebook_index(values[local], bits, count); + float reconstructed = tq_codebook_level(bits, quantized, count); + + uint word_index = local >> 5; + uint word_bit = local & 31u; + uint mask_bit = 1u << word_bit; + if (high_precision) { + high_mask[bitset_base + word_index] |= mask_bit; + } + float residual = values[local] - reconstructed; + residual_squared += residual * residual; + if (residual < 0.0f) { + signs[bitset_base + word_index] |= mask_bit; + } + + for (uint bit = 0; bit < bits; bit++) { + if ((quantized & (1u << bit)) != 0u) { + uint global_bit = bit_offset + bit; + uint packed_word = global_bit >> 5; + uint packed_bit = global_bit & 31u; + packed[packed_base + packed_word] |= 1u << packed_bit; + } + } + bit_offset += bits; + } + scales[scale_base + 1] = norm * sqrt(residual_squared); + """ + + private static let decodeSource = """ + uint index = thread_position_in_grid.x; + if (index >= VALUE_COUNT) { + return; + } + + ulong seed = (ulong(uint(SEED_HI)) << 32) | ulong(uint(SEED_LO)); + uint group_id = index / uint(GROUP_SIZE); + uint local = index - group_id * uint(GROUP_SIZE); + uint packed_base = group_id * uint(MAG_WORDS_PER_GROUP); + if (ROLE == 1) { + uint bit_offset = local * uint(VALUE_BITS); + uint quantized = 0u; + for (uint bit = 0; bit < uint(VALUE_BITS); bit++) { + uint global_bit = bit_offset + bit; + uint packed_word = global_bit >> 5; + uint packed_bit = global_bit & 31u; + if ((packed[packed_base + packed_word] & (1u << packed_bit)) != 0u) { + quantized |= 1u << bit; + } + } + uint scale_base = group_id * uint(SCALES_PER_GROUP); + out[index] = scales[scale_base + 1] + float(quantized) * scales[scale_base]; + return; + } + + uint count = min(uint(GROUP_SIZE), uint(VALUE_COUNT) - group_id * uint(GROUP_SIZE)); + thread float rotated[GROUP_SIZE]; + uint bitset_base = group_id * uint(BITSET_WORDS_PER_GROUP); + for (uint decode_local = 0u; decode_local < count; decode_local++) { + uint word_index = decode_local >> 5; + uint word_bit = decode_local & 31u; + bool high_precision = (high_mask[bitset_base + word_index] & (1u << word_bit)) != 0u; + uint bits = high_precision ? uint(KEY_HIGH_BITS) : uint(KEY_BASE_BITS); + uint bit_offset = 0u; + for (uint prior = 0u; prior < decode_local; prior++) { + uint prior_word = prior >> 5; + uint prior_bit = prior & 31u; + bool prior_high = + (high_mask[bitset_base + prior_word] & (1u << prior_bit)) != 0u; + bit_offset += prior_high ? uint(KEY_HIGH_BITS) : uint(KEY_BASE_BITS); + } + uint code = 0u; + for (uint bit = 0u; bit < bits; bit++) { + uint global_bit = bit_offset + bit; + uint packed_word = global_bit >> 5; + uint packed_bit = global_bit & 31u; + if ((packed[packed_base + packed_word] & (1u << packed_bit)) != 0u) { + code |= 1u << bit; + } + } + rotated[decode_local] = tq_codebook_level(bits, code, count); + } + tq_apply_product_rotation(rotated, count, seed, group_id, true); + out[index] = rotated[local] * scales[group_id * uint(SCALES_PER_GROUP)]; + """ + + private static let matmulSource = """ + uint index = thread_position_in_grid.x; + uint total = uint(X_ROWS) * (TRANSPOSE_WEIGHT ? uint(WEIGHT_ROWS) : uint(WEIGHT_COLUMNS)); + if (index >= total) { + return; + } + + uint output_columns = TRANSPOSE_WEIGHT ? uint(WEIGHT_ROWS) : uint(WEIGHT_COLUMNS); + uint row = index / output_columns; + uint column = index - row * output_columns; + uint reduction = uint(X_COLUMNS); + ulong seed = (ulong(uint(SEED_HI)) << 32) | ulong(uint(SEED_LO)); + float sum = 0.0f; + + for (uint k = 0u; k < reduction; k++) { + uint x_index = row * uint(X_COLUMNS) + k; + uint weight_index = TRANSPOSE_WEIGHT + ? column * uint(WEIGHT_COLUMNS) + k + : k * uint(WEIGHT_COLUMNS) + column; + float weight = tq_decode_flat_value( + packed, signs, high_mask, residual_signs, scales, + weight_index, seed, uint(ROLE), + uint(GROUP_SIZE), uint(MAG_WORDS_PER_GROUP), uint(BITSET_WORDS_PER_GROUP), + uint(BASE_BITS), uint(HIGH_BITS), uint(KEY_BASE_BITS), uint(KEY_HIGH_BITS), + uint(VALUE_BITS), uint(SCALES_PER_GROUP), uint(VALUE_COUNT)); + sum += float(x[x_index]) * weight; + } + out[index] = sum; + """ + + private static let attentionHeader = """ + inline ulong tq_mix(ulong seed, uint index) { + ulong mixed = seed + ulong(index) * 0x9E3779B97F4A7C15ul; + mixed ^= mixed >> 30; + mixed *= 0xBF58476D1CE4E5B9ul; + mixed ^= mixed >> 27; + mixed *= 0x94D049BB133111EBul; + mixed ^= mixed >> 31; + return mixed; + } + + inline bool tq_random_sign(ulong seed, uint index) { + return (tq_mix(seed, index) & 1ul) != 0ul; + } + + inline ulong tq_mix_index(ulong seed, ulong index) { + ulong mixed = seed + index * 0x9E3779B97F4A7C15ul; + mixed ^= mixed >> 30; + mixed *= 0xBF58476D1CE4E5B9ul; + mixed ^= mixed >> 27; + mixed *= 0x94D049BB133111EBul; + mixed ^= mixed >> 31; + return mixed; + } + + inline bool tq_random_sign_index(ulong seed, ulong index) { + return (tq_mix_index(seed, index) & 1ul) != 0ul; + } + + inline ulong tq_product_channel_rank(ulong seed, uint group_index, uint local_index) { + ulong state = seed; + state ^= ulong(group_index) * 0x9E3779B97F4A7C15ul; + state += ulong(local_index) * 0xD1B54A32D192ED03ul; + state ^= state >> 30; + state *= 0xBF58476D1CE4E5B9ul; + state ^= state >> 27; + state *= 0x94D049BB133111EBul; + state ^= state >> 31; + return state; + } + + inline bool tq_product_high_precision( + ulong seed, + uint group_index, + uint local, + uint count, + uint high_count + ) { + if (high_count == 0u) { + return false; + } + if (high_count >= count) { + return true; + } + ulong local_rank = tq_product_channel_rank(seed, group_index, local); + uint rank = 0u; + for (uint other = 0u; other < count; other++) { + ulong other_rank = tq_product_channel_rank(seed, group_index, other); + if (other_rank < local_rank || (other_rank == local_rank && other < local)) { + rank += 1u; + } + } + return rank < high_count; + } + + inline float tq_codebook_unit(uint bits, uint code) { + if (bits <= 1u) { + return code == 0u ? -0.797884561f : 0.797884561f; + } + if (bits == 2u) { + switch (min(code, 3u)) { + case 0u: return -1.510499245f; + case 1u: return -0.452819573f; + case 2u: return 0.452819573f; + default: return 1.510499245f; + } + } + if (bits == 3u) { + switch (min(code, 7u)) { + case 0u: return -2.175028018f; + case 1u: return -1.367204388f; + case 2u: return -0.773020220f; + case 3u: return -0.251312159f; + case 4u: return 0.251312159f; + case 5u: return 0.773020220f; + case 6u: return 1.367204388f; + default: return 2.175028018f; + } + } + switch (min(code, 15u)) { + case 0u: return -2.778927695f; + case 1u: return -2.124836923f; + case 2u: return -1.680512470f; + case 3u: return -1.321175453f; + case 4u: return -1.003692455f; + case 5u: return -0.707453186f; + case 6u: return -0.421537889f; + case 7u: return -0.140103661f; + case 8u: return 0.140103661f; + case 9u: return 0.421537889f; + case 10u: return 0.707453186f; + case 11u: return 1.003692455f; + case 12u: return 1.321175453f; + case 13u: return 1.680512470f; + case 14u: return 2.124836923f; + default: return 2.778927695f; + } + } + + inline float tq_codebook_level(uint bits, uint code, uint count) { + return tq_codebook_unit(bits, code) * rsqrt(float(max(count, 1u))); + } + + inline uint tq_nearest_codebook_index(float value, uint bits, uint count) { + uint level_count = 1u << bits; + uint best_index = 0u; + float best_distance = INFINITY; + for (uint code = 0u; code < level_count; code++) { + float distance = fabs(value - tq_codebook_level(bits, code, count)); + if (distance < best_distance) { + best_distance = distance; + best_index = code; + } + } + return best_index; + } + + inline void tq_fast_hadamard(thread float* values, uint count) { + for (uint width = 1u; width < count; width <<= 1u) { + for (uint start = 0u; start < count; start += width << 1u) { + for (uint offset = 0u; offset < width; offset++) { + float lhs = values[start + offset]; + float rhs = values[start + offset + width]; + values[start + offset] = lhs + rhs; + values[start + offset + width] = lhs - rhs; + } + } + } + } + + inline void tq_apply_rotation_signs( + thread float* values, + uint count, + ulong seed, + uint group_index + ) { + for (uint local = 0u; local < count; local++) { + ulong sign_index = ulong(group_index) * 4099ul + ulong(local); + if (tq_random_sign_index(seed, sign_index)) { + values[local] = -values[local]; + } + } + } + + inline void tq_apply_givens_pass( + thread float* values, + uint count, + ulong seed, + uint group_index, + uint pass, + float direction + ) { + uint offset = pass & 1u; + for (uint index = offset; index + 1u < count; index += 2u) { + ulong angle_rank = tq_product_channel_rank( + seed ^ (ulong(pass) * 0xA24BAED4963EE407ul), + group_index, + index >> 1u); + float unit = float(uint(angle_rank)) / 4294967295.0f; + float angle = (unit - 0.5f) * 3.14159265358979323846f * direction; + float c = cos(angle); + float s = sin(angle); + float lhs = values[index]; + float rhs = values[index + 1u]; + values[index] = c * lhs - s * rhs; + values[index + 1u] = s * lhs + c * rhs; + } + } + + inline void tq_apply_product_rotation( + thread float* values, + uint count, + ulong seed, + uint group_index, + bool inverse + ) { + if (count <= 1u) { + tq_apply_rotation_signs(values, count, seed, group_index); + return; + } + if ((count & (count - 1u)) == 0u) { + if (inverse) { + tq_fast_hadamard(values, count); + tq_apply_rotation_signs(values, count, seed, group_index); + } else { + tq_apply_rotation_signs(values, count, seed, group_index); + tq_fast_hadamard(values, count); + } + float scale = rsqrt(float(count)); + for (uint local = 0u; local < count; local++) { + values[local] *= scale; + } + return; + } + if (inverse) { + for (uint pass_index = 0u; pass_index < 4u; pass_index++) { + tq_apply_givens_pass(values, count, seed, group_index, 3u - pass_index, -1.0f); + } + } else { + for (uint pass = 0u; pass < 4u; pass++) { + tq_apply_givens_pass(values, count, seed, group_index, pass, 1.0f); + } + } + } + + inline uint tq_bitset_offset( + uint batch, + uint head, + uint token, + uint group, + uint word, + uint kv_heads, + uint capacity, + uint groups_per_vector, + uint bitset_words_per_group + ) { + return (((batch * kv_heads + head) * capacity + token) + * groups_per_vector + group) * bitset_words_per_group + word; + } + + inline uint tq_packed_offset( + uint batch, + uint head, + uint token, + uint group, + uint word, + uint kv_heads, + uint capacity, + uint groups_per_vector, + uint mag_words_per_group + ) { + return (((batch * kv_heads + head) * capacity + token) + * groups_per_vector + group) * mag_words_per_group + word; + } + + inline uint tq_scale_offset( + uint batch, + uint head, + uint token, + uint group, + uint scale_index, + uint kv_heads, + uint capacity, + uint groups_per_vector + ) { + return ((((batch * kv_heads + head) * capacity + token) + * groups_per_vector + group) * 3u) + scale_index; + } + + inline uint tq_physical_token( + uint logical_token, + uint capacity, + uint ring_offset, + uint pinned_prefix_length + ) { + uint pinned = pinned_prefix_length; + if (logical_token < pinned) { + return logical_token; + } + uint ring_capacity = capacity - pinned; + if (ring_capacity == 0u) { + return min(logical_token, capacity - 1u); + } + uint ring_logical = logical_token - pinned; + return pinned + ((ring_offset + ring_logical) % ring_capacity); + } + + inline uint tq_read_magnitude( + device const uint* packed, + device const uint* high_mask, + uint batch, + uint head, + uint token, + uint group, + uint local, + uint kv_heads, + uint capacity, + uint groups_per_vector, + uint mag_words_per_group, + uint bitset_words_per_group, + uint base_bits, + uint high_bits + ) { + uint bitset_word = local >> 5; + uint bitset_bit = local & 31u; + bool high_precision = + (high_mask[tq_bitset_offset( + batch, head, token, group, bitset_word, + kv_heads, capacity, groups_per_vector, bitset_words_per_group)] + & (1u << bitset_bit)) != 0u; + uint bits = high_precision ? high_bits : base_bits; + + uint bit_offset = 0u; + for (uint prior = 0; prior < local; prior++) { + uint prior_word = prior >> 5; + uint prior_bit = prior & 31u; + bool prior_high = + (high_mask[tq_bitset_offset( + batch, head, token, group, prior_word, + kv_heads, capacity, groups_per_vector, bitset_words_per_group)] + & (1u << prior_bit)) != 0u; + bit_offset += prior_high ? high_bits : base_bits; + } + + uint quantized = 0u; + for (uint bit = 0; bit < bits; bit++) { + uint global_bit = bit_offset + bit; + uint packed_word = global_bit >> 5; + uint packed_bit = global_bit & 31u; + if ((packed[tq_packed_offset( + batch, head, token, group, packed_word, + kv_heads, capacity, groups_per_vector, mag_words_per_group)] + & (1u << packed_bit)) != 0u) { + quantized |= 1u << bit; + } + } + return quantized; + } + + inline uint tq_storage_group_index( + uint batch, + uint head, + uint token, + uint group, + uint kv_heads, + uint capacity, + uint groups_per_vector + ) { + return ((batch * kv_heads + head) * capacity + token) * groups_per_vector + group; + } + + inline float tq_decode_attention_value( + device const uint* packed, + device const uint* signs, + device const uint* high_mask, + device const uint* residual_signs, + device const float* scales, + uint batch, + uint head, + uint token, + uint dimension, + ulong seed, + uint role, + uint group_size, + uint kv_heads, + uint capacity, + uint groups_per_vector, + uint mag_words_per_group, + uint bitset_words_per_group, + uint base_bits, + uint high_bits, + uint value_bits, + uint key_base_bits, + uint key_high_bits, + uint head_dim, + thread float* rotated + ) { + uint group = dimension / group_size; + uint local = dimension - group * group_size; + if (role == 1u) { + uint bit_offset = local * value_bits; + uint quantized = 0u; + for (uint bit = 0; bit < value_bits; bit++) { + uint global_bit = bit_offset + bit; + uint packed_word = global_bit >> 5; + uint packed_bit = global_bit & 31u; + if ((packed[tq_packed_offset( + batch, head, token, group, packed_word, + kv_heads, capacity, groups_per_vector, mag_words_per_group)] + & (1u << packed_bit)) != 0u) { + quantized |= 1u << bit; + } + } + uint scale_base = ((((batch * kv_heads + head) * capacity + token) + * groups_per_vector + group) * 2u); + return scales[scale_base + 1u] + float(quantized) * scales[scale_base]; + } + + uint group_start = group * group_size; + uint count = min(group_size, head_dim - group_start); + uint storage_group = tq_storage_group_index( + batch, head, token, group, kv_heads, capacity, groups_per_vector); + for (uint decode_local = 0u; decode_local < count; decode_local++) { + uint bitset_word = decode_local >> 5; + uint bitset_bit = decode_local & 31u; + bool high_precision = + (high_mask[tq_bitset_offset( + batch, head, token, group, bitset_word, + kv_heads, capacity, groups_per_vector, bitset_words_per_group)] + & (1u << bitset_bit)) != 0u; + uint bits = high_precision ? key_high_bits : key_base_bits; + uint code = tq_read_magnitude( + packed, high_mask, batch, head, token, group, decode_local, + kv_heads, capacity, groups_per_vector, + mag_words_per_group, bitset_words_per_group, + key_base_bits, key_high_bits); + rotated[decode_local] = tq_codebook_level(bits, code, count); + } + tq_apply_product_rotation(rotated, count, seed, storage_group, true); + return rotated[local] * scales[tq_scale_offset( + batch, head, token, group, 0u, kv_heads, capacity, groups_per_vector)]; + } + + inline float tq_product_attention_inner_product_group( + device const uint* packed, + device const uint* signs, + device const uint* high_mask, + device const float* scales, + thread float* query_values, + uint batch, + uint head, + uint token, + uint group, + ulong seed, + uint group_size, + uint kv_heads, + uint capacity, + uint groups_per_vector, + uint mag_words_per_group, + uint bitset_words_per_group, + uint key_base_bits, + uint key_high_bits, + uint head_dim + ) { + uint group_start = group * group_size; + uint count = min(group_size, head_dim - group_start); + uint storage_group = tq_storage_group_index( + batch, head, token, group, kv_heads, capacity, groups_per_vector); + tq_apply_product_rotation(query_values, count, seed, storage_group, false); + + float quantized_dot = 0.0f; + float sign_dot = 0.0f; + for (uint local = 0u; local < count; local++) { + uint bitset_word = local >> 5; + uint bitset_bit = local & 31u; + uint bit_mask = 1u << bitset_bit; + bool high_precision = + (high_mask[tq_bitset_offset( + batch, head, token, group, bitset_word, + kv_heads, capacity, groups_per_vector, bitset_words_per_group)] & bit_mask) != 0u; + uint bits = high_precision ? key_high_bits : key_base_bits; + uint code = tq_read_magnitude( + packed, high_mask, batch, head, token, group, local, + kv_heads, capacity, groups_per_vector, + mag_words_per_group, bitset_words_per_group, + key_base_bits, key_high_bits); + quantized_dot += query_values[local] * tq_codebook_level(bits, code, count); + float qjl_sign = + (signs[tq_bitset_offset( + batch, head, token, group, bitset_word, + kv_heads, capacity, groups_per_vector, bitset_words_per_group)] & bit_mask) != 0u + ? -1.0f : 1.0f; + sign_dot += qjl_sign * query_values[local]; + } + + float norm = scales[tq_scale_offset( + batch, head, token, group, 0u, kv_heads, capacity, groups_per_vector)]; + float residual_norm = scales[tq_scale_offset( + batch, head, token, group, 1u, kv_heads, capacity, groups_per_vector)]; + float residual = residual_norm * sqrt(3.14159265358979323846f / (2.0f * float(count))) * sign_dot; + return norm * quantized_dot + residual; + } + """ + + private static let encodeAttentionSource = """ + uint row_group_id = thread_position_in_grid.x; + uint kv_heads = uint(KV_HEADS); + uint capacity = uint(CAPACITY); + uint groups_per_vector = uint(GROUPS_PER_VECTOR); + uint mag_words_per_group = uint(MAG_WORDS_PER_GROUP); + uint bitset_words_per_group = uint(BITSET_WORDS_PER_GROUP); + uint total = uint(BATCH_SIZE) * kv_heads * uint(INPUT_LENGTH) * groups_per_vector; + if (row_group_id >= total) { + return; + } + + uint group = row_group_id % groups_per_vector; + uint token = (row_group_id / groups_per_vector) % uint(INPUT_LENGTH); + uint head = (row_group_id / (groups_per_vector * uint(INPUT_LENGTH))) % kv_heads; + uint batch = row_group_id / (groups_per_vector * uint(INPUT_LENGTH) * kv_heads); + if (token >= capacity) { + return; + } + + uint group_start = group * uint(GROUP_SIZE); + uint count = min(uint(GROUP_SIZE), uint(HEAD_DIM) - group_start); + if (ROLE == 1) { + float minimum = INFINITY; + float maximum = -INFINITY; + for (uint local = 0; local < count; local++) { + uint dimension = group_start + local; + uint input_index = + (((batch * uint(KV_HEADS) + head) * uint(INPUT_LENGTH) + token) + * uint(HEAD_DIM)) + dimension; + float value = float(x[input_index]); + minimum = min(minimum, value); + maximum = max(maximum, value); + } + + float value_max = float((1 << VALUE_BITS) - 1); + float range = maximum - minimum; + float value_scale = range > 1.17549435e-38f ? range / value_max : 0.0f; + uint scale_base = ((((batch * kv_heads + head) * capacity + token) + * groups_per_vector + group) * 2u); + scales[scale_base] = value_scale; + scales[scale_base + 1u] = minimum; + + for (uint word = 0; word < mag_words_per_group; word++) { + packed[tq_packed_offset(batch, head, token, group, word, kv_heads, capacity, groups_per_vector, mag_words_per_group)] = 0u; + } + for (uint local = 0; local < count; local++) { + uint dimension = group_start + local; + uint input_index = + (((batch * uint(KV_HEADS) + head) * uint(INPUT_LENGTH) + token) + * uint(HEAD_DIM)) + dimension; + float value = float(x[input_index]); + uint quantized = value_scale == 0.0f + ? 0u + : uint(clamp(round((value - minimum) / value_scale), 0.0f, value_max)); + uint bit_offset = local * uint(VALUE_BITS); + for (uint packed_bit = 0; packed_bit < uint(VALUE_BITS); packed_bit++) { + if ((quantized & (1u << packed_bit)) != 0u) { + uint global_bit = bit_offset + packed_bit; + uint packed_word = global_bit >> 5; + uint packed_word_bit = global_bit & 31u; + packed[tq_packed_offset(batch, head, token, group, packed_word, kv_heads, capacity, groups_per_vector, mag_words_per_group)] |= + 1u << packed_word_bit; + } + } + } + return; + } + + thread float values[GROUP_SIZE]; + ulong seed = (ulong(uint(SEED_HI)) << 32) | ulong(uint(SEED_LO)); + uint storage_group = tq_storage_group_index( + batch, head, token, group, kv_heads, capacity, groups_per_vector); + float norm_squared = 0.0f; + + for (uint local = 0; local < count; local++) { + uint dimension = group_start + local; + uint input_index = + (((batch * uint(KV_HEADS) + head) * uint(INPUT_LENGTH) + token) + * uint(HEAD_DIM)) + dimension; + float value = float(x[input_index]); + values[local] = value; + norm_squared += value * value; + } + + float norm = sqrt(norm_squared); + float inv_norm = norm > 1.17549435e-38f ? 1.0f / norm : 0.0f; + for (uint local = 0; local < count; local++) { + values[local] *= inv_norm; + } + tq_apply_product_rotation(values, count, seed, storage_group, false); + + scales[tq_scale_offset(batch, head, token, group, 0u, kv_heads, capacity, groups_per_vector)] = norm; + scales[tq_scale_offset(batch, head, token, group, 1u, kv_heads, capacity, groups_per_vector)] = 0.0f; + scales[tq_scale_offset(batch, head, token, group, 2u, kv_heads, capacity, groups_per_vector)] = 0.0f; + + for (uint word = 0; word < bitset_words_per_group; word++) { + signs[tq_bitset_offset(batch, head, token, group, word, kv_heads, capacity, groups_per_vector, bitset_words_per_group)] = 0u; + high_mask[tq_bitset_offset(batch, head, token, group, word, kv_heads, capacity, groups_per_vector, bitset_words_per_group)] = 0u; + residual_signs[tq_bitset_offset(batch, head, token, group, word, kv_heads, capacity, groups_per_vector, bitset_words_per_group)] = 0u; + } + for (uint word = 0; word < mag_words_per_group; word++) { + packed[tq_packed_offset(batch, head, token, group, word, kv_heads, capacity, groups_per_vector, mag_words_per_group)] = 0u; + } + + uint high_count = uint(round(float(count) * 0.5f)); + float residual_squared = 0.0f; + uint bit_offset = 0u; + for (uint local = 0; local < count; local++) { + bool high_precision = tq_product_high_precision(seed, storage_group, local, count, high_count); + uint bits = high_precision ? uint(KEY_HIGH_BITS) : uint(KEY_BASE_BITS); + uint quantized = tq_nearest_codebook_index(values[local], bits, count); + float reconstructed = tq_codebook_level(bits, quantized, count); + + uint word = local >> 5; + uint bit = local & 31u; + uint mask = 1u << bit; + if (high_precision) { + high_mask[tq_bitset_offset(batch, head, token, group, word, kv_heads, capacity, groups_per_vector, bitset_words_per_group)] |= mask; + } + float residual = values[local] - reconstructed; + residual_squared += residual * residual; + if (residual < 0.0f) { + signs[tq_bitset_offset(batch, head, token, group, word, kv_heads, capacity, groups_per_vector, bitset_words_per_group)] |= mask; + } + + for (uint packed_bit = 0; packed_bit < bits; packed_bit++) { + if ((quantized & (1u << packed_bit)) != 0u) { + uint global_bit = bit_offset + packed_bit; + uint packed_word = global_bit >> 5; + uint packed_word_bit = global_bit & 31u; + packed[tq_packed_offset(batch, head, token, group, packed_word, kv_heads, capacity, groups_per_vector, mag_words_per_group)] |= + 1u << packed_word_bit; + } + } + bit_offset += bits; + } + scales[tq_scale_offset(batch, head, token, group, 1u, kv_heads, capacity, groups_per_vector)] = + norm * sqrt(residual_squared); + """ + + private static let qkSource = """ + uint index = thread_position_in_grid.x; + uint total = uint(BATCH_SIZE) * uint(QUERY_HEADS) * uint(QUERY_LENGTH) * uint(LOGICAL_LENGTH); + if (index >= total) { + return; + } + + float attention_scale = as_type(uint(ATTENTION_SCALE_BITS)); + uint logical_token = index % uint(LOGICAL_LENGTH); + uint q_token = (index / uint(LOGICAL_LENGTH)) % uint(QUERY_LENGTH); + uint q_head = (index / (uint(LOGICAL_LENGTH) * uint(QUERY_LENGTH))) % uint(QUERY_HEADS); + uint batch = index / (uint(LOGICAL_LENGTH) * uint(QUERY_LENGTH) * uint(QUERY_HEADS)); + uint repeats = uint(QUERY_HEADS) / uint(KV_HEADS); + uint kv_head = q_head / repeats; + uint physical_token = tq_physical_token( + logical_token, uint(CAPACITY), uint(RING_OFFSET), uint(PINNED_PREFIX_LENGTH)); + + float sum = 0.0f; + ulong seed = (ulong(uint(SEED_HI)) << 32) | ulong(uint(SEED_LO)); + for (uint group = 0u; group < uint(GROUPS_PER_VECTOR); group++) { + uint group_start = group * uint(GROUP_SIZE); + uint count = min(uint(GROUP_SIZE), uint(HEAD_DIM) - group_start); + thread float query_values[GROUP_SIZE]; + for (uint local = 0u; local < count; local++) { + uint dimension = group_start + local; + uint q_index = + (((batch * uint(QUERY_HEADS) + q_head) * uint(QUERY_LENGTH) + q_token) + * uint(HEAD_DIM)) + dimension; + query_values[local] = float(q[q_index]); + } + sum += tq_product_attention_inner_product_group( + k_packed, k_signs, k_high_mask, k_scales, query_values, + batch, kv_head, physical_token, group, seed, + uint(GROUP_SIZE), uint(KV_HEADS), uint(CAPACITY), uint(GROUPS_PER_VECTOR), + uint(MAG_WORDS_PER_GROUP), uint(BITSET_WORDS_PER_GROUP), + uint(KEY_BASE_BITS), uint(KEY_HIGH_BITS), uint(HEAD_DIM)); + } + scores[index] = sum * attention_scale; + """ + + private static let decodeAttentionSource = """ + uint index = thread_position_in_grid.x; + uint total = uint(BATCH_SIZE) * uint(KV_HEADS) * uint(LOGICAL_LENGTH) * uint(HEAD_DIM); + if (index >= total) { + return; + } + + uint dimension = index % uint(HEAD_DIM); + uint logical_token = (index / uint(HEAD_DIM)) % uint(LOGICAL_LENGTH); + uint head = (index / (uint(HEAD_DIM) * uint(LOGICAL_LENGTH))) % uint(KV_HEADS); + uint batch = index / (uint(HEAD_DIM) * uint(LOGICAL_LENGTH) * uint(KV_HEADS)); + uint physical_token = tq_physical_token( + logical_token, uint(CAPACITY), uint(RING_OFFSET), uint(PINNED_PREFIX_LENGTH)); + thread float decode_scratch[GROUP_SIZE]; + out[index] = tq_decode_attention_value( + packed, signs, high_mask, residual_signs, scales, + batch, head, physical_token, dimension, + (ulong(uint(SEED_HI)) << 32) | ulong(uint(SEED_LO)), uint(ROLE), + uint(GROUP_SIZE), uint(KV_HEADS), uint(CAPACITY), uint(GROUPS_PER_VECTOR), + uint(MAG_WORDS_PER_GROUP), uint(BITSET_WORDS_PER_GROUP), uint(BASE_BITS), uint(HIGH_BITS), + uint(VALUE_BITS), uint(KEY_BASE_BITS), uint(KEY_HIGH_BITS), uint(HEAD_DIM), + decode_scratch); + """ + + private static let avSource = """ + uint index = thread_position_in_grid.x; + uint total = uint(BATCH_SIZE) * uint(QUERY_HEADS) * uint(QUERY_LENGTH) * uint(HEAD_DIM); + if (index >= total) { + return; + } + + uint dimension = index % uint(HEAD_DIM); + uint q_token = (index / uint(HEAD_DIM)) % uint(QUERY_LENGTH); + uint q_head = (index / (uint(HEAD_DIM) * uint(QUERY_LENGTH))) % uint(QUERY_HEADS); + uint batch = index / (uint(HEAD_DIM) * uint(QUERY_LENGTH) * uint(QUERY_HEADS)); + uint repeats = uint(QUERY_HEADS) / uint(KV_HEADS); + uint kv_head = q_head / repeats; + + float sum = 0.0f; + thread float decode_scratch[GROUP_SIZE]; + for (uint logical_token = 0; logical_token < uint(LOGICAL_LENGTH); logical_token++) { + uint physical_token = tq_physical_token( + logical_token, uint(CAPACITY), uint(RING_OFFSET), uint(PINNED_PREFIX_LENGTH)); + uint weight_index = + (((batch * uint(QUERY_HEADS) + q_head) * uint(QUERY_LENGTH) + q_token) + * uint(LOGICAL_LENGTH)) + logical_token; + float value = tq_decode_attention_value( + v_packed, v_signs, v_high_mask, v_residual_signs, v_scales, + batch, kv_head, physical_token, dimension, + (ulong(uint(SEED_HI)) << 32) | ulong(uint(SEED_LO)), 1u, + uint(GROUP_SIZE), uint(KV_HEADS), uint(CAPACITY), uint(GROUPS_PER_VECTOR), + uint(MAG_WORDS_PER_GROUP), uint(BITSET_WORDS_PER_GROUP), uint(BASE_BITS), uint(HIGH_BITS), + uint(VALUE_BITS), uint(KEY_BASE_BITS), uint(KEY_HIGH_BITS), uint(HEAD_DIM), + decode_scratch); + sum += float(weights[weight_index]) * value; + } + out[index] = sum; + """ + + private static let fusedAttentionSource = """ + constexpr uint threads_per_row = uint(THREADS_PER_ROW); + uint lane = thread_position_in_threadgroup.x; + uint row = threadgroup_position_in_grid.x; + uint total_rows = uint(BATCH_SIZE) * uint(QUERY_HEADS) * uint(QUERY_LENGTH); + if (row >= total_rows) { + return; + } + + threadgroup float partial[256]; + threadgroup float tile_weights[256]; + threadgroup uint tile_physical_tokens[256]; + + float attention_scale = as_type(uint(ATTENTION_SCALE_BITS)); + uint q_token = row % uint(QUERY_LENGTH); + uint q_head = (row / uint(QUERY_LENGTH)) % uint(QUERY_HEADS); + uint batch = row / (uint(QUERY_LENGTH) * uint(QUERY_HEADS)); + uint repeats = uint(QUERY_HEADS) / uint(KV_HEADS); + uint kv_head = q_head / repeats; + uint causal_limit = uint(LOGICAL_LENGTH) - uint(QUERY_LENGTH) + q_token; + ulong key_seed = (ulong(uint(SEED_HI)) << 32) | ulong(uint(SEED_LO)); + + float row_max = -INFINITY; + for (uint logical_token = lane; logical_token < uint(LOGICAL_LENGTH); logical_token += threads_per_row) { + if (DO_CAUSAL && logical_token > causal_limit) { + continue; + } + uint physical_token = tq_physical_token( + logical_token, uint(CAPACITY), uint(RING_OFFSET), uint(PINNED_PREFIX_LENGTH)); + float score = 0.0f; + for (uint group = 0u; group < uint(GROUPS_PER_VECTOR); group++) { + uint group_start = group * uint(GROUP_SIZE); + uint count = min(uint(GROUP_SIZE), uint(HEAD_DIM) - group_start); + thread float query_values[GROUP_SIZE]; + for (uint local = 0u; local < count; local++) { + uint dimension = group_start + local; + uint q_index = + (((batch * uint(QUERY_HEADS) + q_head) * uint(QUERY_LENGTH) + q_token) + * uint(HEAD_DIM)) + dimension; + query_values[local] = float(q[q_index]); + } + score += tq_product_attention_inner_product_group( + k_packed, k_signs, k_high_mask, k_scales, query_values, + batch, kv_head, physical_token, group, key_seed, + uint(GROUP_SIZE), uint(KV_HEADS), uint(CAPACITY), uint(GROUPS_PER_VECTOR), + uint(MAG_WORDS_PER_GROUP), uint(BITSET_WORDS_PER_GROUP), + uint(KEY_BASE_BITS), uint(KEY_HIGH_BITS), uint(HEAD_DIM)); + } + row_max = max(row_max, score * attention_scale); + } + partial[lane] = row_max; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint stride = threads_per_row >> 1; stride > 0u; stride >>= 1) { + if (lane < stride) { + partial[lane] = max(partial[lane], partial[lane + stride]); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + row_max = partial[0]; + + float row_sum = 0.0f; + for (uint logical_token = lane; logical_token < uint(LOGICAL_LENGTH); logical_token += threads_per_row) { + if (DO_CAUSAL && logical_token > causal_limit) { + continue; + } + uint physical_token = tq_physical_token( + logical_token, uint(CAPACITY), uint(RING_OFFSET), uint(PINNED_PREFIX_LENGTH)); + float score = 0.0f; + for (uint group = 0u; group < uint(GROUPS_PER_VECTOR); group++) { + uint group_start = group * uint(GROUP_SIZE); + uint count = min(uint(GROUP_SIZE), uint(HEAD_DIM) - group_start); + thread float query_values[GROUP_SIZE]; + for (uint local = 0u; local < count; local++) { + uint dimension = group_start + local; + uint q_index = + (((batch * uint(QUERY_HEADS) + q_head) * uint(QUERY_LENGTH) + q_token) + * uint(HEAD_DIM)) + dimension; + query_values[local] = float(q[q_index]); + } + score += tq_product_attention_inner_product_group( + k_packed, k_signs, k_high_mask, k_scales, query_values, + batch, kv_head, physical_token, group, key_seed, + uint(GROUP_SIZE), uint(KV_HEADS), uint(CAPACITY), uint(GROUPS_PER_VECTOR), + uint(MAG_WORDS_PER_GROUP), uint(BITSET_WORDS_PER_GROUP), + uint(KEY_BASE_BITS), uint(KEY_HIGH_BITS), uint(HEAD_DIM)); + } + float weight = exp(score * attention_scale - row_max); + row_sum += weight; + } + partial[lane] = row_sum; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint stride = threads_per_row >> 1; stride > 0u; stride >>= 1) { + if (lane < stride) { + partial[lane] += partial[lane + stride]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + row_sum = partial[0]; + + float inv_sum = 1.0f / max(row_sum, 1.17549435e-38f); + if (lane < uint(HEAD_DIM)) { + uint out_index = + (((batch * uint(QUERY_HEADS) + q_head) * uint(QUERY_LENGTH) + q_token) + * uint(HEAD_DIM)) + lane; + out[out_index] = 0.0f; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (uint tile_start = 0u; tile_start < uint(LOGICAL_LENGTH); tile_start += threads_per_row) { + uint logical_token = tile_start + lane; + bool active = logical_token < uint(LOGICAL_LENGTH) + && (!DO_CAUSAL || logical_token <= causal_limit); + float weight = 0.0f; + uint physical_token = 0u; + if (active) { + physical_token = tq_physical_token( + logical_token, uint(CAPACITY), uint(RING_OFFSET), uint(PINNED_PREFIX_LENGTH)); + float score = 0.0f; + for (uint group = 0u; group < uint(GROUPS_PER_VECTOR); group++) { + uint group_start = group * uint(GROUP_SIZE); + uint count = min(uint(GROUP_SIZE), uint(HEAD_DIM) - group_start); + thread float query_values[GROUP_SIZE]; + for (uint local = 0u; local < count; local++) { + uint dimension = group_start + local; + uint q_index = + (((batch * uint(QUERY_HEADS) + q_head) * uint(QUERY_LENGTH) + q_token) + * uint(HEAD_DIM)) + dimension; + query_values[local] = float(q[q_index]); + } + score += tq_product_attention_inner_product_group( + k_packed, k_signs, k_high_mask, k_scales, query_values, + batch, kv_head, physical_token, group, key_seed, + uint(GROUP_SIZE), uint(KV_HEADS), uint(CAPACITY), uint(GROUPS_PER_VECTOR), + uint(MAG_WORDS_PER_GROUP), uint(BITSET_WORDS_PER_GROUP), + uint(KEY_BASE_BITS), uint(KEY_HIGH_BITS), uint(HEAD_DIM)); + } + weight = exp(score * attention_scale - row_max) * inv_sum; + } + tile_weights[lane] = weight; + tile_physical_tokens[lane] = physical_token; + threadgroup_barrier(mem_flags::mem_threadgroup); + + thread float decode_scratch[GROUP_SIZE]; + for (uint dimension = 0; dimension < uint(HEAD_DIM); dimension++) { + float contribution = 0.0f; + if (active) { + float value = tq_decode_attention_value( + v_packed, v_signs, v_high_mask, v_residual_signs, v_scales, + batch, kv_head, tile_physical_tokens[lane], dimension, + (ulong(uint(VALUE_SEED_HI)) << 32) | ulong(uint(VALUE_SEED_LO)), 1u, + uint(GROUP_SIZE), uint(KV_HEADS), uint(CAPACITY), uint(GROUPS_PER_VECTOR), + uint(VALUE_MAG_WORDS_PER_GROUP), uint(BITSET_WORDS_PER_GROUP), uint(BASE_BITS), uint(HIGH_BITS), + uint(VALUE_BITS), uint(KEY_BASE_BITS), uint(KEY_HIGH_BITS), uint(HEAD_DIM), + decode_scratch); + contribution = tile_weights[lane] * value; + } + partial[lane] = contribution; + threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint stride = threads_per_row >> 1; stride > 0u; stride >>= 1) { + if (lane < stride) { + partial[lane] += partial[lane + stride]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + if (lane == 0u) { + uint out_index = + (((batch * uint(QUERY_HEADS) + q_head) * uint(QUERY_LENGTH) + q_token) + * uint(HEAD_DIM)) + dimension; + out[out_index] = float(out[out_index]) + partial[0]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } + """ +} diff --git a/Tests/MLXTests/MLXFastKernelTests.swift b/Tests/MLXTests/MLXFastKernelTests.swift index 82f1dca8..27aec609 100644 --- a/Tests/MLXTests/MLXFastKernelTests.swift +++ b/Tests/MLXTests/MLXFastKernelTests.swift @@ -72,6 +72,29 @@ class MLXFastKernelTests: XCTestCase { XCTAssertTrue(allClose(out[1], full([3, 2], values: -2)).all().item()) } + func testCustomKernelUInt32TemplateArgPreservesHighBits() { + let kernel = MLXFast.metalKernel( + name: "uint32_template_arg_test", + inputNames: [], + outputNames: ["out"], + source: """ + uint elem = thread_position_in_grid.x; + out[elem] = TOKEN == 0xDEADBEEFu ? 1.0f : 0.0f; + """) + + let out = kernel( + [], + template: [ + ("TOKEN", UInt32(0xDEAD_BEEF)) + ], + grid: (1, 1, 1), + threadGroup: (1, 1, 1), + outputShapes: [[1]], + outputDTypes: [.float32]) + + XCTAssertEqual(out[0].item(Float.self), 1) + } + func testFastSDPA() { // https://github.com/ml-explore/mlx-swift/issues/172 // this will just make sure the MLXFast.scaled_dot_product_attention is diff --git a/Tests/MLXTests/QuantizationTests.swift b/Tests/MLXTests/QuantizationTests.swift index 0edbd545..fe75840c 100644 --- a/Tests/MLXTests/QuantizationTests.swift +++ b/Tests/MLXTests/QuantizationTests.swift @@ -5,7 +5,43 @@ import MLX import MLXNN import XCTest +#if canImport(Metal) + import Metal +#endif + class QuantizationTests: XCTestCase { + private func requireMLXRuntime() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLCodec else { + throw XCTSkip("MLX runtime metallib unavailable in this package context") + } + } + + private func relativeMSE(_ lhs: [Float], _ rhs: [Float]) -> Float { + let squaredError = zip(lhs, rhs).reduce(Float(0)) { partial, pair in + let delta = pair.0 - pair.1 + return partial + delta * delta + } + let signal = lhs.reduce(Float(0)) { $0 + $1 * $1 } + return squaredError / max(signal, Float.leastNonzeroMagnitude) + } + + private func pearsonCorrelation(_ lhs: [Float], _ rhs: [Float]) -> Float { + let count = Float(lhs.count) + let lhsMean = lhs.reduce(Float(0), +) / count + let rhsMean = rhs.reduce(Float(0), +) / count + var numerator = Float(0) + var lhsVariance = Float(0) + var rhsVariance = Float(0) + for (left, right) in zip(lhs, rhs) { + let lhsCentered = left - lhsMean + let rhsCentered = right - rhsMean + numerator += lhsCentered * rhsCentered + lhsVariance += lhsCentered * lhsCentered + rhsVariance += rhsCentered * rhsCentered + } + return numerator / max(sqrt(lhsVariance * rhsVariance), Float.leastNonzeroMagnitude) + } + func testQuantizedLinearShapeDesc() { let linear1 = Linear(512, 1024) let quantized1 = linear1.toQuantized(groupSize: 64, bits: 4) @@ -39,4 +75,779 @@ class QuantizationTests: XCTestCase { let quantized = QuantizedLinear(64, 64, groupSize: 32, bits: 4, mode: .mxfp4) XCTAssertNil(quantized.biases) } + + func testTurboQuantPackedRoundTrip() throws { + try requireMLXRuntime() + + let x = MLXArray.ones([1, 32], dtype: .float32, stream: .device(.cpu)) + let configuration = TurboQuantConfiguration(preset: .turbo3_5, groupSize: 32) + let packed = turboQuantized(x, configuration: configuration, stream: .device(.cpu)) + let decoded = turboDequantized(packed, configuration: configuration, stream: .device(.cpu)) + + XCTAssertEqual(decoded.shape, x.shape) + XCTAssertTrue(allClose(decoded, x).item(Bool.self)) + } + + func testTurboQuantMatmulShape() throws { + try requireMLXRuntime() + + let x = MLXArray.ones([2, 32], dtype: .float32, stream: .device(.cpu)) + let w = MLXArray.ones([4, 32], dtype: .float32, stream: .device(.cpu)) + let configuration = TurboQuantConfiguration(preset: .turbo2_5, groupSize: 32) + let packed = turboQuantized(w, configuration: configuration, stream: .device(.cpu)) + let output = turboQuantizedMM( + x, packed, configuration: configuration, stream: .device(.cpu)) + + XCTAssertEqual(output.shape, [2, 4]) + } + + func testTurboQuantReferenceCodecIsDeterministic() throws { + try requireMLXRuntime() + + let values = (0 ..< 128).map { index in + Float(sin(Double(index) * 0.17) + cos(Double(index) * 0.03)) + } + let x = MLXArray(values, [2, 64]) + let configuration = TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 32, + backend: .polarQJLReference, + seed: 42 + ) + + let first = try turboQuantReferenceEncode(x, configuration: configuration) + let second = try turboQuantReferenceEncode(x, configuration: configuration) + + XCTAssertEqual(first, second) + XCTAssertEqual(first.shape, [2, 64]) + XCTAssertEqual(first.format, TurboQuantReferenceFormat.turboQuantProd) + XCTAssertGreaterThan(first.storageByteCount, 0) + XCTAssertFalse(first.highScales.isEmpty) + } + + func testTurboQuantReferenceCodecUsesFullWidthSeed() throws { + try requireMLXRuntime() + + let values = (0 ..< 128).map { index in + Float(sin(Double(index) * 0.11) + cos(Double(index) * 0.19)) + } + let x = MLXArray(values, [2, 64]) + let lowSeedConfiguration = TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .polarQJLReference, + seed: 0x0000_0000_0123_4567 + ) + let highSeedConfiguration = TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .polarQJLReference, + seed: 0xDEAD_BEEF_0123_4567 + ) + + let lowSeed = try turboQuantReferenceEncode(x, configuration: lowSeedConfiguration) + let highSeed = try turboQuantReferenceEncode(x, configuration: highSeedConfiguration) + + XCTAssertNotEqual(lowSeed.signs, highSeed.signs) + } + + func testTurboQuantReferenceCodecDistortionThreshold() throws { + try requireMLXRuntime() + + let values = (0 ..< 256).map { index in + let position = Double(index) + let sineTerm = sin(position * 0.11) * 0.7 + let cosineTerm = cos(position * 0.07) * 0.3 + return Float(sineTerm + cosineTerm) + } + let x = MLXArray(values, [4, 64]) + let configuration = TurboQuantConfiguration( + preset: .turbo3_5, + role: .vector, + groupSize: 64, + backend: .polarQJLReference, + seed: 17 + ) + + let code = try turboQuantReferenceEncode(x, configuration: configuration) + let decoded = try turboQuantReferenceDecode(code).asArray(Float.self) + let mse = + zip(values, decoded) + .map { lhs, rhs in + let delta = lhs - rhs + return delta * delta + } + .reduce(Float(0), +) / Float(values.count) + + XCTAssertLessThan(mse, 0.01) + } + + func testTurboQuantReferenceQualityGatePassesFixture() throws { + try requireMLXRuntime() + + let values = (0 ..< 256).map { index in + let position = Double(index) + let sineTerm = sin(position * 0.09) * 0.5 + let cosineTerm = cos(position * 0.13) * 0.25 + return Float(sineTerm + cosineTerm) + } + let x = MLXArray(values, [4, 64]) + let configuration = TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .polarQJLReference, + seed: 99 + ) + + let report = try turboQuantReferenceQuality(x, configuration: configuration) + + XCTAssertLessThan(report.relativeMSE, 0.085) + XCTAssertGreaterThan(report.cosineSimilarity, 0.955) + } + + func testTurboQuantReferenceValueBitsStorageAccounting() throws { + try requireMLXRuntime() + + let values = (0 ..< 256).map { index in + let position = Double(index) + let sineTerm = 0.4 * sin(position * 0.07) + let cosineTerm = 0.15 * cos(position * 0.17) + return Float(sineTerm + cosineTerm) + } + let x = MLXArray(values, [4, 64]) + let twoBit = try turboQuantReferenceEncode( + x, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .value, + groupSize: 64, + backend: .polarQJLReference, + valueBits: 2 + ) + ) + let fourBit = try turboQuantReferenceEncode( + x, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .value, + groupSize: 64, + backend: .polarQJLReference, + valueBits: 4 + ) + ) + + XCTAssertEqual(twoBit.format, TurboQuantReferenceFormat.affineValue) + XCTAssertEqual(fourBit.format, TurboQuantReferenceFormat.affineValue) + XCTAssertLessThan(twoBit.approximateBitsPerValue, 3.1) + XCTAssertLessThan(fourBit.approximateBitsPerValue, 5.1) + XCTAssertLessThan(twoBit.storageByteCount, fourBit.storageByteCount) + } + + func testTurboQuantProductInnerProductBiasAndRetrieval() throws { + try requireMLXRuntime() + + let queryValues = (0 ..< 64).map { index in + let position = Double(index) + let sineTerm = 0.35 * sin(position * 0.13) + let cosineTerm = 0.2 * cos(position * 0.05) + return Float(sineTerm + cosineTerm) + } + let needleValues = queryValues.map { $0 * 1.35 } + let query = MLXArray(queryValues, [64]) + let keys = (0 ..< 16).map { keyIndex in + (0 ..< 64).map { dim in + if keyIndex == 7 { return needleValues[dim] } + let position = Double(keyIndex * 64 + dim) + return Float(0.25 * sin(position * 0.071) - 0.18 * cos(position * 0.113)) + } + } + + var exactScores: [Float] = [] + var estimatedScores: [Float] = [] + for (keyIndex, keyValues) in keys.enumerated() { + let exactScore = zip(queryValues, keyValues).reduce(Float(0)) { partial, pair in + partial + pair.0 * pair.1 + } + exactScores.append(exactScore) + let code = try turboQuantReferenceEncode( + MLXArray(keyValues, [64]), + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .polarQJLReference, + seed: UInt64(0x600D_0000 + keyIndex) + ) + ) + estimatedScores.append(try turboQuantReferenceInnerProduct(query: query, code: code)) + } + + XCTAssertEqual(estimatedScores.enumerated().max(by: { $0.element < $1.element })?.offset, 7) + XCTAssertGreaterThan(pearsonCorrelation(exactScores, estimatedScores), 0.7) + + let target = MLXArray(keys[3], [64]) + let exact = exactScores[3] + let estimates = try (0 ..< 32).map { seedOffset in + let code = try turboQuantReferenceEncode( + target, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .polarQJLReference, + seed: UInt64(0xB1A5_0000 + seedOffset) + ) + ) + return try turboQuantReferenceInnerProduct(query: query, code: code) + } + let average = estimates.reduce(Float(0), +) / Float(estimates.count) + XCTAssertLessThan(abs(average - exact) / max(abs(exact), Float.leastNonzeroMagnitude), 0.25) + } + + func testTurboQuantBackendAvailabilityContract() throws { + XCTAssertNoThrow(try requireTurboQuantBackend(.mlxPacked)) + XCTAssertNoThrow(try requireTurboQuantBackend(.polarQJLReference)) + + let availability = TurboQuantKernelAvailability.current + if availability.supportsMetalPolarQJL { + XCTAssertNoThrow(try requireTurboQuantBackend(.metalPolarQJL)) + XCTAssertEqual(availability.runtimeBackend(for: .metalPolarQJL), .metalPolarQJL) + XCTAssertNil(availability.fallbackReason(for: .metalPolarQJL)) + } else { + XCTAssertThrowsError(try requireTurboQuantBackend(.metalPolarQJL)) + XCTAssertEqual(availability.runtimeBackend(for: .metalPolarQJL), .mlxPacked) + XCTAssertNotNil(availability.fallbackReason(for: .metalPolarQJL)) + } + } + + func testTurboQuantDeviceCapabilitiesAndProbeContract() throws { + let capabilities = TurboQuantDeviceCapabilities.current + let availability = TurboQuantKernelAvailability.current + + XCTAssertFalse(capabilities.architectureName.isEmpty) + XCTAssertEqual(capabilities.runtimeProbe, TurboQuantRuntimeProbe.current) + XCTAssertEqual(availability.selfTestStatus, capabilities.runtimeProbe.status) + XCTAssertEqual( + availability.selectedKernelProfile, capabilities.runtimeProbe.selectedKernelProfile) + + if availability.supportsMetalPolarQJLAttention { + XCTAssertEqual(capabilities.runtimeProbe.status, .passed) + XCTAssertNotEqual(capabilities.runtimeProbe.selectedKernelProfile, .mlxPackedFallback) + XCTAssertNil(capabilities.runtimeProbe.failureReason) + } else { + XCTAssertNotEqual(capabilities.runtimeProbe.status, .notRun) + XCTAssertEqual(availability.runtimeBackend(for: .metalPolarQJL), .mlxPacked) + } + } + + func testTurboQuantRuntimeProbeAvailabilityIsActionable() throws { + let probe = TurboQuantRuntimeProbe.current + let availability = TurboQuantKernelAvailability.current + + XCTAssertNotEqual(probe.status, .notRun) + XCTAssertEqual(availability.selfTestStatus, probe.status) + XCTAssertEqual(availability.selfTestFailureReason, probe.failureReason) + + if probe.passed { + XCTAssertTrue(probe.metalRuntimeAvailable) + XCTAssertTrue(availability.supportsMetalPolarQJLCodec) + XCTAssertTrue(availability.supportsMetalPolarQJLAttention) + XCTAssertTrue(probe.encodeDecodePassed) + XCTAssertTrue(probe.qkPassed) + XCTAssertTrue(probe.avPassed) + XCTAssertTrue(probe.tiledFusedPassed) + XCTAssertNotNil(probe.encodeDecodeLatencySeconds) + XCTAssertNotNil(probe.twoStageLatencySeconds) + XCTAssertNotNil(probe.tiledFusedLatencySeconds) + XCTAssertNil(probe.failureReason) + } else { + XCTAssertFalse(availability.supportsMetalPolarQJLAttention) + XCTAssertEqual(availability.runtimeBackend(for: .metalPolarQJL), .mlxPacked) + XCTAssertNotNil(probe.failureReason) + } + } + + func testTurboQuantSwiftPMMetalLibraryResourceIsLoadableWhenMetalDeviceExists() throws { + #if canImport(Metal) + guard MTLCreateSystemDefaultDevice() != nil else { + throw XCTSkip("No Metal device available") + } + + let probe = TurboQuantRuntimeProbe.current + XCTAssertTrue( + probe.metalRuntimeAvailable, + probe.failureReason ?? "Expected SwiftPM-packaged default.metallib to be loadable" + ) + XCTAssertTrue(TurboQuantKernelAvailability.current.supportsMetalPolarQJLCodec) + #else + throw XCTSkip("Metal framework unavailable") + #endif + } + + func testTurboQuantMetalCodecRoundTripWhenAvailable() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLCodec else { + throw XCTSkip("Metal runtime unavailable") + } + + let values = (0 ..< 128).map { index in + Float(sin(Double(index) * 0.05)) + } + let x = MLXArray(values, [2, 64]) + for seed in [UInt64(0xDEAD_BEEF_0000_0017), UInt64(0x0000_0000_DEAD_BEEF)] { + let configuration = TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .metalPolarQJL, + seed: seed + ) + + let code = try turboQuantMetalEncode(x, configuration: configuration) + let decoded = try turboQuantMetalDecode(code).asArray(Float.self) + XCTAssertEqual(code.shape, [2, 64]) + XCTAssertLessThan(relativeMSE(values, decoded), 0.1) + } + } + + func testTurboQuantMetalCodecUsesGPUStreamWhenDefaultDeviceIsCPU() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLCodec else { + throw XCTSkip("Metal runtime unavailable") + } + + let values = (0 ..< 128).map { index in + Float(sin(Double(index) * 0.07)) + } + let x = MLXArray(values, [2, 64]) + let configuration = TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .metalPolarQJL, + seed: 0xDEAD_BEEF_0000_0017 + ) + + try Device.withDefaultDevice(.cpu) { + XCTAssertTrue(StreamOrDevice.default.description.contains("cpu")) + + let code = try turboQuantMetalEncode(x, configuration: configuration) + let decoded = try turboQuantMetalDecode(code).asArray(Float.self) + + XCTAssertEqual(code.shape, [2, 64]) + XCTAssertEqual(decoded.count, values.count) + } + } + + func testTurboQuantMetalMatmulMatchesDecodedReferenceWhenAvailable() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLCodec else { + throw XCTSkip("Metal runtime unavailable") + } + + let xValues = (0 ..< 192).map { index in + let position = Double(index) + return Float(0.4 * sin(position * 0.07) + 0.2 * cos(position * 0.17)) + } + let wValues = (0 ..< 320).map { index in + let position = Double(index) + return Float(0.3 * cos(position * 0.05) - 0.15 * sin(position * 0.11)) + } + let x = MLXArray(xValues, [3, 64]) + let w = MLXArray(wValues, [5, 64]) + let configuration = TurboQuantConfiguration( + preset: .turbo3_5, + role: .vector, + groupSize: 64, + backend: .metalPolarQJL, + seed: 0xC0FF_EE00_0000_0042 + ) + + let code = try turboQuantMetalEncode(w, configuration: configuration) + let decoded = try turboQuantMetalDecode(code, dtype: .float32) + let reference = matmul(x, decoded.transposed()) + let output = try turboQuantizedMM(x, code, transpose: true, outputDType: .float32) + + XCTAssertEqual(output.shape, [3, 5]) + XCTAssertTrue(allClose(output, reference, rtol: 1e-4, atol: 1e-4).item(Bool.self)) + XCTAssertEqual(code.magnitudeWordsPerGroup, 5) + + let columnMajorWeight = decoded.transposed() + let columnCode = try turboQuantMetalEncode(columnMajorWeight, configuration: configuration) + let columnReference = matmul(x, try turboQuantMetalDecode(columnCode, dtype: .float32)) + let columnOutput = try turboQuantizedMM( + x, columnCode, transpose: false, outputDType: .float32) + + XCTAssertEqual(columnOutput.shape, [3, 5]) + XCTAssertTrue( + allClose(columnOutput, columnReference, rtol: 1e-4, atol: 1e-4).item(Bool.self)) + } + + func testTurboQuantAttentionLayoutIsRowWise() throws { + let layout = try turboQuantAttentionLayout(shape: [1, 2, 3, 80], groupSize: 64) + + XCTAssertEqual(layout.layoutVersion, 4) + XCTAssertEqual(layout.logicalShape, [1, 2, 3, 80]) + XCTAssertEqual(layout.pinnedPrefixLength, 0) + XCTAssertEqual(layout.groupsPerVector, 2) + XCTAssertEqual(layout.bitsetWordsPerGroup, 2) + } + + func testTurboQuantCompressedAttentionUsesProductEstimatorWhenAvailable() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLAttention else { + throw XCTSkip("Metal compressed attention unavailable") + } + + let qValues: [Float] = (0 ..< 512).map { index in + let position = Double(index) + return Float(sin(position * 0.03) + 0.2 * cos(position * 0.11)) + } + let kValues: [Float] = (0 ..< 640).map { index in + let position = Double(index) + return Float(cos(position * 0.05) * 0.5 + sin(position * 0.17) * 0.1) + } + let vValues: [Float] = (0 ..< 640).map { index in + let position = Double(index) + return Float(sin(position * 0.07) * 0.25 - cos(position * 0.13) * 0.2) + } + let queries = MLXArray(qValues, [1, 4, 2, 64]) + let keys = MLXArray(kValues, [1, 2, 5, 64]) + let values = MLXArray(vValues, [1, 2, 5, 64]) + let keyCode = try turboQuantMetalEncodeAttention( + keys, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .metalPolarQJL, + seed: 11 + ) + ) + let valueCode = try turboQuantMetalEncodeAttention( + values, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .value, + groupSize: 64, + backend: .metalPolarQJL, + seed: 13 + ) + ) + let fullPrecisionReference = MLXFast.scaledDotProductAttention( + queries: queries, + keys: keys, + values: values, + scale: 1 / sqrt(Float(64)), + mask: .causal + ) + + let twoStage = try turboQuantMetalScaledDotProductAttention( + queries: queries, + keyCode: keyCode, + valueCode: valueCode, + scale: 1 / sqrt(Float(64)), + mask: .causal, + preferOnlineFused: false + ) + let fused = try turboQuantMetalScaledDotProductAttention( + queries: queries, + keyCode: keyCode, + valueCode: valueCode, + scale: 1 / sqrt(Float(64)), + mask: .causal, + preferOnlineFused: true + ) + + XCTAssertEqual(twoStage.shape, [1, 4, 2, 64]) + XCTAssertEqual(fused.shape, [1, 4, 2, 64]) + XCTAssertTrue(allClose(fused, twoStage, rtol: 1e-4, atol: 1e-4).item(Bool.self)) + XCTAssertLessThan( + relativeMSE( + fullPrecisionReference.asArray(Float.self), + fused.asArray(Float.self) + ), + 0.12 + ) + XCTAssertLessThan( + relativeMSE( + fullPrecisionReference.asArray(Float.self), + twoStage.asArray(Float.self) + ), + 0.12 + ) + } + + func testTurboQuantCompressedAttentionSupportsBatchedInputsWhenAvailable() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLAttention else { + throw XCTSkip("Metal compressed attention unavailable") + } + + let qValues: [Float] = (0 ..< 1024).map { index in + let position = Double(index) + return Float(0.3 * sin(position * 0.021) + 0.17 * cos(position * 0.071)) + } + let kValues: [Float] = (0 ..< 1280).map { index in + let position = Double(index) + return Float(0.25 * cos(position * 0.037) - 0.11 * sin(position * 0.097)) + } + let vValues: [Float] = (0 ..< 1280).map { index in + let position = Double(index) + return Float(0.19 * sin(position * 0.043) + 0.13 * cos(position * 0.083)) + } + let queries = MLXArray(qValues, [2, 4, 2, 64]) + let keys = MLXArray(kValues, [2, 2, 5, 64]) + let values = MLXArray(vValues, [2, 2, 5, 64]) + let keyCode = try turboQuantMetalEncodeAttention( + keys, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .metalPolarQJL, + seed: 31 + ) + ) + let valueCode = try turboQuantMetalEncodeAttention( + values, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .value, + groupSize: 64, + backend: .metalPolarQJL, + seed: 37 + ) + ) + let fullPrecisionReference = MLXFast.scaledDotProductAttention( + queries: queries, + keys: keys, + values: values, + scale: 1 / sqrt(Float(64)), + mask: .causal + ) + + let twoStage = try turboQuantMetalScaledDotProductAttention( + queries: queries, + keyCode: keyCode, + valueCode: valueCode, + scale: 1 / sqrt(Float(64)), + mask: .causal, + preferOnlineFused: false + ) + let fused = try turboQuantMetalScaledDotProductAttention( + queries: queries, + keyCode: keyCode, + valueCode: valueCode, + scale: 1 / sqrt(Float(64)), + mask: .causal, + preferOnlineFused: true + ) + + XCTAssertEqual(twoStage.shape, [2, 4, 2, 64]) + XCTAssertEqual(fused.shape, [2, 4, 2, 64]) + XCTAssertTrue(allClose(fused, twoStage, rtol: 1e-4, atol: 1e-4).item(Bool.self)) + XCTAssertLessThan( + relativeMSE( + fullPrecisionReference.asArray(Float.self), + fused.asArray(Float.self) + ), + 0.12 + ) + } + + func testTurboQuantCompressedAttentionSupportsSinksWhenAvailable() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLAttention else { + throw XCTSkip("Metal compressed attention unavailable") + } + + let qValues: [Float] = (0 ..< 512).map { index in + let position = Double(index) + return Float(0.24 * sin(position * 0.031) + 0.12 * cos(position * 0.089)) + } + let kValues: [Float] = (0 ..< 640).map { index in + let position = Double(index) + return Float(0.2 * cos(position * 0.047) - 0.08 * sin(position * 0.101)) + } + let vValues: [Float] = (0 ..< 640).map { index in + let position = Double(index) + return Float(0.18 * sin(position * 0.053) + 0.09 * cos(position * 0.077)) + } + let queries = MLXArray(qValues, [1, 4, 2, 64]) + let keys = MLXArray(kValues, [1, 2, 5, 64]) + let values = MLXArray(vValues, [1, 2, 5, 64]) + let sinks = MLXArray([0.3 as Float, -0.2, 0.1, -0.4]) + let keyCode = try turboQuantMetalEncodeAttention( + keys, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .metalPolarQJL, + seed: 41 + ) + ) + let valueCode = try turboQuantMetalEncodeAttention( + values, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .value, + groupSize: 64, + backend: .metalPolarQJL, + seed: 43 + ) + ) + let reference = MLXFast.scaledDotProductAttention( + queries: queries, + keys: keys, + values: values, + scale: 1 / sqrt(Float(64)), + mask: .causal, + sinks: sinks + ) + + let output = try turboQuantMetalScaledDotProductAttention( + queries: queries, + keyCode: keyCode, + valueCode: valueCode, + scale: 1 / sqrt(Float(64)), + mask: .causal, + sinks: sinks, + preferOnlineFused: true + ) + + XCTAssertEqual(output.shape, [1, 4, 2, 64]) + XCTAssertLessThan( + relativeMSE( + reference.asArray(Float.self), + output.asArray(Float.self) + ), + 0.12 + ) + } + + func testTurboQuantCompressedAttentionSupportsSplitKeyValueDimensionsWhenAvailable() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLAttention else { + throw XCTSkip("Metal compressed attention unavailable") + } + + let qValues: [Float] = (0 ..< 512).map { index in + let position = Double(index) + return Float(0.21 * sin(position * 0.029) + 0.16 * cos(position * 0.061)) + } + let kValues: [Float] = (0 ..< 640).map { index in + let position = Double(index) + return Float(0.18 * cos(position * 0.041) - 0.12 * sin(position * 0.087)) + } + let vValues: [Float] = (0 ..< 800).map { index in + let position = Double(index) + return Float(0.22 * sin(position * 0.049) + 0.10 * cos(position * 0.093)) + } + let queries = MLXArray(qValues, [1, 4, 2, 64]) + let keys = MLXArray(kValues, [1, 2, 5, 64]) + let values = MLXArray(vValues, [1, 2, 5, 80]) + let keyCode = try turboQuantMetalEncodeAttention( + keys, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .key, + groupSize: 64, + backend: .metalPolarQJL, + seed: 51 + ) + ) + let valueCode = try turboQuantMetalEncodeAttention( + values, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .value, + groupSize: 64, + backend: .metalPolarQJL, + seed: 53 + ) + ) + + let scores = try turboQuantMetalQK( + queries: queries, + keyCode: keyCode, + scale: 1 / sqrt(Float(64)), + mask: .causal + ) + let twoStage = try turboQuantMetalAV( + attentionWeights: softmax(scores.asType(.float32), axis: -1), + valueCode: valueCode, + outputDType: queries.dtype + ) + let fusedPreferred = try turboQuantMetalScaledDotProductAttention( + queries: queries, + keyCode: keyCode, + valueCode: valueCode, + scale: 1 / sqrt(Float(64)), + mask: .causal, + preferOnlineFused: true + ) + + XCTAssertEqual(twoStage.shape, [1, 4, 2, 80]) + XCTAssertEqual(fusedPreferred.shape, [1, 4, 2, 80]) + XCTAssertTrue(allClose(fusedPreferred, twoStage, rtol: 1e-4, atol: 1e-4).item(Bool.self)) + } + + func testTurboQuantAttentionDecodeHonorsRotatingLayoutWhenAvailable() throws { + guard TurboQuantKernelAvailability.current.supportsMetalPolarQJLAttention else { + throw XCTSkip("Metal compressed attention unavailable") + } + + let capacity = 6 + let headDimension = 64 + let physicalValues = (0 ..< capacity).flatMap { token in + Array(repeating: Float(token + 1) * 0.25, count: headDimension) + } + let physical = MLXArray(physicalValues, [1, 1, capacity, headDimension]) + let code = try turboQuantMetalEncodeAttention( + physical, + configuration: TurboQuantConfiguration( + preset: .turbo3_5, + role: .value, + groupSize: 64, + backend: .metalPolarQJL, + seed: 29 + ), + capacity: capacity, + logicalLength: capacity, + ringOffset: 2, + pinnedPrefixLength: 2 + ) + + let decoded = try turboQuantMetalDecodeAttention(code, outputDType: .float32) + let expectedTokenOrder = [0, 1, 4, 5, 2, 3] + let expectedValues = expectedTokenOrder.flatMap { token in + Array(repeating: Float(token + 1) * 0.25, count: headDimension) + } + let expected = MLXArray(expectedValues, [1, 1, capacity, headDimension]) + + XCTAssertTrue(allClose(decoded, expected, rtol: 1e-6, atol: 1e-6).item(Bool.self)) + } + + func testTurboQuantOnlineFusedSupportContract() throws { + let keyLayout = try turboQuantAttentionLayout(shape: [1, 2, 8, 64], groupSize: 64) + + XCTAssertTrue( + turboQuantMetalSupportsOnlineFusedAttention( + queryShape: [1, 4, 1, 64], + keyLayout: keyLayout, + mask: .none + ) + ) + } + + func testTurboQuantOnlineFusedSupportsLargeContextContract() throws { + let keyLayout = try turboQuantAttentionLayout(shape: [1, 2, 513, 64], groupSize: 64) + + XCTAssertTrue( + turboQuantMetalSupportsOnlineFusedAttention( + queryShape: [1, 4, 1, 64], + keyLayout: keyLayout, + mask: .none + ) + ) + } } diff --git a/tools/build-swiftpm-metallib.sh b/tools/build-swiftpm-metallib.sh new file mode 100755 index 00000000..08103bdd --- /dev/null +++ b/tools/build-swiftpm-metallib.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# Build the default Metal library resource used by SwiftPM Cmlx builds. + +set -euo pipefail + +if [[ $# -ne 1 ]]; then + echo "usage: $0 OUTPUT_METALLIB" >&2 + exit 64 +fi + +OUTPUT="$1" +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd) +ROOT_DIR=$(realpath "${SCRIPT_DIR}/..") +KERNELS_DIR="${ROOT_DIR}/Source/Cmlx/mlx/mlx/backend/metal/kernels" + +METAL=$(xcrun -sdk macosx -find metal) +METALLIB=$(xcrun -sdk macosx -find metallib) +TMP_DIR=$(mktemp -d) +trap 'rm -rf "${TMP_DIR}"' EXIT + +DEPLOYMENT_TARGET="${MACOSX_DEPLOYMENT_TARGET:-14.0}" + +metal_version=$( + printf '%s\n' '__METAL_VERSION__' | + "${METAL}" "-mmacosx-version-min=${DEPLOYMENT_TARGET}" -E -x metal -P - | + tail -1 | + tr -d '[:space:]' +) +metal_version=${metal_version:-0} + +kernels=( + "arg_reduce" + "conv" + "gemv" + "layer_norm" + "random" + "rms_norm" + "rope" + "scaled_dot_product_attention" +) + +if (( metal_version >= 320 )); then + kernels+=("fence") +fi + +metal_flags=( + -x metal + -Wall + -Wextra + -fno-fast-math + -Wno-c++17-extensions + -Wno-c++20-extensions + -mmacosx-version-min="${DEPLOYMENT_TARGET}" +) + +if (( metal_version >= 400 )); then + metal_flags+=(-std=metal4.0) +elif (( metal_version >= 320 )); then + metal_flags+=(-std=metal3.2) +elif (( metal_version >= 310 )); then + metal_flags+=(-std=metal3.1) +elif (( metal_version >= 300 )); then + metal_flags+=(-std=metal3.0) +fi + +air_files=() +for kernel in "${kernels[@]}"; do + source="${KERNELS_DIR}/${kernel}.metal" + air="${TMP_DIR}/${kernel}.air" + "${METAL}" "${metal_flags[@]}" -c "${source}" -I"${ROOT_DIR}/Source/Cmlx/mlx" -o "${air}" + air_files+=("${air}") +done + +mkdir -p "$(dirname "${OUTPUT}")" +"${METALLIB}" "${air_files[@]}" -o "${TMP_DIR}/default.metallib" +mv "${TMP_DIR}/default.metallib" "${OUTPUT}" diff --git a/tools/fix-metal-includes.sh b/tools/fix-metal-includes.sh deleted file mode 100755 index 622d4311..00000000 --- a/tools/fix-metal-includes.sh +++ /dev/null @@ -1,109 +0,0 @@ -#!/bin/bash -# Fixing include path for mlx-swift metal headers - -set -euo pipefail - -SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -ROOT_DIR=$(realpath "${SCRIPT_DIR}/..") - -# Where the files end up -OUTPUT_DIR="${ROOT_DIR}/Source/Cmlx/mlx-generated/metal" - -# The Cmlx source dir -CMLX_MLX_DIR="${ROOT_DIR}/Source/Cmlx/mlx" - -# sub-directory of Cmlx source containing the kernels -KERNELS_INCLUDE_PATH="mlx/backend/metal/kernels" - -KERNELS_DIR="${CMLX_MLX_DIR}/${KERNELS_INCLUDE_PATH}" - -# list of kernels files to process -# see Source/Cmlx/mlx/mlx/backend/metal/kernels/CMakeLists.txt -KERNEL_LIST=" \ -arg_reduce.metal \ -conv.metal \ -gemv.metal \ -layer_norm.metal \ -random.metal \ -rms_norm.metal \ -rope.metal \ -scaled_dot_product_attention.metal \ -steel/attn/kernels/steel_attention.metal" - -# We fixup all the header files AND the listed kernel files -HEADERS=$(find "${KERNELS_DIR}" -name "*.h") -KERNELS=$(for file in ${KERNEL_LIST}; do echo "${KERNELS_DIR}/${file}"; done) - -# Regular expression to replace include directives -PATTERN="^#include \"${KERNELS_INCLUDE_PATH}/([^\"]*)\"" - -mkdir -p "${OUTPUT_DIR}" - -# Mimic the original logic in PrepareMetalShaders::transformIncludes -# Returns rootPath, a string containing a sequence of "../../" to prefix the -# include path -function replaceIncludePrefix { - #Extract components up to the output dir and drop the last one - #swift: let pathUnderKernels = url.pathComponents.drop { $0 != "output" }.dropLast() - - absolutePath=$(realpath "${1}") - absoluteOut=$(realpath "${OUTPUT_DIR}") - remainingPath=${absolutePath#"$absoluteOut"/} - - # Doing the `dropLast` with `dirname`, handling the case where it returns `.`` - remainingPath=$(dirname "${remainingPath}" | sed -E 's|^\.$||') - - # Build the root path - # swift: let rootPath =Array(repeating: "..", count: pathUnderKernels.count - 1).joined(separator: "/") - # + ((pathUnderKernels.count - 1 == 0) ? "" : "/") - IFS='/' read -r -a path <<< "${remainingPath}" - count=${#path[@]} - - if [ "$count" -le 0 ]; then - root_path="" - else - root_path=$(printf "../%.0s" $(seq 1 "${count}")) - fi - echo "${root_path}" -} - -# First pass : copy the files if needed -for src in ${HEADERS} ${KERNELS}; do - - relative_path=${src#"$KERNELS_DIR"/} - dest=${OUTPUT_DIR}/${relative_path} - - # If destination file doesn't exist or if it's older than the source - # copy from source and replace the #include directives - if [ ! -e "$dest" ] || [ "$src" -nt "$dest" ]; then - echo "${src} -> ${dest}" - mkdir -p "$(dirname "${dest}")" - cp -p "${src}" "${dest}" - else - echo "Skipping $src (more recent destination)" - fi - -done - -# second pass: update the include lines -# iterating on src to only process the list of files we copied -# (in case the destination directory has other unrelated files) -for src in ${HEADERS} ${KERNELS}; do - - relative_path=${src#"$KERNELS_DIR"/} - dest=${OUTPUT_DIR}/${relative_path} - prefix=$(replaceIncludePrefix "${dest}") - - # for each matching input line, compute the relative path, then replace the line - while read -r includeLine; do - includePath=$(echo "${includeLine}" | sed -E -n "s|${PATTERN}|\1|p") - - # Note the absence of "/" between the prefix and the path - replace="${prefix}${includePath}" - - # Replace the include line with the new one - echo sed -i '' -e "s|${KERNELS_INCLUDE_PATH}/${includePath}|${replace}|" "${dest}" - sed -i '' -e "s|${KERNELS_INCLUDE_PATH}/${includePath}|${replace}|" "${dest}" - - done < <(grep -E -o "${PATTERN}" "${dest}") -done diff --git a/tools/update-mlx.sh b/tools/update-mlx.sh index 940a3f33..36312b4a 100755 --- a/tools/update-mlx.sh +++ b/tools/update-mlx.sh @@ -75,6 +75,7 @@ make cpu_compiled_preamble cd .. +# Remove stale copied Metal sources from the deleted embedded fallback path. rm -rf Source/Cmlx/mlx-generated/metal rm -f Source/Cmlx/mlx-generated/* cp build/mlx/backend/metal/jit/* Source/Cmlx/mlx-generated @@ -89,8 +90,5 @@ for x in Source/Cmlx/mlx-generated/*.cpp ; do \ done; rm Source/Cmlx/mlx-generated/*.tmp -# Update the headers -./tools/fix-metal-includes.sh - # prepare xcodeproj files ./tools/update-mlx-xcodeproj.sh