Code Monkey home page Code Monkey logo

Comments (10)

kuhar avatar kuhar commented on June 8, 2024 3

This is just an umbrella issue to get started. Feel free to modify / fill in the blanks / link sub-issues and related discussions.
cc: @antiagainst @MaheshRavishankar @qedawkins @raikonenfnu @hanhanW @bjacob

from shark.

kuhar avatar kuhar commented on June 8, 2024 3

The gfx940 ISA supports 2 fp8 formats: fp8 and bf8. You can see both format supported with mfma, including operands of mixed formats: https://llvm.org/docs/AMDGPU/AMDGPUAsmGFX940.html#vop3.

FP8 mfma is plumbed through the amdgpu llvm backend: https://reviews.llvm.org/D129906, for example:

// CHECK-GFX940-LABEL: @test_mfma_f32_32x32x16_fp8_bf8
// CHECK-GFX940: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.fp8.bf8(i64 %a, i64 %b, <16 x float> %c, i32 0, i32 0, i32 0)
void test_mfma_f32_32x32x16_fp8_bf8(global v16f* out, long a, long b, v16f c)
{
  *out = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(a, b, c, 0, 0, 0);
}

The fp8 operands are packed as i64. The only other amdgcn intrinsic for fp8 types is cvt -- type conversions. https://github.com/llvm/llvm-project/blob/cd3942059eed7b7185f26bc583ac287a995db0d0/clang/include/clang/Basic/BuiltinsAMDGPU.def#L400-L407

FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.

The AMD CDNA 3 compute units support both variants of the FP8 data type as defined in the OCP 8-bit floating point specification.

OCP 8-bit Floating Point Specification (OFP8)

Related paper with an overview of fp8 types: FP8 FORMATS FOR DEEP LEARNING

Related blog post with overview of fp8 support for H100: https://lambdalabs.com/blog/nvidia-hopper-h100-and-fp8-support

from shark.

kuhar avatar kuhar commented on June 8, 2024 1

FP8 support in LLVM/MLIR:

RFC from Sep '22 by @stellaraccident: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279.

Since then, the other types plumbed all the way through MLIR are:

  Float8E4M3FNType f8E4M3FNTy;
  Float8E5M2FNUZType f8E5M2FNUZTy;
  Float8E4M3FNUZType f8E4M3FNUZTy;
  Float8E4M3B11FNUZType f8E4M3B11FNUZTy;

(https://github.com/llvm/llvm-project/blob/6af713ae170c34f0561f19e594266ce2a2af343b/mlir/lib/IR/MLIRContext.cpp#L223C27-L227)

      .Case<Float8E5M2Type>([&](Type) { os << "f8E5M2"; })
      .Case<Float8E4M3FNType>([&](Type) { os << "f8E4M3FN"; })
      .Case<Float8E5M2FNUZType>([&](Type) { os << "f8E5M2FNUZ"; })
      .Case<Float8E4M3FNUZType>([&](Type) { os << "f8E4M3FNUZ"; })
      .Case<Float8E4M3B11FNUZType>([&](Type) { os << "f8E4M3B11FNUZ"; })

(https://github.com/llvm/llvm-project/blob/6af713ae170c34f0561f19e594266ce2a2af343b/mlir/lib/IR/AsmPrinter.cpp#L2548C1-L2552C73)

func.func @float_attrs_pass() {
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E5M2
    float_attr = 2. : f8E5M2
  } : () -> ()
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E4M3FN
    float_attr = 2. : f8E4M3FN
  } : () -> ()
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E5M2FNUZ
    float_attr = 2. : f8E5M2FNUZ
  } : () -> ()
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E4M3FNUZ
    float_attr = 2. : f8E4M3FNUZ
  } : () -> ()
  "test.float_attrs"() {
    // CHECK: float_attr = 2.000000e+00 : f8E4M3B11FNUZ
    float_attr = 2. : f8E4M3B11FNUZ
  } : () -> ()
  "test.float_attrs

