Code Monkey home page Code Monkey logo

hidet's Introduction

Hidet: An Open-Source Deep Learning Compiler

Documentation | Research Paper | Releases | Contributing

GitHub GitHub Workflow Status

Hidet is an open-source deep learning compiler, written in Python. It supports end-to-end compilation of DNN models from PyTorch and ONNX to efficient cuda kernels. A series of graph-level and operator-level optimizations are applied to optimize the performance.

Currently, hidet focuses on optimizing the inference workloads on NVIDIA GPUs, and requires

  • Linux OS
  • CUDA Toolkit 11.6+
  • Python 3.8+

Getting Started

Installation

pip install hidet

You can also try the nightly build version or build from source.

Usage

Optimize a PyTorch model through hidet (require PyTorch 2.0):

import torch

# Define pytorch model
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).cuda().eval()
x = torch.rand(1, 3, 224, 224).cuda()

# Compile the model through Hidet
# Optional: set optimization options (see our documentation for more details)
#   import hidet 
#   hidet.torch.dynamo_config.search_space(2)  # tune each tunable operator
#   hidet.torch.dynamo_config.use_fp16()       # use float16 for acceleration
model_opt = torch.compile(model, backend='hidet')  

# Run the optimized model
y = model_opt(x)

See the following tutorials to learn other usages:

Publication

Hidet originates from the following research work:

Hidet: Task-Mapping Programming Paradigm for Deep Learning Tensor Programs
Yaoyao Ding, Cody Hao Yu, Bojian Zheng, Yizhi Liu, Yida Wang, and Gennady Pekhimenko.
ASPLOS '23

If you used Hidet in your research, welcome to cite our paper.

Development

Hidet is currently under active development by a team at CentML Inc.

Contributing

We welcome contributions from the community. Please see contribution guide for more details.

License

Hidet is released under the Apache 2.0 license.

hidet's People

Contributors

yaoyaoding avatar hjjq avatar aalanli avatar xinli-git avatar ktong821 avatar soodoshll avatar destefy avatar andreslavescu avatar vadiklyutiy avatar ldy1998 avatar bolinsnlhm avatar xiaocenxiaocen avatar serach24 avatar dbabokin avatar eltociear avatar fishingguy456 avatar hunbssfy avatar digital-nomad-cheng avatar yudi0201 avatar

Stargazers

Konstantinos Mouratidis avatar Hamza Zaidi avatar  avatar mnemonyx avatar ChengXiang Qi avatar EspiDev avatar Peihong Liu avatar Jake Xue avatar Cao Ying avatar  avatar Aishwary Sharad Patil avatar Yuan avatar  avatar Francesco Pisu avatar Albert Sund Aillet avatar  avatar Koshi Eguchi avatar Nikhil Mehta avatar Do Nhat Phong avatar  avatar linxiaobo avatar Tillmann Radmer avatar  avatar Parth Desai avatar Baladithya Balamurugan avatar Tanguy Launay avatar Jasmine Tang avatar  avatar  avatar Houssem MENHOUR avatar Aleksandr Smechov avatar Jingchang Shi avatar Xinpeng Wei avatar Vladimir Kroz avatar Akeel Ather Medina avatar  avatar Navin Mohan avatar Fernando Sckaff avatar Manickavela avatar  avatar  avatar Fatih Kıyıkçı avatar Zhilin (Jerry) Wang avatar Saar Eliad avatar Yash Wadgave avatar Dali Kilani avatar Praveen Yadav avatar Yijia Diao avatar Jinghan Huang avatar Grigori Fursin avatar Jesse Clark avatar LiangxuanZhao avatar  avatar Jedi Chou avatar Sun Jeong avatar Jorn Peters avatar Chuzhe Tang avatar Sun Jeong avatar Derry Redjeki avatar Juan Pablo Manson avatar Sam Fazel avatar Ravi Kumar Reddy K avatar Hyeonmin Ha avatar Kim Jae-Jin (김재진) avatar  avatar Liad avatar  avatar Yushu Gao avatar Reiase avatar Damon avatar JimyMa avatar  avatar Elliot avatar Jeff Hammerbacher avatar ttbachyinsda avatar Sangshin Oh avatar  avatar  avatar XXZH avatar YangjieZhou avatar Jeroen De Maeseneire avatar  avatar _HYX_ avatar awer-A avatar Zhang Jun avatar Job Henandez Lara avatar  avatar HanayoZz avatar LiJiayang avatar Hazem Essam avatar Joshua Levy avatar Franklyn Dsouza avatar  avatar Mason Ma avatar Decision analytics for computational engineering research group avatar James Zefei Ou avatar zhrrr avatar Wentai Zhang avatar  avatar Marco Morais avatar

Watchers

OKUMURA Yoshio avatar Mike avatar Kenn avatar  avatar  avatar Zhiqiang Wang avatar Anand J avatar Gennady Pekhimenko avatar  avatar Shang Wang avatar Saksham Saini avatar  avatar neos avatar Nicolai M. T. Lassen avatar Kostas Georgiou avatar  avatar  avatar  avatar

hidet's Issues

[Bug] Inconsistent definition of the inputs parameter of operators

Currently, operators have different constructor signatures according to different numbers of input tensors. For example:

  • UnaryOp(Tensor)
  • BinaryOp(Tensor, Tensor)
  • ReduceOp(List[Tensor])

Which causes a problem in

return cls(*inputs, **attributes).run()

No matter which signature the constructor uses, the input tensors will be flattened and stored in inputs attribute of an Operator (with a type of List[Tensor]). When we need to reuse the saved inputs, it's hard to deal with different signatures separately. Maybe it's better to use a unified signature, like Op(List[Tensor])?

[Bug] operator.gt got an object with type <class 'int'>

Hi @yaoyaoding, I encounter a bug when I run T5Model with hidet:

