Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/infinicore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import contextlib

import torch
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不能在框架里这么引入torch

import infinicore.context as context
import infinicore.nn as nn

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __PAGED_ATTENTION_METAX_H__
#define __PAGED_ATTENTION_METAX_H__

#include "../paged_attention.h"

DESCRIPTOR(metax)

#endif // __PAGED_ATTENTION_METAX_H__
149 changes: 149 additions & 0 deletions src/infiniop/ops/paged_attention/metax/paged_attention_metax.maca
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#ifdef ENABLE_METAX_MC_API
#include <mccub/block/block_reduce.cuh>
#else
#include <hccub/block/block_reduce.cuh>
#endif

#include "../../../reduce/cuda/reduce.cuh"
#include "../cuda/kernel.cuh"
#include "paged_attention_metax.h"

template <typename Tdata, typename Tcompute, size_t HEAD_SIZE, size_t NUM_THREADS>
INFINIOP_METAX_KERNEL pagedAttention(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables, const int64_t *seq_lens, const float *alibi_slopes,
const size_t num_kv_heads, const float scale, const size_t max_num_blocks_per_seq,
const size_t block_size,
const ptrdiff_t q_stride,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t o_stride) {
op::paged_attention::cuda::pagedAttentionKernel<Tdata, Tcompute, HEAD_SIZE, NUM_THREADS>(
out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, num_kv_heads, scale,
max_num_blocks_per_seq, block_size, q_stride, kv_block_stride, kv_head_stride, o_stride);
}

namespace op::paged_attention::metax {

struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {
auto info = PagedAttentionInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, alibi_slopes_desc, scale);
CHECK_RESULT(info);
*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);

return INFINI_STATUS_SUCCESS;
}

