Skip to content

Commit

Permalink
Dynamic Scale Factor Calculations for Key/Value Scales With FP8 KV Ca…
Browse files Browse the repository at this point in the history
…ching (#317)

* Changed _k_scale and _v_scale to tensors

* fixed rocm paged attention with tensor kv scales

* Added on the fly scale factor calculation

* trying to fix attn metadata

* fixed AttentionMetadata issue, updated description for calculate-kv-scales flag in arg_utils.py

* Changed K and V scale constants

* Removed unneeded comment

* Changes to pass format.sh, also fixed lingering k_scale/v_scale : float

* Fix for TP > 1

* Ran format.sh

* Removed legacy kv_scale loading from the json file

* Removed the outdated kv cache docs

* Revert some unwanted changes

---------

Co-authored-by: Gregory Shtrasberg <[email protected]>
Signed-off-by: Gregory Shtrasberg <[email protected]>
  • Loading branch information
micah-wil and gshtras committed Dec 18, 2024
1 parent ca5f54a commit b302bfb
Show file tree
Hide file tree
Showing 36 changed files with 180 additions and 1,315 deletions.
10 changes: 5 additions & 5 deletions csrc/attention/attention_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y;
Expand Down Expand Up @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, k_scale);
k_vec_quant, *k_scale);
}
}

Expand Down Expand Up @@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
v_scale);
*v_scale);
}
if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the
Expand Down Expand Up @@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
Expand Down Expand Up @@ -549,7 +549,7 @@ __global__ void paged_attention_v2_kernel(
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float k_scale, const float v_scale, const int tp_rank,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
Expand Down
17 changes: 10 additions & 7 deletions csrc/attention/paged_attention_v1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
k_scale_ptr, v_scale_ptr, tp_rank, blocksparse_local_blocks, \
blocksparse_vert_stride, blocksparse_block_size, \
blocksparse_head_sliding_step);

Expand All @@ -53,10 +53,10 @@ void paged_attention_v1_launcher(
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -80,6 +80,8 @@ void paged_attention_v1_launcher(
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int padded_max_seq_len =
Expand Down Expand Up @@ -177,8 +179,9 @@ void paged_attention_v1(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
Expand Down
17 changes: 10 additions & 7 deletions csrc/attention/paged_attention_v2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \
blocksparse_local_blocks, blocksparse_vert_stride, \
blocksparse_block_size, blocksparse_head_sliding_step); \
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
Expand All @@ -54,10 +54,10 @@ void paged_attention_v2_launcher(
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {
const c10::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -84,6 +84,8 @@ void paged_attention_v2_launcher(
CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());

constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
Expand Down Expand Up @@ -188,8 +190,9 @@ void paged_attention_v2(
torch::Tensor& seq_lens, // [num_seqs]
int64_t block_size, int64_t max_seq_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step) {
const bool is_block_sparse = (blocksparse_vert_stride > 1);
Expand Down
6 changes: 3 additions & 3 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale);
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);

void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
const double k_scale, const double v_scale);
torch::Tensor& k_scale, torch::Tensor& v_scale);

// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
Expand Down
30 changes: 17 additions & 13 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
// block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x, const float k_scale,
const float v_scale) {
const int head_size, const int block_size, const int x,
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
Expand Down Expand Up @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel(
value_cache[tgt_value_idx] = tgt_value;
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
}
}
Expand All @@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel(
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, const int key_stride, const int value_stride,
const int num_heads, const int head_size, const int block_size,
const float k_scale, const float v_scale) {
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
Expand All @@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel(
value_cache[tgt_key_value_idx] = tgt_value;
} else {
key_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_key_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
}
}
Expand All @@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, x, k_scale, v_scale);
num_heads, head_size, block_size, x, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));

void reshape_and_cache(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
Expand All @@ -268,8 +270,8 @@ void reshape_and_cache(
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_tokens = key.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
Expand Down Expand Up @@ -299,7 +301,9 @@ void reshape_and_cache(
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
value_stride, num_heads, head_size, block_size, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));

void reshape_and_cache_flash(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
Expand All @@ -308,8 +312,8 @@ void reshape_and_cache_flash(
torch::Tensor&
value_cache, // [num_blocks, block_size, num_heads, head_size]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, const double k_scale,
const double v_scale) {
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
Expand Down
10 changes: 6 additions & 4 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ void paged_attention_v1(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);

Expand All @@ -45,8 +46,9 @@ void paged_attention_v2(
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const int64_t tp_rank,
const int64_t blocksparse_local_blocks,
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
const int64_t blocksparse_head_sliding_step);

Expand Down
17 changes: 10 additions & 7 deletions csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
Expand Down Expand Up @@ -406,7 +406,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const _B8x8 Vlocalb8 = v_ptrh8be[d];
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, *v_scale_ptr);
}
}
}
Expand All @@ -416,7 +416,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
Klocal[d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], *k_scale_ptr);
}
}

Expand Down Expand Up @@ -890,7 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
UNREACHABLE_CODE
}

Expand Down Expand Up @@ -919,7 +919,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale, v_scale);
k_scale_ptr, v_scale_ptr);

template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
Expand All @@ -929,7 +929,7 @@ void paged_attention_custom_launcher(
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
float k_scale, float v_scale) {
torch::Tensor& k_scale, torch::Tensor& v_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -953,6 +953,8 @@ void paged_attention_custom_launcher(
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());

const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
const int max_num_partitions =
Expand Down Expand Up @@ -1087,7 +1089,8 @@ void paged_attention(
torch::Tensor& context_lens, // [num_seqs]
int64_t block_size, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Half) {
Expand Down
4 changes: 2 additions & 2 deletions csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& context_lens, int64_t block_size,
int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale,
double v_scale);
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);
2 changes: 1 addition & 1 deletion csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
" Tensor k_scale, Tensor v_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}

Expand Down
Loading

0 comments on commit b302bfb

Please sign in to comment.