DEBUG:hidet.graph.frontend.torch.interpreter:interpreting node 34: %gt : [#users=1] = call_function[target=operator.gt](args = (%sub_1, 0), kwargs = {})
Traceback (most recent call last):
  File "hidet-path/python/hidet/graph/frontend/torch/interpreter.py", line 260, in forward
    hidet_env[node.name] = hidet_func(*hidet_args, **hidet_kwargs)
  File "hidet-path/python/hidet/graph/frontend/torch/register_functions.py", line 694, in gt
    return ops.greater(a, b)
  File "hidet-path/python/hidet/graph/ops/definitions/compare.py", line 85, in greater
    return GreaterOp(x, y).get_output(0)
  File "hidet-path/python/hidet/graph/ops/definitions/compare.py", line 34, in __init__
    super().__init__(x, y, lambda a, b: a > b, name='gt')
  File "hidet-path/python/hidet/graph/ops/definitions/arithmetic.py", line 125, in __init__
    task=BinaryElementwiseTask(name, input_like(x, 'x'), input_like(y, 'y'), op=op),
  File "hidet-path/python/hidet/graph/ops/definitions/utils/tensor_utils.py", line 26, in input_like
    raise TypeError('Expect a hidet.Tensor, but got an object with type {}'.format(type(tensor)))
TypeError: Expect a hidet.Tensor, but got an object with type <class 'int'>

Could you please help me fix it?

Post-Scheduling Fusion with TensorCore

Hello, I just read your Hidet paper, and it looks pretty powerful, but I have few questions.

Does Hidet support codegen for TensorCore right now? And can it codegen for fused operators like batchmatmul+add+reshape+transpose? This fused operator comes from bert and I want to fuse them into one complex operators to speedup the network execution.

below is my test script, although I've set the precision and mma flag, I do not find wmma instructions in generated cuda code. How can I run this operator on TensorCore? Hope for your help, Thanks

import hidet

# change the cache directory
hidet.option.cache_dir('./outs/cache')

# save the tensor program level ir in operator cache
hidet.option.save_lower_ir()


def main():
    # construct a simple graph
    x = hidet.symbol([16, 256, 512], device='cuda')
    w = hidet.randn([16, 512, 512], device='cuda')
    b = hidet.randn([512], device='cuda')
    x = hidet.ops.batch_matmul(x, w)
    x = x + b
    x = hidet.ops.reshape(x, [16, 256, 8, 64])
    x = hidet.ops.transpose(x, [0, 2, 1, 3])
    
    # x = hidet.ops.pad(x, [3, 3, 3, 3])
    # x = hidet.ops.conv2d(x, w, stride=2)
    # x = hidet.ops.relu(x)
    
    graph = hidet.trace_from(x)
    print(graph)

    # graph optimizations
    with hidet.graph.PassContext() as ctx:
        # save the computation graph level ir
        ctx.save_graph_instrument(out_dir='./outs/graphs')
        ctx.set_precision(dtype='float16')
        ctx.set_reduce_precision(dtype='float32')
        ctx.set_mma('mma')
        graph_opt = hidet.graph.optimize(graph)

    # run the optimized graph
    xx = hidet.randn([16, 256, 512], device='cuda')
    yy = graph_opt(xx)


if __name__ == '__main__':
    main()

[Bug] Stable Diffusion Compilation Errors

Describe the bug
Stable Diffusion pipeline compilation does not function properly. Even ignoring errors as described in the related issue, nvcc and python modules eventually start erroring out, and when the compilation is finally "done," the speed is the same as eager mode.

#202

To Reproduce
I posted the simple test script here: https://github.com/AlphaAtlas/Diffusion-Compilaton-Testing/blob/main/hidet_test.py

Along with a full log of the run on my machine: https://github.com/AlphaAtlas/Diffusion-Compilaton-Testing/blob/main/hidet.log

Enviroment

  • Hidet nightly (as of this post)
  • OS: Arch Linux
  • GPU: RTX 2060
  • Others: Nvidia Driver 530.41.03, Python 3.11, CUDA 12.1, Torch 2.1 Nightly

I understand diffusion and torch+cu121 is probably a work in progress. 👍 But I figured I would post my findings here anyway.

On a side note, this was tested with dynamic=False, but dynamic=True is almost a practical requirement for stable diffusion use outside of testing.

[Bug] Set resource limit throws error

Describe the bug
This is more of a python error honestly.
It is because below doesn't work on some cases.(I am root)

import resource
resource.setrlimit(resource.RLIMIT_STACK, (2**29, -1))

But I couldn't see a recommended python version & other people might also run into it.
this happens during import hidet

Setting resource limit throws

resource.setrlimit(resource.RLIMIT_STACK, (2**29, -1))

ValueError: not allowed to raise maximum limit

/opt/conda/lib/python3.10/site-packages/hidet/__init__.py:15 in <module>                         │
│                                                                                                  │
│   12 """                                                                                         │
│   13 Hidet is an open-source DNN inference framework based on compilation.                       │
│   14 """                                                                                         │
│ ❱ 15 from . import option                                                                        │
│   16 from . import ir                                                                            │
│   17 from . import backend                                                                       │
│   18 from . import utils                                                                         │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/hidet/option.py:112 in <module>                          │
│                                                                                                  │
│   109 │   )                                                                                      │
│   110                                                                                            │
│   111                                                                                            │
│ ❱ 112 register_hidet_options()                                                                   │
│   113                                                                                            │
│   114                                                                                            │
│   115 class OptionContext:                                                                       │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/hidet/option.py:61 in register_hidet_options             │
│                                                                                                  │
│    58                                                                                            │
│    59                                                                                            │
│    60 def register_hidet_options():                                                              │
│ ❱  61 │   from hidet.utils import git_utils                                                      │
│    62 │                                                                                          │
│    63 │   register_option(                                                                       │
│    64 │   │   name='bench_config',                                                               │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/hidet/utils/__init__.py:18 in <module>                   │
│                                                                                                  │
│   15 from . import netron                                                                        │
│   16 from . import transformers_utils                                                            │
│   17 from . import structure                                                                     │
│ ❱ 18 from . import stack_limit                                                                   │
│   19                                                                                             │
│   20 from .py import prod, Timer, repeat_until_converge, COLORS, get_next_file_index, factori    │
│   21 from .py import same_list, strict_zip, index_of, initialize, gcd, lcm, error_tolerance,     │
│                                                                                                  │
│ /opt/conda/lib/python3.10/site-packages/hidet/utils/stack_limit.py:19 in <module>                │
│                                                                                                  │
│   16 import resource                                                                             │
│   17                                                                                             │
│   18 # allow up to 128MB stack space                                                             │
│ ❱ 19 resource.setrlimit(resource.RLIMIT_STACK, (2**29, -1))                                      │
│   20                                                                                             │
│   21 # allow up to 10^5 recursive python calls, increase this when needed                        │
│   22 sys.setrecursionlimit(100000)                         

To Reproduce
pip install hidet
import hidet

Expected behavior
maybe better to try catch & warn.

Enviroment
Using the latest aws container with amazonaws.com/pytorch-training:2.0.0-gpu-py310

  • OS: 20.04.6 LTS (Focal Fossa)
  • Python 3.10.8
  • GPU: [e.g. RTX 3090] Not relevant
  • Others: [e.g. NVIDIA GPU Driver 525.85.12] Not relevant

Additional context

[Bug] Hidet require more GPU memory than native torch compile

Hi,
thanks for the great work again!

Describe the bug
Currently, I want to compile the LLama-7B on a Nvidia 4090 GPU using Hidet for faster inference speed.
However, I have encountered an issue where Hidet requires more GPU memory compared to the native torch.compile method.

For example, given a input_ids with shape [2, 128]. The native torch.compile can compile the model successfully and infer the output correctly, but, when using Hidet, an OOM exception is raised during the first inference step.

I am wondering why Hidet requires more GPU memory.
Do I need some additional configurations to save the GPU memory?

To Reproduce
The following script is the example code to compile the LLaMA.

import torch
from transformers import LlamaModel, LlamaConfig
import hidet

print("Initialize the model")
configuration = LlamaConfig()
model = LlamaModel(configuration).half().eval().cuda()

BATCH_SIZE = 2
SEQ_LEN = 128
input_ids = torch.zeros(2, 128, dtype=torch.long).cuda()

with torch.no_grad():
    hidet.torch.dynamo_config.use_tensor_core(True)
    hidet.torch.dynamo_config.search_space(2) 
    hidet.torch.dynamo_config.use_fp16(True)
    hidet.torch.dynamo_config.use_fp16_reduction(True)

    print("Start to compile")
    # Compile the model using Hidet
    model_opt = torch.compile(model, backend='hidet')
    #model_opt = torch.compile(model, mode="reduce-overhead")
    print("Start to inference")
    model_opt(
        input_ids=input_ids,
        output_hidden_states=True
    )

However, the current Hidet has some unsupported operations for huggingface LLaMA.
I fork the repo and add some operations based on the latest version in this branch , which can compile the model successfully.

Enviroment

  • Python Package Requirements
transformers == 4.28.1
torch == 2.0.0
sentencepiece == 0.1.99
  • System Enviroments
    • OS: Ubuntu 22.04.1 LTS
    • GPU: RTX 4090
    • CUDA Version: 11.8
    • Driver Version: 520.61.05

torch.nn.Identity are not supported yet

Describe the bug
Some errors occured when I compiled dinov2

To Reproduce

model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
model = model.to('cuda')
model_opt = torch.compile(model, backend='hidet')

Error information

torch._dynamo.exc.BackendCompilerFailed: hidet_backend raised NotImplementedError: The following modules/functions are not supported by hidet yet:
  torch.nn.Identity

Enviroment

  • OS: Ubuntu 22.04
  • GPU: A100

Additional context
Add any other context about the problem here.

[Feature] Export optimized cuda graph

Is there a way to export the optimized cuda graph to runtime library/cuda kernel and call the runtime library from host?
If there is, which APIs should I check to do this?

[Bug] cuda code compilation error

I have the following cuda code generated by hidet:

#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <hidet/runtime/cuda_context.h>
#include <hidet/runtime/cpu_context.h>
typedef float tfloat32_t;
#define __float_to_tf32(x) (x)
/*
Task(
  name: clip
  parameters:
    x: tensor(int64, [1, 4, 1, 1])
    y: tensor(int64, [1, 4, 1, 1])
  inputs: [x]
  outputs: [y]
  computations:
    x: tensor(int64, [1, 4, 1, 1])
    y: int64[1, 4, 1, 1] where y[v, v_1, v_2, v_3] = generic_min(generic_max(x[v, v_1, v_2, v_3], int64(-1)), int64(1))
  attributes: {}
)
*/
extern "C" {

__global__ void __launch_bounds__(500) hidet_compute_y(int64_t * __restrict__ x, int64_t * __restrict__ y) {
  if ((int)threadIdx.x < 4) {
    y[((int)threadIdx.x % 4)] = min(max(x[((int)threadIdx.x % 4)], -1ll), 1ll);
  }
}

__host__ void hidet_clip(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args) {
  hidet_compute_y<<<1, 500, 0, (cudaStream_t)get_cuda_stream()>>>(((int64_t*)(args[0])), ((int64_t*)(args[1])));
}

}

And nvcc raises an error:

 error: more than one instance of overloaded function "max" matches the argument list:
            function "max(long, long)"
/usr/local/cuda/bin/../targets/x86_64-linux/include/crt/math_functions.hpp(1008): here
            function "max(long, unsigned long)"
/usr/local/cuda/bin/../targets/x86_64-linux/include/crt/math_functions.hpp(1043): here
            function "max(long long, long long)"
/usr/local/cuda/bin/../targets/x86_64-linux/include/crt/math_functions.hpp(1077): here
            function "max(unsigned long long, long long)"
/usr/local/cuda/bin/../targets/x86_64-linux/include/crt/math_functions.hpp(1092): here
            argument types are: (int64_t, long long)

BART, Pegasus, GPT2 model benchmarks are slower compared to vanilla ORT

Hey @yaoyaoding!
First of all, amazing work with Hidet!
I have recently been experimenting with hidet to see if it can outperform ORT.
Surprisingly, ORT with IO binding on an ONNX graph(BART, Pegasus, GPT2) without any graph optimisations outperforms the hidet's optimised flow graph even with a search space 2. (on Nvidia A100)
Did you previously run any benchmark comparisons between hidet and ORT? I would love to help debug this!

Also, I have experimented with transformer-deploy, which performs better than vanilla ORT and hidet. Replicating optimisations from transformer-deploy is a good next step. I would love to help with this as well!

[Bug] softmax() got an unexpected keyword argument '_stacklevel'

Describe the bug
I try to optimize the HuggingFace Bert model with hidet in Pytorch. It reports the following error:

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Compiling cuda task full(shape=[], value=-10000.0, dtype=float32)...
Compiling cuda task slice(data=int64[1, 512])...
Compiling cuda task take(data=float32[512, 768], indices=int64[1, 12])...
Compiling cuda task rearrange(x=float32[768, 768])...
Compiling cuda task full(shape=[], value=8.0, dtype=float32)...
Traceback (most recent call last):
File "bert_test.py", line 29, in
output = model_opt(**encoded_input)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 111, in call
return self.dynamo_ctx(self._orig_mod.call)(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 247, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 394, in catch_errors
return callback(frame, cache_size, hooks)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 446, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 113, in _fn
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 293, in _convert_frame_assert
return _compile(
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 169, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 358, in _compile
out_code = transform_code_object(code, transform)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/bytecode_transformation.py", line 683, in transform_code_object
transformations(instructions, code_options)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/convert_frame.py", line 345, in transform
tracer.run()
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1890, in run
super().run()
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 609, in run
and self.step()
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 569, in step
getattr(self, inst.opname)(inst)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/symbolic_convert.py", line 1977, in RETURN_VALUE
self.output.compile_subgraph(
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 648, in compile_subgraph
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 694, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/utils.py", line 169, in time_wrapper
r = func(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 776, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/output_graph.py", line 772, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/debug_utils.py", line 1093, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/wuruofan/workspace/hidet-main/python/hidet/graph/frontend/torch/dynamo_backends.py", line 140, in hidet_backend
output = interpreter(*symbolic_inputs)
File "/home/wuruofan/workspace/hidet-main/python/hidet/graph/frontend/torch/interpreter.py", line 152, in call
return self.forward(*args)
File "/home/wuruofan/workspace/hidet-main/python/hidet/graph/frontend/torch/interpreter.py", line 262, in forward
self._raise_exception(e, hidet_func, hidet_args, hidet_kwargs)
File "/home/wuruofan/workspace/hidet-main/python/hidet/graph/frontend/torch/interpreter.py", line 218, in _raise_exception
raise type(exception)(
torch._dynamo.exc.BackendCompilerFailed: backend='hidet_backend' raised:
TypeError: softmax() got an unexpected keyword argument '_stacklevel', occurred when calling softmax with
args: (<hidet.Tensor object at 0x7f4a9850d370>, -1)
kwargs: {'_stacklevel': 5}
softmax is defined at
File "/home/wuruofan/workspace/hidet-main/python/hidet/graph/frontend/torch/register_functions.py", line 208

You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True

To Reproduce
from transformers import BertTokenizer, BertModel, BertConfig
import torch

config = BertConfig(num_hidden_layers=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased").to(device="cuda:0")

model_opt = torch.compile(model, backend='hidet')

text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt').to(device="cuda:0")
output = model_opt(**encoded_input)

Expected behavior
It seems that there is some problem with softmax op, do you have a good solution for it? Or is there any other way I can run the Bert model with torch dynamo?

Enviroment

  • Python verison: 3.8.10
  • Hidet version: 0.2.2.dev
  • PyTorch version: 2.1.0+cu118
  • Numpy version: 1.23
  • OS: Ubuntu 22.04.2 LTS
  • GPU: NVIDIA A10
  • GPU driver: 525.85.12

[Enquiry] developing Flash Attention Transformer example using Hidet

Hello guys, really appreciate your work on Hidet. It is an awesome tool and it really makes developer's life easier when writing custom schedule for their CUDA kernel for performance optimization👍👍!

To test on Hidet's features, I am currently writing an example of the Flash Attention Transformer (link to research work: https://arxiv.org/abs/2205.14135) using the Hidet tool stack. I have writteb my custom testing setup (which contains my own host/device memory allocation & performance tracking & precision comparison code) in my "flash_attention_main.cu", and I am trying to call the kernel functions in Hidet generated cuda dynamic library.

May I know if there is a standard way of doing this? I tried using "dlopen" to load the library and launch the kernel functions but unfortunately it is not working properly. I therefore just manually copied the Hidet generated cuda source code to two separate header files "flash_attention_kernel_func.h" and "normal_transformer_kernel_func.h" and include them in my "flash_attention_main.cu". And I directly compile "flash_attention_main.cu" and everything works properly as well.

Let me share some source code below for illustration.

Here is my flash_attention_example.py, which includes the flash attention custom schedule and the normal approach.

import os
import math
import time
import numpy as np
import torch
import torch.nn as nn
torch.manual_seed(123)

# NOTE: this script is a simplified implementation of the following research work using Hidet
# Dao, T., Fu, D., Ermon, S., Rudra, A., & Ré, C. (2022). Flashattention: Fast and memory-efficient exact attention with io-awareness. Advances in Neural Information Processing Systems, 35, 16344-16359.
# link to paper: https://arxiv.org/abs/2205.14135

import hidet
from hidet.ir.compute import compute, reduce
from hidet.ir.task import Task
from hidet.ir.func import IRModule
from hidet.ir.primitives.cuda.atomic import atomic_add
from hidet.lang import f16, spatial, repeat, tensor, attr, grid, printf
from hidet.lang.cuda import blockIdx, threadIdx, syncthreads
from hidet.graph.ops.definitions.utils import input_like
from hidet.ir.expr import cast, address
from hidet.ir.primitives import exp, max, printf

# define Flash Attention Task
class FlashAttentionTask(Task):

    def allow_epilogue(self) -> bool:
        return False

    def flash_attention_implement_cuda(self, working_dir: str) -> IRModule:
        # override this method to use template-based scheduling
        return flash_attention_schedule(self)
    
    # Require: Matrices Q�K�V Nxd in HBM, on-chip SRAM of size M.
    # NOTE: typical SRAM size 100 kB, default to 48 kB
    # NOTE: max thread num is set to 1024
    def __init__(self,N=512,d=128,H=16,B=1,M=48*1024,ratio=12,max_thread_num=1024,disable_flash_attention=False):

        # 1. set block sizes Bc = ceil(M/(4d)), Br = min(M/(4d),d)
        Bc = math.ceil(M/(ratio*d))
        Br = min(math.ceil(M/(ratio*d)),d)
        Tr = math.ceil(N/Br)
        Tc = math.ceil(N/Bc)
        GLOBAL_Q = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_Q')
        GLOBAL_K = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_K')
        GLOBAL_V = input_like(hidet.randn([N, d], dtype='float16', device='cuda'),name='GLOBAL_V')
        
        def normal_transformer():
            matmulQK = compute(
                    name = 'GLOBAL_QK',
                    shape = [N, N],
                    fcompute = lambda i, j: reduce(
                        shape=[d],
                        fcompute=lambda k: GLOBAL_Q[i, k] * GLOBAL_K[j, k],
                        reduce_type='sum',
                    )
                )

            max_val = lambda i : reduce(shape=[N], fcompute=lambda j: matmulQK[i,j], reduce_type='max')
            S = compute(
                    name = 'S',
                    shape = [N, N],
                    fcompute = lambda i,j: matmulQK[i,j] - max_val(i)
                )
            exp_s = compute(
                    name = 'exp_s',
                    shape = [N, N],
                    fcompute = lambda i,j: exp(S[i,j])
                )
            exp_sum = lambda i : reduce(shape=[N], fcompute=lambda j: exp_s[i,j], reduce_type='sum')
            softmax = compute('softmax', shape=[N,N], fcompute=lambda i,j: exp_s[i,j] / exp_sum(i))
            matmulPV = compute(
                    name = 'GLOBAL_O',
                    shape = [N, d],
                    fcompute = lambda i, j: reduce(
                        shape=[N],
                        fcompute=lambda k: softmax[i, k] * GLOBAL_V[k, j],
                        reduce_type='sum',
                    )
                )
            return matmulPV
        
        super().__init__(
            name='flash_attention_task',
            inputs=[GLOBAL_Q,GLOBAL_K,GLOBAL_V],
            outputs=[normal_transformer()],
            attributes={
                'B' : B,
                'H' : H,
                'N' : N,
                'd' : d,
                'Bc' : Bc,
                'Br' : Br,
                'Tc' : Tc,
                'Tr' : Tr,
                'BLK' : Tr,
                'THD' : Br * Bc,
                'MAX_THD' : max_thread_num
            },
        )
        if not disable_flash_attention:
            self.implement_cuda = self.flash_attention_implement_cuda
            self.define = "-DRUN_FLASH_ATTN"
        else:
            self.define = ""

# define custom schedule
def flash_attention_schedule(task:FlashAttentionTask) -> IRModule:
    
    print_debug = False

    B = task.attrs['B']
    H = task.attrs['H']
    N = task.attrs['N']
    d = task.attrs['d']
    Bc = task.attrs['Bc']
    Br = task.attrs['Br']
    Tr = task.attrs['Tr']
    Tc = task.attrs['Tc']

    dims = ( task.attrs['BLK'] )
    threads = task.attrs['THD']
    assert threads <= task.attrs['MAX_THD'], f'err: {threads} not < {task.attrs["MAX_THD"]}'
    assert d % Bc == 0, f'err: Bc is not divisible by d'
    assert d % Br == 0, f'err: Br is not divisible by d'


    largest_fp16_value = 65504

    print(f'task.attrs {task.attrs}')
    
    
    # define the tensor program
    with hidet.script_module() as module:
        """Flash attention kernel."""

        @hidet.script
        def QK_matmul_compute(A:f16[Br,d],B:f16[d,Bc],C:f16[Br,Bc]):
            for m,n in spatial(Br,Bc).on(threadIdx.x):
                C[m,n] = 0.0
            syncthreads()
            for m,k,n in spatial(Br,1,Bc).repeat(1,d,1).on(threadIdx.x):   
                atomic_add(~C[m,n],A[m,k] * B[k,n])
            syncthreads()

        @hidet.script
        def PV_matmul_compute(A:f16[Br,Bc],B:f16[Bc,d],C:f16[Br,d]):
            for m,n in spatial(Br,Bc).repeat(1,d//Bc).on(threadIdx.x):
                C[m,n] = 0.0
            syncthreads()
            for m,k,n in spatial(Br,1,Bc).repeat(1,Bc,d//Bc).on(threadIdx.x):   
                atomic_add(~C[m,n],A[m,k] * B[k,n])
            syncthreads()

        @hidet.script
        def rowmax_compute(A:f16[Br,Bc],M:f16[Br],T:f16[Br,Bc]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                T.write([i,j],A[i,j],protected=True)
            syncthreads()

            for i,j in spatial(Br,Bc).on(threadIdx.x):
                k = 1
                while k < Bc:
                    if j % (k*2) == 0:
                        T.write([i,j],max(T[i,j],T[i,j+k]),protected=True)
                    syncthreads()
                    k *= 2

            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    M[i] = T[i,0]
            syncthreads()

        @hidet.script
        def rowsum_compute(A:f16[Br,Bc],L:f16[Br],T:f16[Br,Bc]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                T.write([i,j],A[i,j],protected=True)
            syncthreads()

            for i,j in spatial(Br,Bc).on(threadIdx.x):
                k = 1
                while k < Bc:
                    if j % (k*2) == 0:
                        T.write([i,j],(T[i,j]+T[i,j+k]),protected=True)
                    syncthreads()
                    k *= 2

            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    L[i] = T[i,0]
            syncthreads()

        @hidet.script
        def local_softmax_compute(S:f16[Br,Bc],M:f16[Br]):
            for i,j in spatial(Br,Bc).on(threadIdx.x):
                if False and blockIdx.x==0:
                    printf("S[i,j] before %d %d %d %d : %f - %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"),cast(M[i],"float32"))
                S[i,j] = exp(S[i,j] - M[i])
                if False and blockIdx.x==0:
                    printf("S[i,j] %d %d %d %d : %f\n",blockIdx.x,threadIdx.x,i,j,cast(S[i,j],"float32"))
            syncthreads()
        
        @hidet.script
        def local_update_compute(M:f16[Br],M_new:f16[Br],M_local:f16[Br],L:f16[Br],L_new:f16[Br],L_local:f16[Br]):
            for i in spatial(Br).on(threadIdx.x):
                if threadIdx.x < Br:
                    M_new[i] = max(M[i],M_local[i])
                    L_new[i] = exp(M[i] - M_new[i]) * L[i] + exp(M_local[i] - M_new[i]) * L_local[i]
            syncthreads()

        @hidet.script
        def global_update_compute(PV:f16[Br,d],O:f16[Br,d],M_local:f16[Br],M_new:f16[Br],M:f16[Br],L_new:f16[Br],L:f16[Br]):
            for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                O.write(
                    [i,j],
                    ((L_new[i]**-1) * (L[i]*exp(M[i]-M_new[i])) * O[i,j]) + (exp(M_local[i]-M_new[i]) * PV[i,j]),
                    protected=True
                )
            syncthreads()

        @hidet.script
        def flash_attention_kernel(
            Q: f16[N,d],
            K: f16[N,d],
            V: f16[N,d],
            O: f16[N,d]
        ):
            
            attr.cuda_grid_dim = dims
            attr.cuda_block_dim = threads

            # Init O=(0), N x d in HBM
            for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                offset_i = blockIdx.x * (Br)
                O[offset_i:,:].write([i,j], 0, protected=True)
            syncthreads()

            smem_q = tensor('shared', 'float16', [Br, d])
            smem_k = tensor('shared', 'float16', [d, Bc]) # transposed
            smem_v = tensor('shared', 'float16', [Bc, d])
            smem_o = tensor('shared', 'float16', [Br, d])
            
            smem_l = tensor('shared', 'float16', [Br])
            smem_l_local = tensor('shared', 'float16', [Br])
            smem_l_new = tensor('shared', 'float16', [Br])
            smem_m = tensor('shared', 'float16', [Br])
            smem_m_local = tensor('shared', 'float16', [Br])
            smem_m_new = tensor('shared', 'float16', [Br])
            smem_sp = tensor('shared', 'float16', [Br,Bc])
            smem_pv = tensor('shared', 'float16', [Br,d])
            smem_temp = tensor('shared', 'float16', [Br,Bc])

            for a,b in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                # load Qi from HBM to on-chip SRAM
                # initialization of o,l,m
                offset_i = blockIdx.x * (Br)
                smem_q[a,b] = Q[offset_i:,:].read([a,b],protected=True)
                smem_o[a,b] = 0
                smem_l[a] = 0
                smem_m[a] = -largest_fp16_value
            syncthreads()

            if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                idx = 0
                for i,j in grid(Br,d):
                    printf("idx: %d, Q val: %f\n",idx,cast(smem_q[i,j],"float32"))
                    idx += 1
            syncthreads()

            for j in grid(Tc):

                for a,b in spatial(Bc,Br).repeat(1,(d//Br)).on(threadIdx.x):
                    # load Kj,Vj from HBM to on-chip SRAM
                    offset_j = j * (Bc)
                    smem_k[b,a] = K[offset_j:,:].read([a,b],protected=True)
                    smem_v[a,b] = V[offset_j:,:].read([a,b],protected=True)
                syncthreads()
                
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(d,Bc):
                        printf("idx: %d, K val: %f\n",idx,cast(smem_k[i,j],"float32"))
                        idx += 1
                    for i,j in grid(Bc,d):
                        printf("idx: %d, V val: %f\n",idx,cast(smem_v[i,j],"float32"))
                        idx += 1
                syncthreads()
                
                # on chip, compute Sij = Qi @ (Kj)^T, Br X Bc
                QK_matmul_compute(smem_q,smem_k,smem_sp)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,Bc):
                        printf("idx: %d, S val: %f\n",idx,cast(smem_sp[i,j],"float32"))
                        idx += 1
                syncthreads()
                
                # on chip, compute m'_ij = rowmax(Sij), Br; Pij = exp(Sij - m'_ij), Br x Bc (pointwise); l'_ij = rowsum(P'ij), Br
                rowmax_compute(smem_sp,smem_m_local,smem_temp)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d M val: %f\n",i,cast(smem_m_local[i],"float32"))
                        # for j in grid(Bc):
                        #     printf("j: %d, S val: %f\n",j,cast(smem_sp[i,j],"float32"))
                syncthreads()
                
                local_softmax_compute(smem_sp,smem_m_local)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,Bc):
                        printf("idx: %d, P val: %f\n",idx,cast(smem_sp[i,j],"float32"))
                        idx += 1
                syncthreads()


                rowsum_compute(smem_sp,smem_l_local,smem_temp)

                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d L val: %f\n",i,cast(smem_l_local[i],"float32"))
                        # for j in grid(Bc):
                        #     printf("j: %d, P val: %f\n",j,cast(smem_sp[i,j],"float32"))
                syncthreads()

                
                # on chip, compute m_new_i = max(m_i,m'_ij), Br; l_new_i = e^(m_i - m_new_i) * l_i + e^(m'_ij - m_i_new) * l'_ij, Br
                local_update_compute(smem_m,smem_m_new,smem_m_local,smem_l,smem_l_new,smem_l_local)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    for i in grid(Br):
                        printf("i: %d smem_m val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_m_new val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_m_local val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l_new val: %f\n",i,cast(smem_m[i],"float32"))
                        printf("i: %d smem_l_local val: %f\n",i,cast(smem_m[i],"float32"))
                syncthreads()
                # write Oi = diag(l_i_new)^-1 * (diag(l_i)*e^(m_i-m_i_new) @ Oi + e^*m'_ij-m_i_new*(P'ij @ Vj))

                PV_matmul_compute(smem_sp,smem_v,smem_pv)
                if print_debug and (blockIdx.x==0 and threadIdx.x==0):
                    idx = 0
                    for i,j in grid(Br,d):
                        printf("idx: %d, PV val: %f\n",idx,cast(smem_pv[i,j],"float32"))
                        idx += 1
                syncthreads()

                global_update_compute(smem_pv,smem_o,smem_m_local,smem_m_new,smem_m,smem_l_new,smem_l)

                if j + 1 == Tc:
                    for i,j in spatial(Br,Bc).repeat(1,(d//Bc)).on(threadIdx.x):
                        offset_i = blockIdx.x * (Br)
                        O[offset_i:,:].write([i,j], smem_o[i,j], protected=True)
                    syncthreads()

                # write l_i = l_i_new, m_i = m_i_new
                for i in spatial(Br).on(threadIdx.x):
                    if threadIdx.x < Br:
                        smem_m[i] = smem_m_new[i]
                        smem_l[i] = smem_l_new[i]
                syncthreads()

            if print_debug and (blockIdx.x==15 and threadIdx.x==0):
                idx = 0
                for i,j in grid(Br,d):
                    offset_i = blockIdx.x * (Br)
                    printf("blockIdx %d : output idx: %d, val: %f\n",blockIdx.x,idx,cast(O[offset_i+i,j],"float32"))
                    idx += 1
            syncthreads()
            return

        @hidet.script
        def flash_attention_launch_func( 
            G_Q: f16[B, H, N, d],
            G_K: f16[B, H, N, d],
            G_V: f16[B, H, N, d],
            G_O: f16[B, H, N, d]
        ):
            # NOTE: this section needs to be written in flash_attention_main.cu
            for b,h in grid(B,H):
                flash_attention_kernel(
                    address(G_Q[b,h,0,0]),
                    address(G_K[b,h,0,0]),
                    address(G_V[b,h,0,0]),
                    address(G_O[b,h,0,0])
                )
            
    # build ir module
    ir_module = module.ir_module()
    return ir_module

# gen Python gold data as reference
def gen_gold(attrs,r1=-3,r2=3):

    Q = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    K = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    V = torch.FloatTensor(attrs['B'],attrs['H'],attrs['N'],attrs['d']).uniform_(r1, r2).half()
    t = time.process_time()
    Q.half().numpy().tofile('mat_Q.bin')
    K.half().numpy().tofile('mat_K.bin')
    V.half().numpy().tofile('mat_V.bin')
    S = torch.from_numpy(Q.numpy() @ torch.transpose(K, -2, -1).numpy())

    row_max, _ = torch.max(S,dim=-1)
    S = torch.from_numpy(np.exp((S - row_max.reshape(attrs['B'],attrs['H'],attrs['N'],1)).numpy()))
    row_sum = torch.sum(S,dim=-1).reshape(attrs['B'],attrs['H'],attrs['N'],1)
    P = S / row_sum

    # TODO: test with softmax float precision
    # P = nn.Softmax(dim=-1)(S.float()).half()
    O = torch.from_numpy(P.numpy() @ V.numpy())
    elapsed_time = (time.process_time() - t)*1000
    print(f"Python gold gen run elapsed time {round(elapsed_time,3)} msec")
    O.half().numpy().tofile('gold_mat_O.bin')

# run task
def run_task(disable_flash_attention=False):
    # define the task here
    flash_attention_task = FlashAttentionTask(disable_flash_attention=disable_flash_attention)
    # build the task
    ret = flash_attention_task.build(target='cuda')

    # copy source file and lib to current directory
    source_path = ret.src_path
    library_path = ret.lib_path
    print(f'source_path {source_path} library_path {library_path}')

    import shutil
    shutil.copyfile(source_path,os.path.join("./","flash_attention_"+os.path.basename(source_path)))
    shutil.copyfile(library_path,os.path.join("./","flash_attention_"+os.path.basename(library_path)))

    # generate golden data
    gen_gold(flash_attention_task.attrs)

    def exe_f(command='', shell=True):
        print(f'running {command}')
        import subprocess
        process = subprocess.Popen(command, shell=shell)
        code = process.wait()
        process.communicate()
        return code
    
    # launch testcase flash_attention_main.cu
    HIDET_CUDA_INCLUDE_PATH = "../cuda-samples-master/Common/"
    CUDA_SAMPLES_INCLUDE_PATH = "../../include/"
    ret = exe_f(f'nvcc flash_attention_main.cu {flash_attention_task.define} -gencode arch=compute_86,code=sm_86 -I {CUDA_SAMPLES_INCLUDE_PATH} -I {HIDET_CUDA_INCLUDE_PATH} -std=c++11 -o fa.out && ./fa.out -BATCH={flash_attention_task.attrs["B"]} -HEAD={flash_attention_task.attrs["H"]} -BLK={flash_attention_task.attrs["BLK"]} -THD={flash_attention_task.attrs["THD"]}')
    print('test done' if ret==0 else 'test error')

# main function
if __name__ == '__main__':
    # normal approach execution
    run_task(disable_flash_attention=True)
    # flash attention approach execution
    run_task(disable_flash_attention=False)

Here is my flash_attention_main.cu, which includes the performance tracking, precision comparison & memory allocation operations, and it lauches the test kernels.

// System includes
#include <stdio.h>
#include <sys/stat.h>
#include <dlfcn.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <assert.h>
#include <vector>

// CUDA runtime
#include <cuda_runtime.h>
#include <cuda_profiler_api.h>

// Helper functions and utilities to work with CUDA,
#include <helper_functions.h>
#include <helper_cuda.h>
#include <cuda_fp16.h>

// Import kernel functions
#include "flash_attention_kernel_func.h"
#include "normal_transformer_kernel_func.h"


// test function, execute kernel, compare with gold data
int flash_attention_test(
    unsigned int B, unsigned int H,
    unsigned int block_size, unsigned int thread_size,
    half *h_Q, unsigned int size_Q,
    half *h_K, unsigned int size_K,
    half *h_V, unsigned int size_V,
    half *h_gold_O, unsigned int size_O)
{

    cudaStream_t stream;
    const unsigned int BH = B * H;
    // Allocate device memory
    half *d_Q, *d_K, *d_V, *d_O, *h_O;
    checkCudaErrors(cudaMallocHost(&h_O, size_O * sizeof(half)));

    if (h_O == NULL)
    {
        fprintf(stderr, "Failed to allocate host matrix O!\n");
        exit(EXIT_FAILURE);
    }

    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_Q), size_Q * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_K), size_K * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_V), size_V * sizeof(half)));
    checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&d_O), size_O * sizeof(half)));
    // Allocate CUDA events that we'll use for timing
    cudaEvent_t start, stop;
    checkCudaErrors(cudaEventCreate(&start));
    checkCudaErrors(cudaEventCreate(&stop));

    checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));

    // copy host memory to device
    checkCudaErrors(
        cudaMemcpyAsync(d_Q, h_Q, size_Q * sizeof(half), cudaMemcpyHostToDevice, stream));
    checkCudaErrors(
        cudaMemcpyAsync(d_K, h_K, size_K * sizeof(half), cudaMemcpyHostToDevice, stream));
    checkCudaErrors(
        cudaMemcpyAsync(d_V, h_V, size_V * sizeof(half), cudaMemcpyHostToDevice, stream));

    const unsigned int k_size_Q = (size_Q / BH);
    const unsigned int k_size_K = (size_K / BH);
    const unsigned int k_size_V = (size_V / BH);
    const unsigned int k_size_O = (size_O / BH);
    printf("k_size_Q %u k_size_K %u k_size_V %u k_size_O %u\n", k_size_Q, k_size_K, k_size_V, k_size_O);

    // Record the start event
    checkCudaErrors(cudaEventRecord(start, stream));

    const int32_t num_args = 4;

    for (unsigned int b = 0; b < B; b++)
    {
        for (unsigned int h = 0; h < H; h++)
        {
            unsigned int offset_index = (b * H) + h;

            half *param[num_args] = {
                d_Q + offset_index * k_size_Q,
                d_K + offset_index * k_size_K,
                d_V + offset_index * k_size_V,
                d_O + offset_index * k_size_O};

#ifdef RUN_FLASH_ATTN
            // run flash attention kernel
            flash_attention_kernel<<<dim3(16, 1, 1), dim3(1024, 1, 1), 0, (cudaStream_t)stream>>>(((half *)(param[0])), ((half *)(param[1])), ((half *)(param[2])), ((half *)(param[3])));
#else
            // run normal transformer kernel
            uint8_t *buffer;

            checkCudaErrors(cudaMalloc(reinterpret_cast<void **>(&buffer), int64_t(2097152ll)));

            half *GLOBAL_QK = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(0ll) * ((int64_t)(1))))]));
            half *S = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(524288ll) * ((int64_t)(1))))]));
            half *exp_s = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(1048576ll) * ((int64_t)(1))))]));
            half *softmax = ((half *)(&buffer[(((int64_t)(0)) + (int64_t(1572864ll) * ((int64_t)(1))))]));

            hidet_compute_GLOBAL_QK<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(param[0], param[1], GLOBAL_QK);
            hidet_compute_S<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(GLOBAL_QK, param[0], param[1], S);
            hidet_compute_exp_s<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(S, exp_s);
            hidet_compute_softmax<<<dim3(512, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(S, param[0], param[1], GLOBAL_QK, exp_s, softmax);
            hidet_compute_GLOBAL_O<<<dim3(128, 1, 1), dim3(512, 1, 1), 0, (cudaStream_t)stream>>>(softmax, ((half *)(param[2])), param[0], param[1], GLOBAL_QK, S, exp_s, ((half *)(param[3])));
#endif // RUN_FLASH_ATTN

        }
    }

    checkCudaErrors(cudaStreamSynchronize(stream));

    // Record the stop event
    checkCudaErrors(cudaEventRecord(stop, stream));
    printf("test done !!!\n");

    // Wait for the stop event to complete
    checkCudaErrors(cudaEventSynchronize(stop));

    float msecTotal = 0.0f;
    checkCudaErrors(cudaEventElapsedTime(&msecTotal, start, stop));

    // Compute and print the performance
