Hello guys, really appreciate your work on Hidet. It is an awesome tool and it really makes developer's life easier when writing custom schedule for their CUDA kernel for performance optimization👍👍!
To test on Hidet's features, I am currently writing an example of the Flash Attention Transformer (link to research work: https://arxiv.org/abs/2205.14135) using the Hidet tool stack. I have writteb my custom testing setup (which contains my own host/device memory allocation & performance tracking & precision comparison code) in my "flash_attention_main.cu", and I am trying to call the kernel functions in Hidet generated cuda dynamic library.
May I know if there is a standard way of doing this? I tried using "dlopen" to load the library and launch the kernel functions but unfortunately it is not working properly. I therefore just manually copied the Hidet generated cuda source code to two separate header files "flash_attention_kernel_func.h" and "normal_transformer_kernel_func.h" and include them in my "flash_attention_main.cu". And I directly compile "flash_attention_main.cu" and everything works properly as well.
Let me share some source code below for illustration.
Here is my flash_attention_example.py, which includes the flash attention custom schedule and the normal approach.
import os
import math
import time
import numpy as np
import torch
import torch.nn as nn
torch.manual_seed(123)
# NOTE: this script is a simplified implementation of the following research work using Hidet
# Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35, 16344-16359.
# link to paper: https://arxiv.org/abs/2205.14135
import hidet
from hidet.ir.compute import compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule
from hidet.ir.primitives.cuda.atomic import atomic_add
from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.graph.ops.definitions.utils import input_like
from hidet.ir.expr import cast, address
from hidet.ir.primitives import exp, max, printf
# define Flash Attention Task
class FlashAttentionTask(Task):
def allow_epilogue(self) -> bool:
return False
def flash_attention_implement_cuda(self, working_dir: str) -> IRModule:
# override this method to use template-based scheduling
return flash_attention_schedule(self)
# Require: Matrices Q�K�V Nxd in HBM, on-chip SRAM of size M.
# NOTE: typical SRAM size 100 kB, default to 48 kB
# NOTE: max thread num is set to 1024
def __init__(self,N=512,d=128,H=16,B=1,M=48*1024,ratio=12,max_thread_num=1024,disable_flash_attention=False):
# 1. set block sizes Bc = ceil(M/(4d)), Br = min(M/(4d),d)
Bc = math.ceil(M/(ratio*d))
Br = min(math.ceil(M/(ratio*d)),d)
Tr = math.ceil(N/Br)
Tc = math.ceil(N/Bc)
GLOBAL_Q = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_Q')
GLOBAL_K = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_K')
GLOBAL_V = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_V')
def normal_transformer():
matmulQK = compute(
name = 'GLOBAL_QK',
shape = [N, N],
fcompute = lambda i, j: reduce(
shape=[d],
fcompute=lambda k: GLOBAL_Q[i, k] * GLOBAL_K[j, k],
reduce_type='sum',
)
)
max_val = lambda i : reduce(shape=[N], fcompute=lambda j: matmulQK[i,j], reduce_type='max')
S = compute(
name = 'S',
shape = [N, N],
fcompute = lambda i,j: matmulQK[i,j] - max_val(i)
)
exp_s = compute(
name = 'exp_s',
shape = [N, N],
fcompute = lambda i,j: exp(S[i,j])
)
exp_sum = lambda i : reduce(shape=[N], fcompute=lambda j: exp_s[i,j], reduce_type='sum')
softmax = compute('softmax', shape=[N,N], fcompute=lambda i,j: exp_s[i,j] / exp_sum(i))
matmulPV = compute(
name = 'GLOBAL_O',
shape = [N, d],
fcompute = lambda i, j: reduce(
shape=[N],
fcompute=lambda k: softmax[i, k] * GLOBAL_V[k, j],
reduce_type='sum',
)
)
return matmulPV
super().__init__(
name='flash_attention_task',
inputs=[GLOBAL_Q,GLOBAL_K,GLOBAL_V],
outputs=[normal_transformer()],
attributes={
'B' : B,
'H' : H,
'N' : N,
'd' : d,
'Bc' : Bc,
'Br' : Br,
'Tc' : Tc,
'Tr' : Tr,
'BLK' : Tr,
'THD' : Br * Bc,
'MAX_THD' : max_thread_num
},
)
if not disable_flash_attention:
self.implement_cuda = self.flash_attention_implement_cuda
self.define = "-DRUN_FLASH_ATTN"
else:
self.define = ""
# define custom schedule
def flash_attention_schedule(task:FlashAttentionTask) -> IRModule:
print_debug = False
B = task.attrs['B']
H = task.attrs['H']
N = task.attrs['N']
d = task.attrs['d']
Bc = task.attrs['Bc']
Br = task.attrs['Br']
Tr = task.attrs['Tr']
Tc = task.attrs['Tc']
dims = ( task.attrs['BLK'] )
threads = task.attrs['THD']
assert threads <= task.attrs['MAX_THD'], f'err: {threads} not < {task.attrs["MAX_THD"]}'
assert d % Bc == 0, f'err: Bc is not divisible by d'
assert d % Br == 0, f'err: Br is not divisible by d'
largest_fp16_value = 65504
print(f'task.attrs {task.attrs}')
# define the tensor program
with hidet.script_module() as module:
"""Flash attention kernel."""
@hidet.script
def QK_matmul_compute(A:f16[Br,d],B:f16[d,Bc],C:f16[Br,Bc]):
for m,n in spatial(Br,Bc).on(threadIdx.x):
C[m,n] = 0.0
syncthreads()
for m,k,n in spatial(Br,1,Bc).repeat(1,d,1).on(threadIdx.x):
atomic_add(~C[m,n],A[m,k] * B[k,n])
syncthreads()
@hidet.script
def PV_matmul_compute(A:f16[Br,Bc],B:f16[Bc,d],C:f16[Br,d]):
for m,n in spatial(Br,Bc).repeat(1,d//Bc).on(threadIdx.x):
C[m,n] = 0.0
syncthreads()
for m,k,n in spatial(Br,1,Bc).repeat(1,Bc,d//Bc).on(threadIdx.x):
atomic_add(~C[m,n],A[m,k] * B[k,n])
syncthreads()
@hidet.script
def rowmax_compute(A:f16[Br,Bc],M:f16[Br],T:f16[Br,Bc]):
for i,j in spatial(Br,Bc).on(threadIdx.x):
T.write([i,j],A[i,j],protected=True)
syncthreads()
for i,j in spatial(Br,Bc).on(threadIdx.x):
k = 1
while k < Bc:
if j % (k*2) == 0:
T.write([i,j],max(T[i,j],T[i,j+k]),protected=True)
syncthreads()
k *= 2
for i in spatial(Br).on(threadIdx.x):
if threadIdx.x < Br:
M[i] = T[i,0]
syncthreads()
@hidet.script
def rowsum_compute(A:f16[Br,Bc],L:f16[Br],T:f16[Br,Bc]):
for i,j in spatial(Br,Bc).on(threadIdx.x):
T.write([i,j],A[i,j],protected=True)
syncthreads()
for i,j in spatial(Br,Bc).on(threadIdx.x):
k = 1
while k < Bc:
if j % (k*2) == 0:
T.write([i,j],(T[i,j]+T[i,j+k]),protected=True)
syncthreads()
k *= 2
for i in spatial(Br).on(threadIdx.x):
if threadIdx.x < Br:
L[i] = T[i,0]
syncthreads()
@hidet.script
def local_softmax_compute(S:f16[Br,Bc],M:f16[Br]):
for i,j in spatial(Br,Bc).on(threadIdx.x):
if False and blockIdx.x==0:
printf("S[i,j] before %d %d %d %d : %f - %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"),cast(M[i],"float32"))
S[i,j] = exp(S[i,j] - M[i])
if False and blockIdx.x==0:
printf("S[i,j] %d %d %d %d : %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"))
syncthreads()
@hidet.script
def local_update_compute(M:f16[Br],M_new:f16[Br],M_local:f16[Br],L:f16[Br],L_new:f16[Br],L_local:f16[Br]):
for i in spatial(Br).on(threadIdx.x):
if threadIdx.x < Br:
M_new[i] = max(M[i],M_local[i])
L_new[i] = exp(M[i] - M_new[i]) * L[i] + exp(M_local[i] - M_new[i]) * L_local[i]
syncthreads()
@hidet.script
def global_update_compute(PV:f16[Br,d],O:f16[Br,d],M_local:f16[Br],M_new:f16[Br],M:f16[Br],L_new:f16[Br],L:f16[Br]):
for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
O.write(
[i,j],
((L_new[i]**-1) * (L[i]*exp(M[i]-M_new[i])) * O[i,j]) + (exp(M_local[i]-M_new[i]) * PV[i,j]),
protected=True
)
syncthreads()
@hidet.script
def flash_attention_kernel(
Q: f16[N,d],
K: f16[N,d],
V: f16[N,d],
O: f16[N,d]
):
attr.cuda_grid_dim = dims
attr.cuda_block_dim = threads
# Init O=(0), N x d in HBM
for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
offset_i = blockIdx.x * (Br)
O[offset_i:,:].write([i,j], 0, protected=True)
syncthreads()
smem_q = tensor('shared', 'float16', [Br, d])
smem_k = tensor('shared', 'float16', [d, Bc]) # transposed
smem_v = tensor('shared', 'float16', [Bc, d])
smem_o = tensor('shared', 'float16', [Br, d])
smem_l = tensor('shared', 'float16', [Br])
smem_l_local = tensor('shared', 'float16', [Br])
smem_l_new = tensor('shared', 'float16', [Br])
smem_m = tensor('shared', 'float16', [Br])
smem_m_local = tensor('shared', 'float16', [Br])
smem_m_new = tensor('shared', 'float16', [Br])
smem_sp = tensor('shared', 'float16', [Br,Bc])
smem_pv = tensor('shared', 'float16', [Br,d])
smem_temp = tensor('shared', 'float16', [Br,Bc])
for a,b in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
# load Qi from HBM to on-chip SRAM
# initialization of o,l,m
offset_i = blockIdx.x * (Br)
smem_q[a,b] = Q[offset_i:,:].read([a,b],protected=True)
smem_o[a,b] = 0
smem_l[a] = 0
smem_m[a] = -largest_fp16_value
syncthreads()
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,d):
printf("idx: %d, Q val: %f\n",idx,cast(smem_q[i,j],"float32"))
idx += 1
syncthreads()
for j in grid(Tc):
for a,b in spatial(Bc,Br).repeat(1,(d//Br)).on(threadIdx.x):
# load Kj,Vj from HBM to on-chip SRAM
offset_j = j * (Bc)
smem_k[b,a] = K[offset_j:,:].read([a,b],protected=True)
smem_v[a,b] = V[offset_j:,:].read([a,b],protected=True)
syncthreads()
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(d,Bc):
printf("idx: %d, K val: %f\n",idx,cast(smem_k[i,j],"float32"))
idx += 1
for i,j in grid(Bc,d):
printf("idx: %d, V val: %f\n",idx,cast(smem_v[i,j],"float32"))
idx += 1
syncthreads()
# on chip, compute Sij = Qi @ (Kj)^T, Br X Bc
QK_matmul_compute(smem_q,smem_k,smem_sp)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,Bc):
printf("idx: %d, S val: %f\n",idx,cast(smem_sp[i,j],"float32"))
idx += 1
syncthreads()
# on chip, compute m'_ij = rowmax(Sij), Br; Pij = exp(Sij - m'_ij), Br x Bc (pointwise); l'_ij = rowsum(P'ij), Br
rowmax_compute(smem_sp,smem_m_local,smem_temp)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
for i in grid(Br):
printf("i: %d M val: %f\n",i,cast(smem_m_local[i],"float32"))
# for j in grid(Bc):
# printf("j: %d, S val: %f\n",j,cast(smem_sp[i,j],"float32"))
syncthreads()
local_softmax_compute(smem_sp,smem_m_local)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,Bc):
printf("idx: %d, P val: %f\n",idx,cast(smem_sp[i,j],"float32"))
idx += 1
syncthreads()
rowsum_compute(smem_sp,smem_l_local,smem_temp)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
for i in grid(Br):
printf("i: %d L val: %f\n",i,cast(smem_l_local[i],"float32"))
# for j in grid(Bc):
# printf("j: %d, P val: %f\n",j,cast(smem_sp[i,j],"float32"))
syncthreads()
# on chip, compute m_new_i = max(m_i,m'_ij), Br; l_new_i = e^(m_i - m_new_i) * l_i + e^(m'_ij - m_i_new) * l'_ij, Br
local_update_compute(smem_m,smem_m_new,smem_m_local,smem_l,smem_l_new,smem_l_local)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
for i in grid(Br):
printf("i: %d smem_m val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_m_new val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_m_local val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_l val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_l_new val: %f\n",i,cast(smem_m[i],"float32"))
printf("i: %d smem_l_local val: %f\n",i,cast(smem_m[i],"float32"))
syncthreads()
# write Oi = diag(l_i_new)^-1 * (diag(l_i)*e^(m_i-m_i_new) @ Oi + e^*m'_ij-m_i_new*(P'ij @ Vj))
PV_matmul_compute(smem_sp,smem_v,smem_pv)
if print_debug and (blockIdx.x==0 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,d):
printf("idx: %d, PV val: %f\n",idx,cast(smem_pv[i,j],"float32"))
idx += 1
syncthreads()
global_update_compute(smem_pv,smem_o,smem_m_local,smem_m_new,smem_m,smem_l_new,smem_l)
if j + 1 == Tc:
for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
offset_i = blockIdx.x * (Br)
O[offset_i:,:].write([i,j], smem_o[i,j], protected=True)
syncthreads()
# write l_i = l_i_new, m_i = m_i_new
for i in spatial(Br).on(threadIdx.x):
if threadIdx.x < Br:
smem_m[i] = smem_m_new[i]
smem_l[i] = smem_l_new[i]
syncthreads()
if print_debug and (blockIdx.x==15 and threadIdx.x==0):
idx = 0
for i,j in grid(Br,d):
offset_i = blockIdx.x * (Br)
printf("blockIdx %d : output idx: %d, val: %f\n",blockIdx.x,idx,cast(O[offset_i+i,j],"float32"))
idx += 1
syncthreads()
return
@hidet.script
def flash_attention_launch_func(
G_Q: f16[B, H, N, d],
G_K: f16[B, H, N, d],
G_V: f16[B, H, N, d],
G_O: f16[B, H, N, d]
):
# NOTE: this section needs to be written in flash_attention_main.cu
for b,h in grid(B,H):
flash_attention_kernel(
address(G_Q[b,h,0,0]),
address(G_K[b,h,0,0]),
address(G_V[b,h,0,0]),
address(G_O[b,h,0,0])
)
# build ir module
ir_module = module.ir_module()
return ir_module
# gen Python gold data as reference
def gen_gold(attrs,r1=-3,r2=3):
Q = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
K = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
V = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
t = time.process_time()
Q.half().numpy().tofile('mat_Q.bin')
K.half().numpy().tofile('mat_K.bin')
V.half().numpy().tofile('mat_V.bin')
S = torch.from_numpy(Q.numpy() @ torch.transpose(K, -2, -1).numpy())
row_max, _ = torch.max(S,dim=-1)
S = torch.from_numpy(np.exp((S - row_max.reshape(attrs['B'],attrs['H'],attrs['N'],1)).numpy()))
row_sum = torch.sum(S,dim=-1).reshape(attrs['B'],attrs['H'],attrs['N'],1)
P = S / row_sum
# TODO: test with softmax float precision
# P = nn.Softmax(dim=-1)(S.float()).half()
O = torch.from_numpy(P.numpy() @ V.numpy())
elapsed_time = (time.process_time() - t)*1000
print(f"Python gold gen run elapsed time {round(elapsed_time,3)} msec")
O.half().numpy().tofile('gold_mat_O.bin')
# run task
def run_task(disable_flash_attention=False):
# define the task here
flash_attention_task = FlashAttentionTask(disable_flash_attention=disable_flash_attention)
# build the task
ret = flash_attention_task.build(target='cuda')
# copy source file and lib to current directory
source_path = ret.src_path
library_path = ret.lib_path
print(f'source_path {source_path} library_path {library_path}')
import shutil
shutil.copyfile(source_path,os.path.join("./","flash_attention_"+os.path.basename(source_path)))
shutil.copyfile(library_path,os.path.join("./","flash_attention_"+os.path.basename(library_path)))
# generate golden data
gen_gold(flash_attention_task.attrs)
def exe_f(command='', shell=True):
print(f'running {command}')
import subprocess
process = subprocess.Popen(command, shell=shell)
code = process.wait()
process.communicate()
return code
# launch testcase flash_attention_main.cu
HIDET_CUDA_INCLUDE_PATH = "../cuda-samples-master/Common/"
CUDA_SAMPLES_INCLUDE_PATH = "../../include/"
ret = exe_f(f'nvcc flash_attention_main.cu {flash_attention_task.define} -gencode arch=compute_86,code=sm_86 -I {CUDA_SAMPLES_INCLUDE_PATH} -I {HIDET_CUDA_INCLUDE_PATH} -std=c++11 -o fa.out && ./fa.out -BATCH={flash_attention_task.attrs["B"]} -HEAD={flash_attention_task.attrs["H"]} -BLK={flash_attention_task.attrs["BLK"]} -THD={flash_attention_task.attrs["THD"]}')
print('test done' if ret==0 else 'test error')
# main function
if __name__ == '__main__':
# normal approach execution
run_task(disable_flash_attention=True)
# flash attention approach execution
run_task(disable_flash_attention=False)
Here is my flash_attention_main.cu, which includes the performance tracking, precision comparison & memory allocation operations, and it lauches the test kernels.
// System includes
#include <stdio.h>
#include <sys/stat.h>
#include <dlfcn.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <assert.h>
#include <vector>
// CUDA runtime
#include <cuda_runtime.h>
#include <cuda_profiler_api.h>
// Helper functions and utilities to work with CUDA,
#include <helper_functions.h>
#include <helper_cuda.h>
#include <cuda_fp16.h>
// Import kernel functions
#include "flash_attention_kernel_func.h"
#include "normal_transformer_kernel_func.h"
// test function, execute kernel, compare with gold data
int flash_attention_test(
unsigned int B, unsigned int H,
unsigned int block_size, unsigned int thread_size,
half *h_Q, unsigned int size_Q,
half *h_K, unsigned int size_K,
half *h_V, unsigned int size_V,
half *h_gold_O, unsigned int size_O)
{
cudaStream_t stream;
const unsigned int BH = B * H;
// Allocate device memory
half *d_Q, *d_K, *d_V, *d_O, *h_O;
checkCudaErrors(cudaMallocHost(&h_O, size_O * sizeof(half)));
if (h_O == NULL)
{
fprintf(stderr, "Failed to allocate host matrix O!\n");
exit(EXIT_FAILURE);
}
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_Q), size_Q * sizeof(half)));
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_K), size_K * sizeof(half)));
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_V), size_V * sizeof(half)));
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_O), size_O * sizeof(half)));
// Allocate CUDA events that we'll use for timing
cudaEvent_t start, stop;
checkCudaErrors(cudaEventCreate(&start));
checkCudaErrors(cudaEventCreate(&stop));
checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
// copy host memory to device
checkCudaErrors(
cudaMemcpyAsync(d_Q, h_Q, size_Q * sizeof(half), cudaMemcpyHostToDevice, stream));
checkCudaErrors(
cudaMemcpyAsync(d_K, h_K, size_K * sizeof(half), cudaMemcpyHostToDevice, stream));
checkCudaErrors(
cudaMemcpyAsync(d_V, h_V, size_V * sizeof(half), cudaMemcpyHostToDevice, stream));
const unsigned int k_size_Q = (size_Q / BH);
const unsigned int k_size_K = (size_K / BH);
const unsigned int k_size_V = (size_V / BH);
const unsigned int k_size_O = (size_O / BH);
printf("k_size_Q %u k_size_K %u k_size_V %u k_size_O %u\n", k_size_Q, k_size_K, k_size_V, k_size_O);
// Record the start event
checkCudaErrors(cudaEventRecord(start, stream));
const int32_t num_args = 4;
for (unsigned int b = 0; b < B; b++)
{
for (unsigned int h = 0; h < H; h++)
{
unsigned int offset_index = (b * H) + h;
half *param[num_args] = {
d_Q + offset_index * k_size_Q,
d_K + offset_index * k_size_K,
d_V + offset_index * k_size_V,
d_O + offset_index * k_size_O};
#ifdef RUN_FLASH_ATTN
// run flash attention kernel
flash_attention_kernel<<<dim3(16, 1, 1), dim3(1024, 1, 1), 0, (cudaStream_t)stream>>>(((half *)(param[0])), ((half *)(param[1])), ((half *)(param[2])), ((half *)(param[3])));
#else
// run normal transformer kernel
uint8_t *buffer;
checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&buffer), int64_t(2097152ll)));
half *GLOBAL_QK = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(0ll) * ((int64_t)(1))))]));
half *S = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(524288ll) * ((int64_t)(1))))]));
half *exp_s = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(1048576ll) * ((int64_t)(1))))]));
half *softmax = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(1572864ll) * ((int64_t)(1))))]));
hidet_compute_GLOBAL_QK<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(param[0], param[1], GLOBAL_QK);
hidet_compute_S<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(GLOBAL_QK, param[0], param[1], S);
hidet_compute_exp_s<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(S, exp_s);
hidet_compute_softmax<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(S, param[0], param[1], GLOBAL_QK, exp_s, softmax);
hidet_compute_GLOBAL_O<<<dim3(128, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(softmax, ((half *)(param[2])), param[0], param[1], GLOBAL_QK, S, exp_s, ((half *)(param[3])));
#endif // RUN_FLASH_ATTN
}
}
checkCudaErrors(cudaStreamSynchronize(stream));
// Record the stop event
checkCudaErrors(cudaEventRecord(stop, stream));
printf("test done !!!\n");
// Wait for the stop event to complete
checkCudaErrors(cudaEventSynchronize(stop));
float msecTotal = 0.0f;
checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));
// Compute and print the performance
#if RUN_FLASH_ATTN
printf("flash attention elapsed time = %.3f msec\n", msecTotal);
#else
printf("normal approach elapsed time = %.3f msec\n", msecTotal);
#endif // RUN_FLASH_ATTN
// Copy result from device to host
checkCudaErrors(
cudaMemcpyAsync(h_O, d_O, size_O * sizeof(half), cudaMemcpyDeviceToHost, stream));
checkCudaErrors(cudaStreamSynchronize(stream));
printf("Checking computed result for correctness: \n");
double eps = 0.01; // 1% error with python output
const unsigned int max_print_count = 100;
uint32_t total_count = 0;
uint32_t total_err_count = 0;
for (int i = 0; i < static_cast<int>(size_O); i++)
{
double gold_val = fabs((double)h_gold_O[i]);
double abs_val = fabs((double)h_O[i]);
double abs_err = fabs(abs_val - gold_val);
double rel_err = abs_err / abs_val;
if (rel_err > eps)
{
if (total_err_count < max_print_count)
printf("Error! Matrix[%05d]=%.8f, ref=%.8f error term %E is > %E\n",
i, (double)h_O[i], (double)h_gold_O[i], rel_err, eps);
total_err_count++;
}
total_count++;
}
double error_ratio = (double)total_err_count / (double)total_count;
bool correct = error_ratio < eps;
printf("total count %u total error count %u (%.8f %%)\n", total_count, total_err_count, error_ratio * 100);
printf("%s\n", correct ? "Result = PASS" : "Result = FAIL");
// Clean up memory
checkCudaErrors(cudaFree(d_Q));
checkCudaErrors(cudaFree(d_K));
checkCudaErrors(cudaFree(d_V));
checkCudaErrors(cudaFree(d_O));
checkCudaErrors(cudaEventDestroy(start));
checkCudaErrors(cudaEventDestroy(stop));
if (correct)
{
return EXIT_SUCCESS;
}
else
{
return EXIT_FAILURE;
}
}
inline bool file_exists(const std::string &name)
{
struct stat buffer;
return (stat(name.c_str(), &buffer) == 0);
}
void load_data(std::vector<half> &matrix, const std::string bin_file)
{
printf("loading %s\n", bin_file.c_str());
assert(file_exists(bin_file) && "Error! binary file doesn't exist");
std::ifstream fin(bin_file, std::ios::binary);
half elem;
while (fin.read(reinterpret_cast<char *>(&elem), sizeof(half)))
{
matrix.push_back(elem);
}
}
int main(int argc, char **argv)
{
printf("[Flash Attention Using CUDA] - Starting...\n");
if (checkCmdLineFlag(argc, (const char **)argv, "help") ||
checkCmdLineFlag(argc, (const char **)argv, "?"))
{
printf("Usage -device=n (n >= 0 for deviceID)\n");
printf(" -BATCH=number of Batch\n");
printf(" -HEAD=number of Head\n");
printf(" -BLK=block size\n");
printf(" -THD=thread size\n");
exit(EXIT_SUCCESS);
}
// This will pick the best possible CUDA capable device, otherwise
// override the device ID based on input provided at the command line
int dev = findCudaDevice(argc, (const char **)argv);
unsigned int batch = 1;
if (checkCmdLineFlag(argc, (const char **)argv, "BATCH"))
{
batch = getCmdLineArgumentInt(argc, (const char **)argv, "BATCH");
}
unsigned int head = 1;
if (checkCmdLineFlag(argc, (const char **)argv, "HEAD"))
{
head = getCmdLineArgumentInt(argc, (const char **)argv, "HEAD");
}
unsigned int block_size = 1;
if (checkCmdLineFlag(argc, (const char **)argv, "BLK"))
{
block_size = getCmdLineArgumentInt(argc, (const char **)argv, "BLK");
}
unsigned int thread_size = 1;
if (checkCmdLineFlag(argc, (const char **)argv, "THD"))
{
thread_size = getCmdLineArgumentInt(argc, (const char **)argv, "THD");
}
// load Q
std::vector<half> mat_Q;
load_data(mat_Q, "./mat_Q.bin");
// load K
std::vector<half> mat_K;
load_data(mat_K, "./mat_K.bin");
// load V
std::vector<half> mat_V;
load_data(mat_V, "./mat_V.bin");
// load golden data O
std::vector<half> gold_mat_O;
load_data(gold_mat_O, "./gold_mat_O.bin");
printf("batch %u head %u block_size %u thread_size %u\n", batch, head, block_size, thread_size);
printf("Q size %lu K size %lu V size %lu O size %lu\n", mat_Q.size(), mat_K.size(), mat_V.size(), gold_mat_O.size());
checkCudaErrors(cudaProfilerStart());
int result = flash_attention_test(
batch, head, block_size, thread_size,
&mat_Q[0], mat_Q.size(),
&mat_K[0], mat_K.size(),
&mat_V[0], mat_V.size(),
&gold_mat_O[0], gold_mat_O.size());
checkCudaErrors(cudaProfilerStop());
exit(result);
}
Here are the flash_attention_kernel_func.h and normal_transformer_func.h, respectively.
Again, really wonderful work on Hidet! And any help will be well appreciated 🙏 Or if any further info. is needed, please let me know.