template <size_t HEAD_SIZE, size_t NUM_THREADS>
infiniStatus_t launchKernel(void *out, const void *q, const void *k_cache, const void *v_cache,
infiniDtype_t dtype,
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
size_t num_heads, size_t num_seqs,
size_t num_kv_heads, float scale, size_t max_num_blocks_per_seq, size_t block_size,
ptrdiff_t q_stride, ptrdiff_t kv_block_stride, ptrdiff_t kv_head_stride, ptrdiff_t o_stride,
hcStream_t stream) {
dim3 grid(uint64_t(num_heads), uint64_t(num_seqs), 1);
dim3 block(NUM_THREADS);
size_t shared_mem_size = (HEAD_SIZE + max_num_blocks_per_seq * block_size + 2) * sizeof(float);

if (dtype == INFINI_DTYPE_F16) {
pagedAttention<half, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(half *)out,
(const half *)q, (const half *)k_cache, (const half *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else if (dtype == INFINI_DTYPE_BF16) {
pagedAttention<cuda_bfloat16, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(cuda_bfloat16 *)out, (const cuda_bfloat16 *)q, (const cuda_bfloat16 *)k_cache, (const cuda_bfloat16 *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else if (dtype == INFINI_DTYPE_F32) {
pagedAttention<float, float, HEAD_SIZE, NUM_THREADS>
<<<grid, block, shared_mem_size, stream>>>(
(float *)out, (const float *)q, (const float *)k_cache, (const float *)v_cache,
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const float *)alibi_slopes, num_kv_heads,
scale, max_num_blocks_per_seq, block_size,
q_stride, kv_block_stride, kv_head_stride, o_stride);
} else {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables, const void *seq_lens, const void *alibi_slopes,
void *stream_) const {
hcStream_t stream = (hcStream_t)stream_;

#define LAUNCH_HEADSIZE_BLOCKSIZE(__H_SIZE, __B_SIZE) \
launchKernel<__H_SIZE, __B_SIZE>( \
out, q, k_cache, v_cache, _info.dtype, block_tables, seq_lens, alibi_slopes, \
_info.num_heads, _info.num_seqs, \
_info.num_kv_heads, _info.scale, _info.max_num_blocks_per_seq, _info.block_size, \
_info.q_stride, _info.kv_block_stride, _info.kv_head_stride, _info.o_stride, \
stream);

#define SWITCH_HEAD_SIZE(__B_SIZE) \
switch (_info.head_size) { \
case 16: \
LAUNCH_HEADSIZE_BLOCKSIZE(16, __B_SIZE) \
break; \
case 32: \
LAUNCH_HEADSIZE_BLOCKSIZE(32, __B_SIZE) \
break; \
case 64: \
LAUNCH_HEADSIZE_BLOCKSIZE(64, __B_SIZE) \
break; \
case 128: \
LAUNCH_HEADSIZE_BLOCKSIZE(128, __B_SIZE) \
break; \
case 256: \
LAUNCH_HEADSIZE_BLOCKSIZE(256, __B_SIZE) \
break; \
default: \
return INFINI_STATUS_BAD_TENSOR_SHAPE; \
}

int max_threads = _opaque->internal->maxThreadsPerBlock();
if (max_threads >= METAX_BLOCK_SIZE_1024) {
SWITCH_HEAD_SIZE(METAX_BLOCK_SIZE_1024)
} else if (max_threads >= METAX_BLOCK_SIZE_512) {
SWITCH_HEAD_SIZE(METAX_BLOCK_SIZE_512)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}

#undef LAUNCH_HEADSIZE_BLOCKSIZE
#undef SWITCH_HEAD_SIZE

return INFINI_STATUS_SUCCESS;
}

} // namespace op::paged_attention::metax
30 changes: 15 additions & 15 deletions src/infiniop/ops/paged_attention/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#ifdef ENABLE_NVIDIA_API
#include "nvidia/paged_attention_nvidia.cuh"
#endif
// #ifdef ENABLE_METAX_API
// #include "metax/paged_attention_metax.h"
// #endif
#ifdef ENABLE_METAX_API
#include "metax/paged_attention_metax.h"
#endif

__C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
infiniopHandle_t handle,
Expand All @@ -34,9 +34,9 @@ __C infiniStatus_t infiniopCreatePagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
CREATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CREATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CREATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand All @@ -55,9 +55,9 @@ __C infiniStatus_t infiniopGetPagedAttentionWorkspaceSize(
#ifdef ENABLE_NVIDIA_API
GET(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// GET(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
GET(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand All @@ -80,9 +80,9 @@ __C infiniStatus_t infiniopPagedAttention(
#ifdef ENABLE_NVIDIA_API
CALCULATE(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// CALCULATE(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
CALCULATE(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand All @@ -100,9 +100,9 @@ __C infiniStatus_t infiniopDestroyPagedAttentionDescriptor(
#ifdef ENABLE_NVIDIA_API
DESTROY(INFINI_DEVICE_NVIDIA, nvidia)
#endif
// #ifdef ENABLE_METAX_API
// DESTROY(INFINI_DEVICE_METAX, metax)
// #endif
#ifdef ENABLE_METAX_API
DESTROY(INFINI_DEVICE_METAX, metax)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#ifndef __PAGED_ATTENTION_PREFILL_METAX_H__
#define __PAGED_ATTENTION_PREFILL_METAX_H__

#include "../paged_attention_prefill.h"

DESCRIPTOR(metax)

#endif // __PAGED_ATTENTION_PREFILL_METAX_H__
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#include <float.h>
#include <math.h>
#include <stdint.h>

#include "../../../devices/metax/metax_common.h"
#include "../../../devices/metax/metax_kernel_common.h"
#include "../cuda/kernel.cuh"
#include "paged_attention_prefill_metax.h"

template <typename Tdata, typename Tcompute>
infiniStatus_t launchPagedAttentionPrefill(
Tdata *out, const Tdata *q, const Tdata *k_cache, const Tdata *v_cache,
const int64_t *block_tables,
const int64_t *seq_lens,
const int64_t *cum_seq_lens_q,
const float *alibi_slopes,
const size_t num_heads,
const size_t num_seqs,
const size_t num_kv_heads,
const float scale,
const size_t max_num_blocks_per_seq,
const size_t block_size,
const size_t total_q_tokens,
const size_t head_size,
const ptrdiff_t kv_block_stride,
const ptrdiff_t kv_head_stride,
const ptrdiff_t q_stride,
const ptrdiff_t q_head_stride,
hcStream_t stream) {

if (total_q_tokens == 0 || num_heads == 0) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

if (head_size == 0 || head_size > 1024) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

dim3 grid(static_cast<unsigned int>(total_q_tokens), static_cast<unsigned int>(num_heads));
dim3 block(static_cast<unsigned int>(head_size));

op::paged_attention_prefill::cuda::pagedAttentionPrefillKernel<Tdata, Tcompute>
<<<grid, block, 0, stream>>>(
out, q, k_cache, v_cache,
block_tables, seq_lens, cum_seq_lens_q, alibi_slopes,
num_heads, num_kv_heads, scale,
max_num_blocks_per_seq, block_size,
kv_block_stride, kv_head_stride,
q_stride, q_head_stride,
head_size,
num_seqs);

return INFINI_STATUS_SUCCESS;
}

namespace op::paged_attention_prefill::metax {

struct Descriptor::Opaque {
std::shared_ptr<device::metax::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t cum_seq_lens_q_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {

auto info = PagedAttentionPrefillInfo::create(
out_desc, q_desc, k_cache_desc, v_cache_desc,
block_tables_desc, seq_lens_desc,
cum_seq_lens_q_desc,
alibi_slopes_desc, scale);

CHECK_RESULT(info);

*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *out, const void *q, const void *k_cache, const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *cum_seq_lens_q,
const void *alibi_slopes,
void *stream_) const {

hcStream_t stream = (hcStream_t)stream_;

#define LAUNCH_KERNEL(Tdata, Tcompute) \
launchPagedAttentionPrefill<Tdata, Tcompute>( \
(Tdata *)out, (const Tdata *)q, (const Tdata *)k_cache, (const Tdata *)v_cache, \
(const int64_t *)block_tables, (const int64_t *)seq_lens, (const int64_t *)cum_seq_lens_q, \
(const float *)alibi_slopes, \
_info.num_heads, _info.num_seqs, _info.num_kv_heads, \
_info.scale, _info.max_num_blocks_per_seq, \
_info.block_size, _info.total_q_tokens, \
_info.head_size, \
_info.kv_block_stride, _info.kv_head_stride, \
_info.q_stride, _info.q_head_stride, \
stream)

if (_info.dtype == INFINI_DTYPE_F16) {
return LAUNCH_KERNEL(half, float);
} else if (_info.dtype == INFINI_DTYPE_BF16) {
return LAUNCH_KERNEL(cuda_bfloat16, float);
} else if (_info.dtype == INFINI_DTYPE_F32) {
return LAUNCH_KERNEL(float, float);
}

return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

} // namespace op::paged_attention_prefill::metax
Loading
Loading