#if RUN_FLASH_ATTN
    printf("flash attention elapsed time = %.3f msec\n", msecTotal);
#else
    printf("normal approach elapsed time = %.3f msec\n", msecTotal);
#endif // RUN_FLASH_ATTN
    // Copy result from device to host
    checkCudaErrors(
        cudaMemcpyAsync(h_O, d_O, size_O * sizeof(half), cudaMemcpyDeviceToHost, stream));
    checkCudaErrors(cudaStreamSynchronize(stream));

    printf("Checking computed result for correctness: \n");

    double eps = 0.01; // 1% error with python output
    const unsigned int max_print_count = 100;
    uint32_t total_count = 0;
    uint32_t total_err_count = 0;
    for (int i = 0; i < static_cast<int>(size_O); i++)
    {
        double gold_val = fabs((double)h_gold_O[i]);
        double abs_val = fabs((double)h_O[i]);
        double abs_err = fabs(abs_val - gold_val);
        double rel_err = abs_err / abs_val;

        if (rel_err > eps)
        {
            if (total_err_count < max_print_count)
                printf("Error! Matrix[%05d]=%.8f, ref=%.8f error term %E is > %E\n",
                       i, (double)h_O[i], (double)h_gold_O[i], rel_err, eps);
            total_err_count++;
        }
        total_count++;
    }
    double error_ratio = (double)total_err_count / (double)total_count;
    bool correct = error_ratio < eps;
    printf("total count %u total error count %u (%.8f %%)\n", total_count, total_err_count, error_ratio * 100);
    printf("%s\n", correct ? "Result = PASS" : "Result = FAIL");

    // Clean up memory
    checkCudaErrors(cudaFree(d_Q));
    checkCudaErrors(cudaFree(d_K));
    checkCudaErrors(cudaFree(d_V));
    checkCudaErrors(cudaFree(d_O));
    checkCudaErrors(cudaEventDestroy(start));
    checkCudaErrors(cudaEventDestroy(stop));

    if (correct)
    {
        return EXIT_SUCCESS;
    }
    else
    {
        return EXIT_FAILURE;
    }
}

