Comments (10)
This is just an umbrella issue to get started. Feel free to modify / fill in the blanks / link sub-issues and related discussions.
cc: @antiagainst @MaheshRavishankar @qedawkins @raikonenfnu @hanhanW @bjacob
from shark.
The gfx940 ISA supports 2 fp8 formats: fp8 and bf8. You can see both format supported with mfma, including operands of mixed formats: https://llvm.org/docs/AMDGPU/AMDGPUAsmGFX940.html#vop3.
FP8 mfma is plumbed through the amdgpu llvm backend: https://reviews.llvm.org/D129906, for example:
// CHECK-GFX940-LABEL: @test_mfma_f32_32x32x16_fp8_bf8
// CHECK-GFX940: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.fp8.bf8(i64 %a, i64 %b, <16 x float> %c, i32 0, i32 0, i32 0)
void test_mfma_f32_32x32x16_fp8_bf8(global v16f* out, long a, long b, v16f c)
{
*out = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(a, b, c, 0, 0, 0);
}
The fp8 operands are packed as i64. The only other amdgcn intrinsic for fp8 types is cvt
-- type conversions. https://github.com/llvm/llvm-project/blob/cd3942059eed7b7185f26bc583ac287a995db0d0/clang/include/clang/Basic/BuiltinsAMDGPU.def#L400-L407
FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.
OCP 8-bit Floating Point Specification (OFP8)
Related paper with an overview of fp8 types: FP8 FORMATS FOR DEEP LEARNING
Related blog post with overview of fp8 support for H100: https://lambdalabs.com/blog/nvidia-hopper-h100-and-fp8-support
from shark.
FP8 support in LLVM/MLIR:
RFC from Sep '22 by @stellaraccident: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279.
- e5m2 landed here: https://reviews.llvm.org/D133823, the MLIR syntax is
f8E5M2
Since then, the other types plumbed all the way through MLIR are:
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
.Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
.Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
.Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
.Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
.Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })
func.func @float_attrs_pass() {
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2
float_attr = 2. : f8E5M2
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FN
float_attr = 2. : f8E4M3FN
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ
float_attr = 2. : f8E5M2FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
float_attr = 2. : f8E4M3FNUZ
} : () -> ()
"test.float_attrs"() {
// CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
float_attr = 2. : f8E4M3B11FNUZ
} : () -> ()
"test.float_attrs
static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
static constexpr fltSemantics semFloat8E5M2FNUZ = {
15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3FN = {
8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static constexpr fltSemantics semFloat8E4M3FNUZ = {
7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
from shark.
FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.
If a model is trained with 2-bit mantissas (E5M2), how is the 3rd bit of mantissa in E4M3 going to be useful in inference?
This is explained a bit in the NVIDIA doc as linked in my previous comment:
During training neural networks both of these types may be utilized. Typically forward activations and weights require more precision, so E4M3 datatype is best used during forward pass. In the backward pass, however, gradients flowing through the network typically are less susceptible to the loss of precision, but require higher dynamic range. Therefore they are best stored using E5M2 data format. H100 TensorCores provide support for any combination of these types as the inputs, enabling us to store each tensor using its preferred precision.
from shark.
Support in MLIR/LLVM/AMDGPU already seems quite promising, so as discussed this morning the plan is to show a very simple example using fp8 in IREE first, something like
module {
func.func @matmul_static(%arg0: tensor<32x32xi8>, %arg1: tensor<32x32xi8>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
%0 = tensor.bitcast %arg0 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
%1 = tensor.bitcast %arg1 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
%2 = linalg.matmul ins(%0, %1 : tensor<32x32xf8E4M3FNUZ>, tensor<32x32xf8E4M3FNUZ>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
return %2 : tensor<32x32xf32>
}
}
or, to avoid the need to also handle mfma at the same time, just something as simple as
#map = affine_map<(d0) -> (d0)>
module {
func.func @extend_i8(%arg0: tensor<32xi8>) -> tensor<32xf32> {
%0 = tensor.bitcast %arg0 : tensor<32xi8> to tensor<32xf8E4M3FNUZ>
%1 = tensor.empty() : tensor<32xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%0 : tensor<32xf8E4M3FNUZ>) outs(%1 : tensor<32xf32>) {
^bb0(%in: f8E4M3FNUZ, %out: f32):
%3 = arith.extf %in : f8E4M3FNUZ to f32
linalg.yield %3 : f32
} -> tensor<32xf32>
return %2 : tensor<32xf32>
}
}
from shark.
amgcn's fp8 maps to f8E4M3FNUZ
while bf8 to f8E5M2NUZ
.
from shark.
FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.
If a model is trained with 2-bit mantissas (E5M2), how is the 3rd bit of mantissa in E4M3 going to be useful in inference?
from shark.
Also https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html talks a bit about fp8 in NVIDIA GPUs, which is useful reference.
In general, fp8 right now are just used in a very ad-hoc way--with ISAs just do conversion and tensor/matrix core ops. For training we also have different fp8 scaling factors for different tensors and need model/framework level handling there, so also quite ad-hoc.
So as we've discussed in the meeting, getting a minimal matmul to excersise fp8 + tensor/matrix core in IREE/SHARK would be good start and foundation to everything else. We can then build other parts on top.
from shark.
Explanation of the LLVM fp semantics naming convention:
F is for "finite" (no infinities), N for with special NaN encoding, UZ for unsigned zero.
source: https://github.com/jax-ml/ml_dtypes?tab=readme-ov-file#float8_e5m2fnuz
from shark.
Looking through support in MLIR and lowering into NVVM/ROCDL, seems to be already there as well..
MFMA to ROCLD intrinsics :
- MFMA instruction lowering (link)
- Convert instructions lowering ([link])(https://github.com/llvm/llvm-project/blob/d9a9872ec4760762fdc467ef283cea302a3742e5/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp#L683))
Tensor core instructions lowering
- WGMMA instruction support (link)
- I didnt find support for conversion instructions. So thats strange.
So for the examples in this comment #2054 (comment) , the extension truncation should just pass through and compile on AMD. The mfma support, it would be great if we could just take a single matmul of the exact mfma shape and it would just lower to that operation. Like literally all tile sizes would be 1... it should vectorize to vector.contract
, lower to amdgpu.mfma
-> rocdl intrinsics...
from shark.
Related Issues (20)
- LoRA models not showing up in dropdown
- System interrupts gets very high CPU followed by complete computer shutdown HOT 3
- Document for the compilation flow ?
- Enable llama2 benchmarking with Turbine HOT 1
- (Shark-1.0) setup_venv.ps1 specifies a torch-mlir version that has dropped off the end of package index
- RuntimeError: duplicate registrations HOT 3
- Gradio Button has no attr "label"
- Gradio JSON has no option interactive
- problem with Torch-mlir on Ubuntu, latest en previous Shark version
- (Studio2) Centralize and minimize prompt handling for LLMs
- (Studio2) (Windows) 'rocm devices are not available'
- Drastic difference in images generated in 984 vs 1091
- Can't run SDXL Turbo at all. My 32GB of RAM at 100% and fails. HOT 2
- After compiling Vulkan shaders, multiple errors ensue HOT 3
- bf16 result mismatch for Conv2D op HOT 30
- Compiling vulkan shaders never ends HOT 4
- "AttributeError: module 'mpmath' has no attribute 'rational'" when recreating .venv (+workaround) HOT 2
- Error registering modules: C:\actions-runner\w\SRT\SRT\c\runtime\src\iree\hal\drivers\vulkan\native_executable.cc:51:
- running SDXL set to 1024x1024 renders a black image
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from shark.