(https://github.com/llvm/llvm-project/blob/dd047c5b64944bae830b9fecf53f8d11ff41386e/mlir/test/IR/attribute.mlir#L38C1-L59C20)

static constexpr fltSemantics semFloat8E5M2 = {15, -14, 3, 8};
static constexpr fltSemantics semFloat8E5M2FNUZ = {
   15, -15, 3, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3FN = {
   8, -6, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
static constexpr fltSemantics semFloat8E4M3FNUZ = {
   7, -7, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
   4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};

(https://github.com/llvm/llvm-project/blob/dd047c5b64944bae830b9fecf53f8d11ff41386e/llvm/lib/Support/APFloat.cpp#L132-L140)

from shark.

antiagainst avatar antiagainst commented on June 8, 2024 1

FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.

If a model is trained with 2-bit mantissas (E5M2), how is the 3rd bit of mantissa in E4M3 going to be useful in inference?

This is explained a bit in the NVIDIA doc as linked in my previous comment:

During training neural networks both of these types may be utilized. Typically forward activations and weights require more precision, so E4M3 datatype is best used during forward pass. In the backward pass, however, gradients flowing through the network typically are less susceptible to the loss of precision, but require higher dynamic range. Therefore they are best stored using E5M2 data format. H100 TensorCores provide support for any combination of these types as the inputs, enabling us to store each tensor using its preferred precision.

from shark.

qedawkins avatar qedawkins commented on June 8, 2024 1

Support in MLIR/LLVM/AMDGPU already seems quite promising, so as discussed this morning the plan is to show a very simple example using fp8 in IREE first, something like

module {
  func.func @matmul_static(%arg0: tensor<32x32xi8>, %arg1: tensor<32x32xi8>, %arg2: tensor<32x32xf32>) -> tensor<32x32xf32> {
    %0 = tensor.bitcast %arg0 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
    %1 = tensor.bitcast %arg1 : tensor<32x32xi8> to tensor<32x32xf8E4M3FNUZ>
    %2 = linalg.matmul ins(%0, %1 : tensor<32x32xf8E4M3FNUZ>, tensor<32x32xf8E4M3FNUZ>) outs(%arg2 : tensor<32x32xf32>) -> tensor<32x32xf32>
    return %2 : tensor<32x32xf32>
  }
}

or, to avoid the need to also handle mfma at the same time, just something as simple as

#map = affine_map<(d0) -> (d0)>
module {
  func.func @extend_i8(%arg0: tensor<32xi8>) -> tensor<32xf32> {
    %0 = tensor.bitcast %arg0 : tensor<32xi8> to tensor<32xf8E4M3FNUZ>
    %1 = tensor.empty() : tensor<32xf32>
    %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%0 : tensor<32xf8E4M3FNUZ>) outs(%1 : tensor<32xf32>) {
    ^bb0(%in: f8E4M3FNUZ, %out: f32):
      %3 = arith.extf %in : f8E4M3FNUZ to f32
      linalg.yield %3 : f32
    } -> tensor<32xf32>
    return %2 : tensor<32xf32>
  }
}

from shark.

kuhar avatar kuhar commented on June 8, 2024

amgcn's fp8 maps to f8E4M3FNUZ while bf8 to f8E5M2NUZ.

from shark.

bjacob avatar bjacob commented on June 8, 2024

FP8 is E4M3 (inference-focused) while BF8 is E5M2 (training-focused): https://www.amd.com/en/products/accelerators/instinct/mi300/mi300a.html.

If a model is trained with 2-bit mantissas (E5M2), how is the 3rd bit of mantissa in E4M3 going to be useful in inference?

from shark.

antiagainst avatar antiagainst commented on June 8, 2024

Also https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html talks a bit about fp8 in NVIDIA GPUs, which is useful reference.

In general, fp8 right now are just used in a very ad-hoc way--with ISAs just do conversion and tensor/matrix core ops. For training we also have different fp8 scaling factors for different tensors and need model/framework level handling there, so also quite ad-hoc.

So as we've discussed in the meeting, getting a minimal matmul to excersise fp8 + tensor/matrix core in IREE/SHARK would be good start and foundation to everything else. We can then build other parts on top.

from shark.

kuhar avatar kuhar commented on June 8, 2024

Explanation of the LLVM fp semantics naming convention:

F is for "finite" (no infinities), N for with special NaN encoding, UZ for unsigned zero.

source: https://github.com/jax-ml/ml_dtypes?tab=readme-ov-file#float8_e5m2fnuz

from shark.

MaheshRavishankar avatar MaheshRavishankar commented on June 8, 2024

Looking through support in MLIR and lowering into NVVM/ROCDL, seems to be already there as well..

MFMA to ROCLD intrinsics :

Tensor core instructions lowering

  • WGMMA instruction support (link)
  • I didnt find support for conversion instructions. So thats strange.

So for the examples in this comment #2054 (comment) , the extension truncation should just pass through and compile on AMD. The mfma support, it would be great if we could just take a single matmul of the exact mfma shape and it would just lower to that operation. Like literally all tile sizes would be 1... it should vectorize to vector.contract, lower to amdgpu.mfma -> rocdl intrinsics...

from shark.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.