inline bool file_exists(const std::string &name)
{
    struct stat buffer;
    return (stat(name.c_str(), &buffer) == 0);
}

void load_data(std::vector<half> &matrix, const std::string bin_file)
{
    printf("loading %s\n", bin_file.c_str());
    assert(file_exists(bin_file) && "Error! binary file doesn't exist");

    std::ifstream fin(bin_file, std::ios::binary);
    half elem;
    while (fin.read(reinterpret_cast<char *>(&elem), sizeof(half)))
    {
        matrix.push_back(elem);
    }
}

int main(int argc, char **argv)
{
    printf("[Flash Attention Using CUDA] - Starting...\n");

    if (checkCmdLineFlag(argc, (const char **)argv, "help") ||
        checkCmdLineFlag(argc, (const char **)argv, "?"))
    {

        printf("Usage -device=n (n >= 0 for deviceID)\n");
        printf("      -BATCH=number of Batch\n");
        printf("      -HEAD=number of Head\n");
        printf("      -BLK=block size\n");
        printf("      -THD=thread size\n");
        exit(EXIT_SUCCESS);
    }

    // This will pick the best possible CUDA capable device, otherwise
    // override the device ID based on input provided at the command line
    int dev = findCudaDevice(argc, (const char **)argv);

    unsigned int batch = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "BATCH"))
    {
        batch = getCmdLineArgumentInt(argc, (const char **)argv, "BATCH");
    }
    unsigned int head = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "HEAD"))
    {
        head = getCmdLineArgumentInt(argc, (const char **)argv, "HEAD");
    }
    unsigned int block_size = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "BLK"))
    {
        block_size = getCmdLineArgumentInt(argc, (const char **)argv, "BLK");
    }
    unsigned int thread_size = 1;
    if (checkCmdLineFlag(argc, (const char **)argv, "THD"))
    {
        thread_size = getCmdLineArgumentInt(argc, (const char **)argv, "THD");
    }

    // load Q
    std::vector<half> mat_Q;
    load_data(mat_Q, "./mat_Q.bin");

    // load K
    std::vector<half> mat_K;
    load_data(mat_K, "./mat_K.bin");

    // load V
    std::vector<half> mat_V;
    load_data(mat_V, "./mat_V.bin");

    // load golden data O
    std::vector<half> gold_mat_O;
    load_data(gold_mat_O, "./gold_mat_O.bin");

    printf("batch %u head %u block_size %u thread_size %u\n", batch, head, block_size, thread_size);

    printf("Q size %lu K size %lu V size %lu O size %lu\n", mat_Q.size(), mat_K.size(), mat_V.size(), gold_mat_O.size());

    checkCudaErrors(cudaProfilerStart());
    int result = flash_attention_test(
        batch, head, block_size, thread_size,
        &mat_Q[0], mat_Q.size(),
        &mat_K[0], mat_K.size(),
        &mat_V[0], mat_V.size(),
        &gold_mat_O[0], gold_mat_O.size());
    checkCudaErrors(cudaProfilerStop());

    exit(result);
}

