diff --git a/src/kernels.cu b/src/kernels.cu index f3a5d98..5236437 100644 --- a/src/kernels.cu +++ b/src/kernels.cu @@ -1,5 +1,6 @@ #include "kernels.cuh" #include +#include // TODO perhaps share these between main.cpp and kernels.cu to not duplicate them? @@ -16,6 +17,41 @@ constexpr int V_OFFSET = BLOCK_SIZE * KV_DIM * sizeof(__nv_bfloat16); constexpr int BLOCK_BYTES = V_OFFSET * 2; // * 2 because K and V constexpr int MAX_BLOCKS_PER_SEQ = MAX_SEQ_LEN / BLOCK_SIZE; // 2048 / 16 = 128 +// Llama 3 RoPE scaling parameters +constexpr float ROPE_FACTOR = 32.0f; +constexpr float ROPE_LOW_FREQ_FACTOR = 1.0f; +constexpr float ROPE_HIGH_FREQ_FACTOR = 4.0f; +constexpr int ROPE_ORIGINAL_MAX_POS = 8192; +constexpr int NUM_ROPE_PAIRS = HEAD_DIM / 2; // 32 + +// Precomputed Llama3-scaled inverse frequencies +__device__ __constant__ float d_rope_freqs[NUM_ROPE_PAIRS]; + +void initRopeFreqs() +{ + float freqs[NUM_ROPE_PAIRS]; + float low_freq_wavelen = (float)ROPE_ORIGINAL_MAX_POS / ROPE_LOW_FREQ_FACTOR; + float high_freq_wavelen = (float)ROPE_ORIGINAL_MAX_POS / ROPE_HIGH_FREQ_FACTOR; + + for (int i = 0; i < NUM_ROPE_PAIRS; ++i) + { + float inv_freq = 1.0f / powf(500000.0f, (float)(2 * i) / HEAD_DIM); + float wavelen = 2.0f * M_PI / inv_freq; + + if (wavelen > low_freq_wavelen) + { + inv_freq /= ROPE_FACTOR; + } + else if (wavelen > high_freq_wavelen) + { + float smooth = (ROPE_ORIGINAL_MAX_POS / wavelen - ROPE_LOW_FREQ_FACTOR) / (ROPE_HIGH_FREQ_FACTOR - ROPE_LOW_FREQ_FACTOR); + inv_freq = (1.0f - smooth) * inv_freq / ROPE_FACTOR + smooth * inv_freq; + } + freqs[i] = inv_freq; + } + cudaMemcpyToSymbol(d_rope_freqs, freqs, NUM_ROPE_PAIRS * sizeof(float)); +} + // prefill / shared // gpu_input_tokens - N tokens @@ -90,14 +126,15 @@ __global__ void ropeKernel(__nv_bfloat16 *input, int num_tokens, int proj_dim) { if (2 * threadIdx.x + 1 + blockIdx.x * proj_dim < num_tokens * proj_dim) { - // TODO: precompute thetas, angles and perhaps sin/cos vals and reuse it across all kernel invocations - int double_i = 2 * (threadIdx.x % 32); - float theta = 1.0 / (pow(500000.0, ((float)double_i / HEAD_DIM))); - float angle = blockIdx.x * theta; - __nv_bfloat16 prev_2i = input[2 * threadIdx.x + blockIdx.x * proj_dim]; - __nv_bfloat16 prev_2i_1 = input[2 * threadIdx.x + 1 + blockIdx.x * proj_dim]; - input[2 * threadIdx.x + blockIdx.x * proj_dim] = (__nv_bfloat16)((float)prev_2i * cos(angle) - (float)prev_2i_1 * sin(angle)); - input[2 * threadIdx.x + 1 + blockIdx.x * proj_dim] = (__nv_bfloat16)((float)prev_2i * sin(angle) + (float)prev_2i_1 * cos(angle)); + int pair_idx = threadIdx.x % NUM_ROPE_PAIRS; + float angle = blockIdx.x * d_rope_freqs[pair_idx]; + float cos_a = cosf(angle); + float sin_a = sinf(angle); + int idx = 2 * threadIdx.x + blockIdx.x * proj_dim; + __nv_bfloat16 prev_2i = input[idx]; + __nv_bfloat16 prev_2i_1 = input[idx + 1]; + input[idx] = (__nv_bfloat16)((float)prev_2i * cos_a - (float)prev_2i_1 * sin_a); + input[idx + 1] = (__nv_bfloat16)((float)prev_2i * sin_a + (float)prev_2i_1 * cos_a); } } @@ -280,16 +317,17 @@ void embeddingGatherDecode(int *gpu_last_tokens, int num_tokens, __nv_bfloat16 * __global__ void ropeKernelDecode(__nv_bfloat16 *input, int position_in_sequence, int proj_dim) { - if (2 * threadIdx.x + 1 < proj_dim) // TODO: check correctness + if (2 * threadIdx.x + 1 < proj_dim) { - // TODO: precompute thetas, angles and perhaps sin/cos vals and reuse it across all kernel invocations - int double_i = 2 * (threadIdx.x % 32); - float theta = 1.0 / (pow(500000.0, ((float)double_i / HEAD_DIM))); - float angle = position_in_sequence * theta; - __nv_bfloat16 prev_2i = input[2 * threadIdx.x]; - __nv_bfloat16 prev_2i_1 = input[2 * threadIdx.x + 1]; - input[2 * threadIdx.x] = (__nv_bfloat16)((float)prev_2i * cos(angle) - (float)prev_2i_1 * sin(angle)); - input[2 * threadIdx.x + 1] = (__nv_bfloat16)((float)prev_2i * sin(angle) + (float)prev_2i_1 * cos(angle)); + int pair_idx = threadIdx.x % NUM_ROPE_PAIRS; + float angle = position_in_sequence * d_rope_freqs[pair_idx]; + float cos_a = cosf(angle); + float sin_a = sinf(angle); + int idx = 2 * threadIdx.x; + __nv_bfloat16 prev_2i = input[idx]; + __nv_bfloat16 prev_2i_1 = input[idx + 1]; + input[idx] = (__nv_bfloat16)((float)prev_2i * cos_a - (float)prev_2i_1 * sin_a); + input[idx + 1] = (__nv_bfloat16)((float)prev_2i * sin_a + (float)prev_2i_1 * cos_a); } } diff --git a/src/kernels.cuh b/src/kernels.cuh index 19ff54e..8ad4953 100644 --- a/src/kernels.cuh +++ b/src/kernels.cuh @@ -1,6 +1,9 @@ #pragma once #include +// init +void initRopeFreqs(); + // prefill void embeddingGather(int *gpu_input_tokens, __nv_bfloat16 *gpu_input_embeds, __nv_bfloat16 *embed_tokens, int num_input_tokens); void rmsNorm(__nv_bfloat16 *input, __nv_bfloat16 *output, nv_bfloat16 *norm_weights, int num_tokens); diff --git a/src/main.cpp b/src/main.cpp index 3a901b3..771253f 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -240,8 +240,6 @@ void prefill(std::vector &prompt, std::queue> &queue, int CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT); - // RoPE now - rope(q_proj, prompt_len, EMBEDDING_LENGTH); rope(k_proj_temp_buf, prompt_len, KV_DIM); @@ -563,6 +561,8 @@ int main(int argc, char *argv[]) return 1; } + initRopeFreqs(); + Weights weights{}; if (loadWeights(weights) != 0) {