Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 55 additions & 17 deletions src/kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "kernels.cuh"
#include <iostream>
#include <cmath>

// TODO perhaps share these between main.cpp and kernels.cu to not duplicate them?

Expand All @@ -16,6 +17,41 @@ constexpr int V_OFFSET = BLOCK_SIZE * KV_DIM * sizeof(__nv_bfloat16);
constexpr int BLOCK_BYTES = V_OFFSET * 2; // * 2 because K and V
constexpr int MAX_BLOCKS_PER_SEQ = MAX_SEQ_LEN / BLOCK_SIZE; // 2048 / 16 = 128

// Llama 3 RoPE scaling parameters
constexpr float ROPE_FACTOR = 32.0f;
constexpr float ROPE_LOW_FREQ_FACTOR = 1.0f;
constexpr float ROPE_HIGH_FREQ_FACTOR = 4.0f;
constexpr int ROPE_ORIGINAL_MAX_POS = 8192;
constexpr int NUM_ROPE_PAIRS = HEAD_DIM / 2; // 32

// Precomputed Llama3-scaled inverse frequencies
__device__ __constant__ float d_rope_freqs[NUM_ROPE_PAIRS];

void initRopeFreqs()
{
float freqs[NUM_ROPE_PAIRS];
float low_freq_wavelen = (float)ROPE_ORIGINAL_MAX_POS / ROPE_LOW_FREQ_FACTOR;
float high_freq_wavelen = (float)ROPE_ORIGINAL_MAX_POS / ROPE_HIGH_FREQ_FACTOR;

for (int i = 0; i < NUM_ROPE_PAIRS; ++i)
{
float inv_freq = 1.0f / powf(500000.0f, (float)(2 * i) / HEAD_DIM);
float wavelen = 2.0f * M_PI / inv_freq;

if (wavelen > low_freq_wavelen)
{
inv_freq /= ROPE_FACTOR;
}
else if (wavelen > high_freq_wavelen)
{
float smooth = (ROPE_ORIGINAL_MAX_POS / wavelen - ROPE_LOW_FREQ_FACTOR) / (ROPE_HIGH_FREQ_FACTOR - ROPE_LOW_FREQ_FACTOR);
inv_freq = (1.0f - smooth) * inv_freq / ROPE_FACTOR + smooth * inv_freq;
}
freqs[i] = inv_freq;
}
cudaMemcpyToSymbol(d_rope_freqs, freqs, NUM_ROPE_PAIRS * sizeof(float));
}

// prefill / shared

// gpu_input_tokens - N tokens
Expand Down Expand Up @@ -90,14 +126,15 @@ __global__ void ropeKernel(__nv_bfloat16 *input, int num_tokens, int proj_dim)
{
if (2 * threadIdx.x + 1 + blockIdx.x * proj_dim < num_tokens * proj_dim)
{
// TODO: precompute thetas, angles and perhaps sin/cos vals and reuse it across all kernel invocations
int double_i = 2 * (threadIdx.x % 32);
float theta = 1.0 / (pow(500000.0, ((float)double_i / HEAD_DIM)));
float angle = blockIdx.x * theta;
__nv_bfloat16 prev_2i = input[2 * threadIdx.x + blockIdx.x * proj_dim];
__nv_bfloat16 prev_2i_1 = input[2 * threadIdx.x + 1 + blockIdx.x * proj_dim];
input[2 * threadIdx.x + blockIdx.x * proj_dim] = (__nv_bfloat16)((float)prev_2i * cos(angle) - (float)prev_2i_1 * sin(angle));
input[2 * threadIdx.x + 1 + blockIdx.x * proj_dim] = (__nv_bfloat16)((float)prev_2i * sin(angle) + (float)prev_2i_1 * cos(angle));
int pair_idx = threadIdx.x % NUM_ROPE_PAIRS;
float angle = blockIdx.x * d_rope_freqs[pair_idx];
float cos_a = cosf(angle);
float sin_a = sinf(angle);
int idx = 2 * threadIdx.x + blockIdx.x * proj_dim;
__nv_bfloat16 prev_2i = input[idx];
__nv_bfloat16 prev_2i_1 = input[idx + 1];
input[idx] = (__nv_bfloat16)((float)prev_2i * cos_a - (float)prev_2i_1 * sin_a);
input[idx + 1] = (__nv_bfloat16)((float)prev_2i * sin_a + (float)prev_2i_1 * cos_a);
}
}

Expand Down Expand Up @@ -280,16 +317,17 @@ void embeddingGatherDecode(int *gpu_last_tokens, int num_tokens, __nv_bfloat16 *

__global__ void ropeKernelDecode(__nv_bfloat16 *input, int position_in_sequence, int proj_dim)
{
if (2 * threadIdx.x + 1 < proj_dim) // TODO: check correctness
if (2 * threadIdx.x + 1 < proj_dim)
{
// TODO: precompute thetas, angles and perhaps sin/cos vals and reuse it across all kernel invocations
int double_i = 2 * (threadIdx.x % 32);
float theta = 1.0 / (pow(500000.0, ((float)double_i / HEAD_DIM)));
float angle = position_in_sequence * theta;
__nv_bfloat16 prev_2i = input[2 * threadIdx.x];
__nv_bfloat16 prev_2i_1 = input[2 * threadIdx.x + 1];
input[2 * threadIdx.x] = (__nv_bfloat16)((float)prev_2i * cos(angle) - (float)prev_2i_1 * sin(angle));
input[2 * threadIdx.x + 1] = (__nv_bfloat16)((float)prev_2i * sin(angle) + (float)prev_2i_1 * cos(angle));
int pair_idx = threadIdx.x % NUM_ROPE_PAIRS;
float angle = position_in_sequence * d_rope_freqs[pair_idx];
float cos_a = cosf(angle);
float sin_a = sinf(angle);
int idx = 2 * threadIdx.x;
__nv_bfloat16 prev_2i = input[idx];
__nv_bfloat16 prev_2i_1 = input[idx + 1];
input[idx] = (__nv_bfloat16)((float)prev_2i * cos_a - (float)prev_2i_1 * sin_a);
input[idx + 1] = (__nv_bfloat16)((float)prev_2i * sin_a + (float)prev_2i_1 * cos_a);
}
}

Expand Down
3 changes: 3 additions & 0 deletions src/kernels.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#pragma once
#include <cuda_bf16.h>

// init
void initRopeFreqs();

// prefill
void embeddingGather(int *gpu_input_tokens, __nv_bfloat16 *gpu_input_embeds, __nv_bfloat16 *embed_tokens, int num_input_tokens);
void rmsNorm(__nv_bfloat16 *input, __nv_bfloat16 *output, nv_bfloat16 *norm_weights, int num_tokens);
Expand Down
4 changes: 2 additions & 2 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,6 @@ void prefill(std::vector<int> &prompt, std::queue<std::vector<int>> &queue, int
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT);

// RoPE now

rope(q_proj, prompt_len, EMBEDDING_LENGTH);
rope(k_proj_temp_buf, prompt_len, KV_DIM);

Expand Down Expand Up @@ -563,6 +561,8 @@ int main(int argc, char *argv[])
return 1;
}

initRopeFreqs();

Weights weights{};
if (loadWeights(weights) != 0)
{
Expand Down