Here are the flash_attention_kernel_func.h and normal_transformer_func.h, respectively.

// flash_attention_kernel_func.h
__global__ void __launch_bounds__(1024) flash_attention_kernel(half *__restrict__ Q, half *__restrict__ K, half *__restrict__ V, half *__restrict__ O)
{
    for (int32_t i = 0; (i < 4); i = (i + 1))
    {
        O[(((((int)blockIdx.x * 32) + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i))] = ((half)(0));
    }
    __syncthreads();
    __shared__ half smem_q[4096];
    __shared__ half smem_k[4096];
    __shared__ half smem_v[4096];
    __shared__ half smem_o[4096];
    __shared__ half smem_l[32];
    __shared__ half smem_l_local[32];
    __shared__ half smem_l_new[32];
    __shared__ half smem_m[32];
    __shared__ half smem_m_local[32];
    __shared__ half smem_m_new[32];
    __shared__ half smem_sp[1024];
    __shared__ half smem_pv[4096];
    __shared__ half smem_temp[1024];
    for (int32_t i_1 = 0; (i_1 < 4); i_1 = (i_1 + 1))
    {
        smem_q[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_1))] = Q[(((((int)blockIdx.x * 32) + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_1))];
        smem_o[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_1))] = ((half)(0));
        smem_l[((int)threadIdx.x / 32)] = ((half)(0));
        smem_m[((int)threadIdx.x / 32)] = ((half)((-65504)));
    }
    __syncthreads();
    __syncthreads();
    for (int32_t j = 0; (j < 16); j = (j + 1))
    {
        for (int32_t i_2 = 0; (i_2 < 4); i_2 = (i_2 + 1))
        {
            int32_t offset_j = (j * 32);
            smem_k[((((((int)threadIdx.x % 32) * 4) + i_2) * 32) + ((int)threadIdx.x / 32))] = K[(((offset_j + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_2))];
            smem_v[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_2))] = V[(((offset_j + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_2))];
        }
        __syncthreads();
        __syncthreads();
        half *A = smem_q;
        half *B = smem_k;
        half *C = smem_sp;
        C[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = ((half)(0.0f));
        __syncthreads();
        for (int32_t i_3 = 0; (i_3 < 128); i_3 = (i_3 + 1))
        {
            atomicAdd(&C[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))], (A[((((int)threadIdx.x / 32) * 128) + i_3)] * B[((i_3 * 32) + ((int)threadIdx.x % 32))]));
        }
        __syncthreads();
        __syncthreads();
        half *A_1 = smem_sp;
        half *M = smem_m_local;
        half *T = smem_temp;
        T[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = A_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))];
        __syncthreads();
        int32_t k = 1;
        while ((k < 32))
        {
            if ((((int)threadIdx.x % 32) % (k * 2)) == 0)
            {
                T[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = __hmax(T[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))], T[((((int)threadIdx.x / 32) * 32) + (((int)threadIdx.x % 32) + k))]);
            }
            __syncthreads();
            k = (k * 2);
        }
        if ((int)threadIdx.x < 32)
        {
            M[((int)threadIdx.x % 32)] = T[(((int)threadIdx.x % 32) * 32)];
        }
        __syncthreads();
        __syncthreads();
        half *S = smem_sp;
        half *M_1 = smem_m_local;
        S[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = hexp((S[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] - M_1[((int)threadIdx.x / 32)]));
        __syncthreads();
        __syncthreads();
        half *A_2 = smem_sp;
        half *L = smem_l_local;
        half *T_1 = smem_temp;
        T_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = A_2[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))];
        __syncthreads();
        int32_t k_1 = 1;
        while ((k_1 < 32))
        {
            if ((((int)threadIdx.x % 32) % (k_1 * 2)) == 0)
            {
                T_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] = (T_1[((((int)threadIdx.x / 32) * 32) + ((int)threadIdx.x % 32))] + T_1[((((int)threadIdx.x / 32) * 32) + (((int)threadIdx.x % 32) + k_1))]);
            }
            __syncthreads();
            k_1 = (k_1 * 2);
        }
        if ((int)threadIdx.x < 32)
        {
            L[((int)threadIdx.x % 32)] = T_1[(((int)threadIdx.x % 32) * 32)];
        }
        __syncthreads();
        __syncthreads();
        half *M_2 = smem_m;
        half *M_new = smem_m_new;
        half *M_local = smem_m_local;
        half *L_1 = smem_l;
        half *L_new = smem_l_new;
        half *L_local = smem_l_local;
        if ((int)threadIdx.x < 32)
        {
            M_new[((int)threadIdx.x % 32)] = __hmax(M_2[((int)threadIdx.x % 32)], M_local[((int)threadIdx.x % 32)]);
            L_new[((int)threadIdx.x % 32)] = ((hexp((M_2[((int)threadIdx.x % 32)] - M_new[((int)threadIdx.x % 32)])) * L_1[((int)threadIdx.x % 32)]) + (hexp((M_local[((int)threadIdx.x % 32)] - M_new[((int)threadIdx.x % 32)])) * L_local[((int)threadIdx.x % 32)]));
        }
        __syncthreads();
        __syncthreads();
        half *A_3 = smem_sp;
        half *B_1 = smem_v;
        half *C_1 = smem_pv;
        for (int32_t i_4 = 0; (i_4 < 4); i_4 = (i_4 + 1))
        {
            C_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_4))] = ((half)(0.0f));
        }
        __syncthreads();
        for (int32_t i_5 = 0; (i_5 < 32); i_5 = (i_5 + 1))
        {
            for (int32_t i_6 = 0; (i_6 < 4); i_6 = (i_6 + 1))
            {
                atomicAdd(&C_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_6))], (A_3[((((int)threadIdx.x / 32) * 32) + i_5)] * B_1[((i_5 * 128) + ((((int)threadIdx.x % 32) * 4) + i_6))]));
            }
        }
        __syncthreads();
        __syncthreads();
        half *PV = smem_pv;
        half *O_1 = smem_o;
        half *M_local_1 = smem_m_local;
        half *M_new_1 = smem_m_new;
        half *M_3 = smem_m;
        half *L_new_1 = smem_l_new;
        half *L_2 = smem_l;
        for (int32_t i_7 = 0; (i_7 < 4); i_7 = (i_7 + 1))
        {
            O_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_7))] = (((((half)(powf((float)(L_new_1[((int)threadIdx.x / 32)]), ((float)((-1)))))) * (L_2[((int)threadIdx.x / 32)] * hexp((M_3[((int)threadIdx.x / 32)] - M_new_1[((int)threadIdx.x / 32)])))) * O_1[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_7))]) + (hexp((M_local_1[((int)threadIdx.x / 32)] - M_new_1[((int)threadIdx.x / 32)])) * PV[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_7))]));
        }
        __syncthreads();
        if ((j + 1) == 16)
        {
            for (int32_t i_8 = 0; (i_8 < 4); i_8 = (i_8 + 1))
            {
                O[(((((int)blockIdx.x * 32) + ((int)threadIdx.x / 32)) * 128) + ((((int)threadIdx.x % 32) * 4) + i_8))] = smem_o[((((int)threadIdx.x / 32) * 128) + ((((int)threadIdx.x % 32) * 4) + i_8))];
            }
            __syncthreads();
        }
        if ((int)threadIdx.x < 32)
        {
            smem_m[((int)threadIdx.x % 32)] = smem_m_new[((int)threadIdx.x % 32)];
            smem_l[((int)threadIdx.x % 32)] = smem_l_new[((int)threadIdx.x % 32)];
        }
        __syncthreads();
    }
    __syncthreads();
    return;
}
// normal_transformer_func.h
__global__ void __launch_bounds__(512) hidet_compute_GLOBAL_QK(half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ GLOBAL_QK) {
  half acc_Sum = half(0.0f);
  for (int32_t v = 0; (v < 128); v = (v + 1)) {
    acc_Sum = (acc_Sum + (GLOBAL_Q[(((int)blockIdx.x * 128) + v)] * GLOBAL_K[(((int)threadIdx.x * 128) + v)]));
  } 
  GLOBAL_QK[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = acc_Sum;
}

__global__ void __launch_bounds__(512) hidet_compute_S(half * __restrict__ GLOBAL_QK, half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ S) {
  half acc_Max = half(-65504.0f);
  for (int32_t v = 0; (v < 512); v = (v + 1)) {
    acc_Max = __hmax(acc_Max, GLOBAL_QK[(((int)blockIdx.x * 512) + v)]);
  } 
  S[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = (GLOBAL_QK[(((int)blockIdx.x * 512) + (int)threadIdx.x)] - acc_Max);
}

__global__ void __launch_bounds__(512) hidet_compute_exp_s(half * __restrict__ S, half * __restrict__ exp_s) {
  exp_s[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = hexp(S[(((int)blockIdx.x * 512) + (int)threadIdx.x)]);
}

__global__ void __launch_bounds__(512) hidet_compute_softmax(half * __restrict__ S, half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ GLOBAL_QK, half * __restrict__ exp_s, half * __restrict__ softmax) {
  half acc_Sum = half(0.0f);
  for (int32_t v = 0; (v < 512); v = (v + 1)) {
    acc_Sum = (acc_Sum + hexp(S[(((int)blockIdx.x * 512) + v)]));
  } 
  softmax[(((int)blockIdx.x * 512) + (int)threadIdx.x)] = (hexp(S[(((int)blockIdx.x * 512) + (int)threadIdx.x)]) / acc_Sum);
}

__global__ void __launch_bounds__(512) hidet_compute_GLOBAL_O(half * __restrict__ softmax, half * __restrict__ GLOBAL_V, half * __restrict__ GLOBAL_Q, half * __restrict__ GLOBAL_K, half * __restrict__ GLOBAL_QK, half * __restrict__ S, half * __restrict__ exp_s, half * __restrict__ GLOBAL_O) {
  half acc_Sum = half(0.0f);
  for (int32_t v = 0; (v < 512); v = (v + 1)) {
    acc_Sum = (acc_Sum + (softmax[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 128) * 512) + v)] * GLOBAL_V[((v * 128) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 128))]));
  } 
  GLOBAL_O[((((((int)blockIdx.x * 512) + (int)threadIdx.x) / 128) * 128) + ((((int)blockIdx.x * 512) + (int)threadIdx.x) % 128))] = acc_Sum;
}

