Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,12 @@ let cmlx = Target.target(
cSettings: [
.headerSearchPath("mlx"),
.headerSearchPath("mlx-c"),
.headerSearchPath("turbo-quant"),
],
cxxSettings: cxxSettings + [
.headerSearchPath("mlx"),
.headerSearchPath("mlx-c"),
.headerSearchPath("turbo-quant"),
.headerSearchPath("json/single_include/nlohmann"),
.headerSearchPath("fmt/include"),
.define("MLX_VERSION", to: "\"0.31.1\""),
Expand Down
42 changes: 42 additions & 0 deletions Source/Cmlx/include/mlx/c/fast.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,48 @@ int mlx_fast_scaled_dot_product_attention(
const mlx_array sinks /* may be null */,
const mlx_stream s);

// TurboQuant KV cache compression
int mlx_fast_turbo_encode(
mlx_array* res_polar_k,
mlx_array* res_polar_v,
mlx_array* res_residual_k,
mlx_array* res_residual_v,
const mlx_array keys,
const mlx_array values,
int k_bits,
const mlx_stream s);

int mlx_fast_turbo_decode_k(
mlx_array* res,
const mlx_array packed,
const mlx_stream s);

int mlx_fast_turbo_decode_v(
mlx_array* res,
const mlx_array packed,
const mlx_stream s);

// TurboQuant KV cache compression
int mlx_fast_turbo_encode(
mlx_array* res_polar_k,
mlx_array* res_polar_v,
mlx_array* res_residual_k,
mlx_array* res_residual_v,
const mlx_array keys,
const mlx_array values,
int k_bits,
const mlx_stream s);

int mlx_fast_turbo_decode_k(
mlx_array* res,
const mlx_array packed,
const mlx_stream s);

int mlx_fast_turbo_decode_v(
mlx_array* res,
const mlx_array packed,
const mlx_stream s);

/**@}*/

#ifdef __cplusplus
Expand Down
Loading