Again, really wonderful work on Hidet! And any help will be well appreciated 🙏 Or if any further info. is needed, please let me know.

[Bug] arguments of clip drop after fusion

This snippet gives the wrong result

import hidet
import numpy as np
import onnx

from onnx import numpy_helper, TensorProto
from onnx.helper import (
    make_model, make_node, make_graph,
    make_tensor_value_info)

X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None])
Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None, None])

v_min = numpy_helper.from_array(np.array(-1, dtype='float32'), name='v_min')
v_max = numpy_helper.from_array(np.array(1, dtype='float32'), name='v_max')

vmin = make_node('Constant', [], ['vmin'], value=v_min) 
vmax = make_node('Constant', [], ['vmax'], value=v_max) 

node = make_node('Clip',['X', 'vmin', 'vmax'], ['Y'])
graph = make_graph([vmin, vmax, node], 'foo', [X], [Y]) 
onnx_model = make_model(graph)

hidet.option.search_space(1)
hidet_onnx_module = hidet.graph.frontend.from_onnx(onnx_model)
hidet_onnx_module = hidet_onnx_module.to_cuda()

input_names = hidet_onnx_module.input_names

inputs = [hidet.asarray(np.array([6], dtype='float32'), device='cuda')]
symbol_data = [hidet.symbol_like(inputs[0])]
symbol_output = hidet_onnx_module(*symbol_data)
graph: hidet.FlowGraph = hidet.trace_from(symbol_output, inputs=symbol_data)
with hidet.graph.PassContext() as ctx:
    graph_opt: hidet.FlowGraph = hidet.graph.optimize(graph)
    print(graph_opt)
cuda_graph = graph_opt.cuda_graph()
outputs = cuda_graph.run(inputs)

print(outputs[0].detach().cpu().numpy())

It returns 6 instead of 1. The cuda code generated for clip op is just copying the input tensor without any condition statements. vmin and vmax seem to be dropped. Work well when fuse_operator_pass is disabled.

[Bug] fusion rewrite fails

Running this snippet:

import torch
import hidet
import onnx

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, y):
        y = torch.min(y, dim=0)[0]
        z = x / y
        return z

device = 'cuda'

model = Foo()
model.to(device)

x = torch.rand([1, 1, 1, 1, 1], device=device)
y = torch.rand([2, 2], device=device)

z = model(x, y)
print(z.shape)

torch.onnx.export(model, (x, y), 'tmp.onnx', input_names = ['x', 'y'],
                  output_names = ['z'])
model = onnx.load('tmp.onnx')

hidet.torch.dynamo_config.search_space(1)

x = hidet.from_torch(x)
y = hidet.from_torch(y)
symbol_data = [hidet.symbol_like(x), hidet.symbol_like(y)]
hidet_onnx_module = hidet.graph.frontend.from_onnx(model)
symbol_output = hidet_onnx_module(*symbol_data)
graph: hidet.FlowGraph = hidet.trace_from(symbol_output, inputs=symbol_data)
with hidet.graph.PassContext() as ctx:
    graph_opt: hidet.FlowGraph = hidet.graph.optimize(graph)
cuda_graph = graph_opt.cuda_graph()
outputs = cuda_graph.run([x, y])

raises error message:

  File "/home/su/accdiff/thirdparty/hidet/python/hidet/transforms/tools/apply_prologue_epilogue.py", line 172, in visit_BufferStoreStmt
    remap: Dict[Var, Expr] = {a: b for a, b in strict_zip(tc.axes, out_indices)}
  File "/home/su/accdiff/thirdparty/hidet/python/hidet/utils/py.py", line 55, in strict_zip
    raise ValueError(
ValueError: Expect two sequence have the same length in zip, got length 5 and 1.

[Bug] FP64 reduce

See the snippet below:

#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <hidet/runtime/cuda_context.h>
#include <hidet/runtime/cpu_context.h>
typedef float tfloat32_t;
#define __float_to_tf32(x) (x)
/*
Task(
  name: reduce_min
  parameters:
    x: tensor(float64, [1])
    y: tensor(float64, [])
  inputs: [x]
  outputs: [y]
  computations:
    x: tensor(float64, [1])
    y: float64[] where y[] = reduce([1], (v) => x[v], minreduce)
  attributes: {dims: [0], keep_dim: 0, reduce_type: min, accumulate_dtype: float32}
)
*/
extern "C" {

__global__ void __launch_bounds__(32) hidet_reduce_min_grid(double * __restrict__ x, double * __restrict__ y) {
  // label: reduce schedule
  float rv = 3.4028234663852886e+38f;
  if ((int)threadIdx.x < 1) {
    rv = ((float)(fmin(rv, x[(int)threadIdx.x])));
  }
  int32_t mask = __activemask();
  rv = fminf(rv, __shfl_down_sync(mask, rv, 16, 32));
  rv = fminf(rv, __shfl_down_sync(mask, rv, 8, 32));
  rv = fminf(rv, __shfl_down_sync(mask, rv, 4, 32));
  rv = fminf(rv, __shfl_down_sync(mask, rv, 2, 32));
  rv = fminf(rv, __shfl_down_sync(mask, rv, 1, 32));
  rv = __shfl_sync(mask, rv, 0, 32);
  rv = rv;
  if ((int)threadIdx.x < 1) {
    if ((int)threadIdx.x == 0) {
      y[0] = ((double)(rv));
    }
  }
}

__host__ void hidet_reduce_min(int32_t num_args, int32_t * __restrict__ arg_types, void* * __restrict__ args) {
  assert(((void)"Expect 2 arguments", (num_args == 2)));
  assert(((void)"The 0-th argument should be tensor(float64, [1])", (arg_types[0] == 3)));
  assert(((void)"The 1-th argument should be tensor(float64, [])", (arg_types[1] == 3)));
  hidet_reduce_min_grid<<<dim3(1, 1, 1), dim3(32, 1, 1), 0, (cudaStream_t)get_cuda_stream()>>>(((double*)(args[0])), ((double*)(args[1])));
}

}

Several problems here:

  • Why use fp32 accumulation in a fp64 operator? I guess we should not use fp32 for the default value of accumulation type (what if the dtype is interger?)
  • fp32 accumulator further leads to the call to fmin(float, double), which only exists as a host function, causing a compile error.

Question about Complex datatype support

The inductor backend of Pytorch2.0 does not officially support Complex data types yet (pytorch/pytorch#93424), just wondering if hidet has the same limitation currently, or not?

If it does does it rely on other parts of Pytorch2.0 (e.g., the inductor, dynamo etc.) to fully support complex, or can complex support be added separately?

[Bug] Unexpected behavior when inputs and outputs overlap

If some tensors exist both in the inputs and outputs of a FlowGraph (a direct link), it will create a dummy output that overrides the variable name and prohibit reading from where the input is really stored.

Though I don't think users will write such code intentionally, the current framework won't stop users to do so, which might be confusing in some cases. If the syntax of FlowGraph is SSA, inputs/outputs as an interface exposed to users are worth a sanity check. Furthermore, if in-place operators are supported in the future, it is possible that the input and output share the same storage.

[Bug] broadcast_shape parameter type error

Hi!

I encounter this problem when I tried to optimize a onnx model. In

new_shape = hidet.graph.ops.definitions.arithmetic.broadcast_shape(data.shape, new_shape)

data.shape is a Tuple. And in the definition of broadcast_shape.

x_shape = [1] + x_shape

x_shape will be a Tuple. However, [1] + x_shape will raise an error TypeError: can only concatenate list (not "tuple") to list.

[Feature]

hi,@yaoyaoding
Hidet is amazing and is learning from him recently.
Please there are some official tutorial for testing matrix multiplication with Hidet?Thank you!

[Bug] pip install fails, Cannot open include file: 'nvtx3/nvToolsExt.h': No such file or directory ; CUDA Toolkit installed, works in VS19

Describe the bug
Platform : Windows 10
pip install fails at stage "Building wheel for nvtx" with error Cannot open include file: 'nvtx3/nvToolsExt.h': No such file or directory
Cuda toolkit 12.1 is installed. Importing the module in VS19 works (starting a project with CUDA Toolkit).

To Reproduce
pip install hidet

Expected behavior
installation should succeed

Enviroment

  • OS: Windows 10
  • GPU: GTX 1660Ti
  • Others: CUDA Toolkit 12.1.1

Additional context
Installing in conda environment

[Bug] Performance regression of softmax following PR #122

Describe the bug
There is a performance regression for the softmax operator following the change with PR #122
To Reproduce

git checkout 3869540f33ede74a22922f33721e0b66699c5412

Run the following script with pytest

import pytest
import numpy as np

import hidet as hi
from hidet import ops

from hidet.testing import check_unary

import git
repo = git.Repo(search_parent_directories=True)
sha = repo.head.object.hexsha
print(sha)
hi.option.cache_dir('./cache_{}'.format(sha))


def numpy_softmax(data, axis):
    data = np.exp(data - np.max(data, axis, keepdims=True))
    data = data / np.sum(data, axis, keepdims=True)
    return data


@pytest.mark.parametrize(
    "shape, axis",
    [[[16, 1024, 1024], 2],] 
)
def test_softmax(shape, axis):
    check_unary(
        shape, lambda x: numpy_softmax(x, axis), lambda x: ops.softmax(x, axis), dtype='float16', atol=1, rtol=1, device='cuda',
    )
git checkout HEAD^ # previous commit

Run the script again with pytest

Diff the generated kernel code

diff -r cache_3869540f33ede74a22922f33721e0b66699c5412 cache_e57b798a427fbfb13c386f0ae34839a909fbfb7a

There is a big increase in index calculation and redundant checks. This is possibly introduced by the changes in TaskMapping but I was not able to locate exactly where.

Expected behavior
Performance should be similar, kernel code should not introduce too much index calculation and if statements

Enviroment

  • OS: [ Ubuntu 22.04]
  • GPU: [RTX 3090]
  • Others: [e.g. NVIDIA GPU 520.61.05]

Additional context

[Bug] binary arithmetic with CUDA scalar

In

It looks like in order to calculate a binary operator with a scalar on GPU, we need to first copy it to CPU (Is it expected? Will it affect the performance because of synchronization?). And the cpu() function raises an error saying we should first detach the variable in

return self.cpu().numpy().tolist()
.

If we add detach here, here comes the second problem. It looks like the .numpy(or to_dlpack) method does not support a tensor with only one element and shape []. The error message is like:

  File "/home/su/accdiff/thirdparty/hidet/python/hidet/graph/ops/definitions/arithmetic.py", line 529, in divide
    return binary_arithmetic(
  File "/home/su/accdiff/thirdparty/hidet/python/hidet/graph/ops/definitions/arithmetic.py", line 479, in binary_arithmetic
    x = x.dtype(x.item())
  File "/home/su/accdiff/thirdparty/hidet/python/hidet/graph/tensor.py", line 528, in item
    ret = self.squeeze(dims=list(range(len(self.shape)))).tolist()
  File "/home/su/accdiff/thirdparty/hidet/python/hidet/graph/tensor.py", line 488, in tolist
    ret = ret.numpy()
  File "/home/su/accdiff/thirdparty/hidet/python/hidet/graph/tensor.py", line 930, in numpy
    return np.from_dlpack(self)
  File "/home/su/accdiff/thirdparty/hidet/python/hidet/graph/tensor.py", line 467, in __dlpack__
    return to_dlpack(self)
  File "/home/su/accdiff/thirdparty/hidet/python/hidet/graph/impl/dlpack.py", line 251, in to_dlpack
    return DLManagedTensorContext(tensor).capsuled_dltensor()
  File "/home/su/accdiff/thirdparty/hidet/python/hidet/graph/impl/dlpack.py", line 229, in __init__
    data=tensor.storage.addr,
AttributeError: 'NoneType' object has no attribute 'addr'

[Bug] Hidet numpy version conflicts with Google colab version

Describe the bug
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.12.0 requires numpy<1.24,>=1.22, but you have numpy 1.24.3 which is incompatible.
numba 0.56.4 requires numpy<1.24,>=1.18, but you have numpy 1.24.3 which is incompatible.

To Reproduce
https://colab.research.google.com/drive/1K0gF1lBuEx0s3i8Pm3JMw5_Njs4mgxeZ#scrollTo=oK2NFZu0BZXJ

Enviroment

  • Google Colab
  • T4 Gpu

[Bug] torch dynamo cannot find hidet backend

Describe the bug
I'm trying to optimize pytorch models with hidet. However, whether installing hidet via pip or from source, it will produce an error that the backend cannot be found:

File "cell1.py", line 12, in
model_opt = torch.compile(model, backend='hidet')
File "/usr/local/lib/python3.8/dist-packages/torch/init.py", line 1419, in compile
return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 426, in optimize
backend = get_compiler_fn(backend)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 356, in get_compiler_fn
compiler_fn = lookup_backend(compiler_fn)
File "/usr/local/lib/python3.8/dist-packages/torch/_dynamo/eval_frame.py", line 365, in lookup_backend
compiler_fn = BACKENDS[compiler_fn]
KeyError: 'hidet'

To Reproduce
gallery/getting-started/quick-start.py

Expected behavior
Theoretically, torch dynamo should have registered the hidet backend after successful installation. Do you have any idea about it?

Enviroment

  • Python verison: 3.8.10
  • Hidet version: 0.2.2.dev
  • PyTorch version: 2.0.0+cu116
  • OS: Ubuntu 22.04.2 LTS
  • GPU: NVIDIA A10
  • GPU driver: 525.85.12

[Bug]

test code is as follows:

#!/usr/bin/python3
import torch

# Define pytorch model
model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).cuda().eval()
x = torch.rand(1, 3, 224, 224).cuda()

# Compile the model through Hidet
# Optional: set optimization options (see our documentation for more details)
#   import hidet
#   hidet.torch.dynamo_config.search_space(2)  # tune each tunable operator
#   hidet.torch.dynamo_config.use_fp16()       # use float16 for acceleration
model_opt = torch.compile(model, backend='hidet')

# Run the optimized model
y = model_opt(x)
Traceback (most recent call last):
  File "./test.py", line 13, in <module>
    model_opt = torch.compile(model, backend='hidet')
  File "/home/user02/.local/lib/python3.8/site-packages/torch/__init__.py", line 1441, in compile
    return torch._dynamo.optimize(backend=backend, nopython=fullgraph, dynamic=dynamic, disable=disable)(model)
  File "/home/user02/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 424, in optimize
    backend = get_compiler_fn(backend)
  File "/home/user02/.local/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 363, in get_compiler_fn
    compiler_fn = lookup_backend(compiler_fn)
  File "/home/user02/.local/lib/python3.8/site-packages/torch/_dynamo/backends/registry.py", line 58, in lookup_backend
    _lazy_import_entry_point(compiler_fn)
  File "/home/user02/.local/lib/python3.8/site-packages/torch/_dynamo/backends/registry.py", line 100, in _lazy_import_entry_point
    eps = [ep for ep in backend_eps[group_name] if ep.name == backend_name]
KeyError: 'torch_dynamo_backends'

May I ask why this error occurs and how to solve it?Thanks.

[Question] how to use with facebook dinov2

reproduce

model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
model = model.to('cuda')
torch._dynamo.config.verbose = True
hidet.torch.dynamo_config.use_tensor_core(True)
hidet.torch.dynamo_config.search_space(2)
model_opt = torch.compile(model, backend='hidet')

some error

torch._dynamo.exc.BackendCompilerFailed: hidet_backend raised NotImplementedError: The following modules/functions are not supported by hidet yet:
  torch.nn.Identity
Enviroment

env
OS: Ubuntu 22.04
GPU: A100

question

[Bug] Use int64 in argmax

Now hidet uses int32 as the return type of ArgReduceTask,

extent=x_shape[dim], fcompute=reduce_fcompute, reduce_type=reduce_type, index_dtype='int32'

which is misaligned with torch and onnx that return int64, leading to incompatibility with other operators. Like concatenation requires inputs to have the same dtype. So concatenating the output of argmax with an int64 tensor is legal in torch but illegal in hidet.

Here is a simple snippet:

import torch
import hidet
import onnx

class Foo(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, x, y):
        y = torch.argmax(y, dim=0)
        print(y.dtype) # int64
        return torch.concat([x, y])

device = 'cuda'

model = Foo()
model.to(device)

x = torch.ones([5], dtype=torch.int64, device=device)
y = torch.rand([5, 5], device=device)
z = model(x, y)
print(z.shape)

torch.onnx.export(model, (x, y), 'tmp.onnx', input_names = ['x', 'y'],
                  output_names = ['z'])
model = onnx.load('tmp.onnx')

hidet.torch.dynamo_config.search_space(1)

x = hidet.from_torch(x)
y = hidet.from_torch(y)
symbol_data = [hidet.symbol_like(x), hidet.symbol_like(y)]
hidet_onnx_module = hidet.graph.frontend.from_onnx(model)
symbol_output = hidet_onnx_module(*symbol_data)
graph: hidet.FlowGraph = hidet.trace_from(symbol_output, inputs=symbol_data)
with hidet.graph.PassContext() as ctx:
    graph_opt: hidet.FlowGraph = hidet.graph.optimize(graph)
cuda_graph = graph_opt.cuda_graph()
outputs = cuda_graph.run([x, y])

which raises an error:

ValueError: concat: expect all tensors have the same dtype, but got:
Tensor(shape=(5,), dtype='int64', device='cuda:0')
Tensor(shape=(5,), dtype='int32', device='cuda:0')

[Feature]Does hidet support training?

Does hidet support training? I saw that it is now a backend inside torch dynamo now, so does this mean we can also reuse AOT autograd with hidet backend to do training now? How does that work if it does support now or why it's not working yet?

Question

hi @yaoyaoding
I recently tested matrix multiplication using hidet. See the following compilation content, have a question to consult. The contents are as follows:
Compling cuda task batch_matmul (a=float32(1,4352,4096),b=float32(1,4096,4096),c=float32(1,4352,4096),batch_size=1,m_size=4352,n_size=4096,k_size=4096,mma='mma')...
Compling cpu task cast(x=float64(1,4352,4096),y=float32(1,4352,4096))...
Compling cpu task cast(x=float64(1,4096,4096),y=float32(1,4096,4096))...
Compling:100%...

my question is :
In hidet, CUDA and CPU tasks were compiled separately, did they complete matrix multiplication separately, or did they complete part of each?

Thank you very much again !

[Bug] Some hidet tensor methods do not support symbolic tensors?

Hi, thanks for the great work!

I am wondering why some hidet tensor methods (e.g., to, cuda, and cpu) do not support symbolic tensors.

class TestMode(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv = nn.Linear(10, 10)

    def forward(self, x):
        z = x.unsqueeze(0).expand(4, 4, 512).to(torch.device("cuda"))
        return z

if __name__ == "__main__":
    model = TestMode()
    model = model.eval().half()
    model = model.to(device)
    hidet.torch.dynamo_config.search_space(2)
    hidet.torch.dynamo_config.use_fp16()
    model_opt = torch.compile(model, backend='hidet')

    tokens = torch.zeros(20, 10).cuda()
    model_opt(tokens)

In the above test case, the exception
NotImplementedError: hidet: Tensor.to(..., device=...) is not supported for symbolic tensors., occurred when calling tensor_to(Tensor(shape=(4, 4, 512), dtype='bool', device='cuda:0'), device(type='cuda')) is raised.

I think the operation (.to(device)) is a common operation for deep learning models as the implementation of huggingface llama

Are there any concerns or limitations regarding these operations for symbolic trace?
Look forward to your response. Thanks!

Half function undefined during compilation

Try to run softmax with fp16, but get the error message error: identifier "__hmax" is undefined.

It seems some functions don't exist on older hardware before Ampere. Please refer to this

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.