Code Monkey home page Code Monkey logo

cute-gemm's People

Contributors

luliyucoordinate avatar reed-lau avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

cute-gemm's Issues

如何确定Swizzle最优的M,S,B?

在代码中有两处Swizzle配置,第一处是A和B Share Memory的Swizzle(M,S,B=3,3,3),第二处是C Share Memory的Swizzle(M,S,B=3,3,2)。

不太理解这两处Swizzle不一样的原因,修改了一下第一处A和B的Swizzle为 M,S,B=3,3,2,对修改前和修改后进行profile。

    static constexpr int kShmLoadSwizzleM = 3;
    static constexpr int kShmLoadSwizzleS = 3;
    static constexpr int kShmLoadSwizzleB = 2; // 原先为3

    using SmemLayoutAtom = decltype(composition(
        Swizzle<kShmLoadSwizzleB, kShmLoadSwizzleM, kShmLoadSwizzleS>{},
        make_layout(make_shape(Int<8>{}, Int<kTileK>{}),
                    make_stride(Int<kTileK>{}, Int<1>{}))));

修改前Profile的kernel avg duration 为 121.5191μs

修改后Profile的kernel avg duration 为 121.0709μs

另外采样了这两组bank conflict的各项指标,看上去修改后(B=2)时整体的指标更好一些。

reed老师可以讨论一下这个问题嘛?

附:profile文件

ncu-profiles.zip

关于ampere架构上多级流水线技术

大佬可以讲解一下gemm的多级流水线吗?为什么在安培架构上至少需要3级流水线呢?使用2级或更多级流水线有什么差别吗?

我不知道我的理解对不对
ampere架构引入的asyncCopy使得g2s/s2r/mma三者可以处于没有依赖关系的并行状态,所以我猜在ampere架构上,至少需要3级流水才能将三者完全并行,类似如下的关系

  • (i + 0) stage的mma
  • (i + 2) stage的g2s
  • (i + 1) stage的s2r

如果是这样的话,2级流水就不够了,因为s2r和mma或s2r和g2s就必须串行了
但是引入更多级流水的意义是什么呢?比如5级流水,目的是提前发射更多g2s的异步请求,使得s2r也可以处于忙碌状态吗?

烦请大佬解答一下

make_tiled_mma中的MMA_P_T是什么意思?

在简单矩阵乘法实现中,我可以理解val layout是warp在各方向上的重复计算;
但是在高效矩阵乘法那片文章中:

    using mma_atom_shape = mma_traits::Shape_MNK;
    static constexpr int kMmaPM = 1 * kMmaEURepeatM * get<0>(mma_atom_shape{});
    static constexpr int kMmaPN = 2 * kMmaEURepeatN * get<1>(mma_atom_shape{});
    static constexpr int kMmaPK = 1 * kMmaEURepeatK * get<2>(mma_atom_shape{});

    using MMA_EU_RepeatT = decltype(make_layout(make_shape(
        Int<kMmaEURepeatM>{}, Int<kMmaEURepeatN>{}, Int<kMmaEURepeatK>{})));
    using MMA_P_T = Tile<Int<kMmaPM>, Int<kMmaPN>, Int<kMmaPK>>;

    using MMA = decltype(make_tiled_mma(mma_atom{}, MMA_EU_RepeatT{}, MMA_P_T{}));

...

    // epilogue: register to global via shared memory
    using SmemLayoutAtomC = decltype(composition(
        Swizzle<2, 3, 3>{}, make_layout(make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}),
                                        make_stride(Int<kMmaPN>{}, Int<1>{}))));
    using SmemLayoutC = decltype(tile_to_shape(
        SmemLayoutAtomC{},
        make_shape(Int<kMmaPM>{}, Int<kMmaPN>{}, Int<kSmemLayoutCBatch>{})));

可以为我解释一下kMmaPM 、kMmaPN 、kMmaPK 的含义吗;
我打印出来其维度是32x32x16,这似乎与文章中写的1x2x1相差很大;
以及在SmemLayoutAtomC中是怎么利用的?

感谢大佬~

gemm-multi-stage中 CopyC reg->shm 似乎存在多余的类型转换

在CopyC中的reg->shm有一段拷贝内循环,需要创建临时变量对tCrC_r2sx进行数据类型转换,比较疑惑tCrC_r2sx转换前的数据类型是什么。

原代码如下:

    for (int j = 0; j < step; ++j) {
      // we add a temp tensor to cope with accumulator and output data type
      // difference
      auto t = make_tensor_like<T>(tCrC_r2sx(_, i + j));
      cute::copy(tCrC_r2sx(_, i + j), t);

      cute::copy(r2s_tiled_copy_c, t, tCsC_r2s(_, 0, 0, j));
    }

遂修改代码如下进行测试,不进行转换,而是直接拷贝:

    for (int j = 0; j < step; ++j) {
      // we add a temp tensor to cope with accumulator and output data type
      // difference
      // auto t = make_tensor_like<T>(tCrC_r2sx(_, i + j));
      // cute::copy(tCrC_r2sx(_, i + j), t);Z
      // cute::copy(r2s_tiled_copy_c, t, tCsC_r2s(_, 0, 0, j));

      cute::copy(r2s_tiled_copy_c, tCrC_r2sx(_, i + j), tCsC_r2s(_, 0, 0, j));
    }

编译后运行结果如下:

cuBLAS version: 120205
cublasLt version: 120205
TiledMMA
  ThrLayoutVMNK:  (_32,_2,_2,_1):(_1,_32,_64,_0)
  PermutationMNK: (_32,_32,_16)
MMA_Atom
  ThrID:      _32:_1
  LayoutA_TV: ((_4,_8),(_2,_2,_2)):((_32,_1),(_16,_8,_128))
  LayoutB_TV: ((_4,_8),(_2,_2)):((_16,_1),(_8,_64))
  LayoutC_TV: ((_4,_8),(_2,_2)):((_32,_1),(_16,_8))
block = (128, 1), gird = (2, 640), shm = 49152
err = 0, str = no error
check ok, max_error = 0.000000
check ok, max_error = 0.000000
M = 81920, N = 256, K = 256
our-impl:
ptr[16b](0x7f4d6c8da010) o (8,8):(256,1):
   -0.92    1.16   -1.99   -2.67    3.73    8.20   -1.80   13.30
    0.26   -8.27    3.47   -5.67    3.22    5.18   -3.02   -1.94
    0.69    0.55   -8.59   -0.18    1.38    1.71    6.43   -3.21
   -2.21    3.14   10.21   -4.23  -11.81   -4.04    4.39   -0.45
   -3.13    7.17   -4.35    7.47    2.07   -1.28    2.77   -0.46
   -5.50    2.15   -1.03    2.75   -1.45    7.04    5.41   -6.82
    6.31    1.05   -5.47   -9.59   -6.00    4.25   -5.16   -0.61
    4.91   -0.43    5.77    5.55   -3.62   -3.80   -5.21    3.45
cublas:
ptr[16b](0x7f4d6a0d9010) o (8,8):(256,1):
   -0.92    1.16   -1.99   -2.67    3.73    8.20   -1.80   13.30
    0.26   -8.27    3.47   -5.67    3.22    5.18   -3.02   -1.94
    0.69    0.55   -8.59   -0.18    1.38    1.71    6.43   -3.21
   -2.21    3.14   10.21   -4.23  -11.81   -4.04    4.39   -0.45
   -3.13    7.17   -4.35    7.47    2.07   -1.28    2.77   -0.46
   -5.50    2.15   -1.03    2.75   -1.45    7.04    5.41   -6.82
    6.31    1.05   -5.47   -9.59   -6.00    4.25   -5.16   -0.61
    4.91   -0.43    5.77    5.55   -3.62   -3.80   -5.21    3.45
cublaslt:
ptr[16b](0x7f4d678d8010) o (8,8):(256,1):
   -0.92    1.16   -1.99   -2.67    3.73    8.20   -1.80   13.30
    0.26   -8.27    3.47   -5.67    3.22    5.18   -3.02   -1.94
    0.69    0.55   -8.59   -0.18    1.38    1.71    6.43   -3.21
   -2.21    3.14   10.21   -4.23  -11.81   -4.04    4.39   -0.45
   -3.13    7.17   -4.35    7.47    2.07   -1.28    2.77   -0.46
   -5.50    2.15   -1.03    2.75   -1.45    7.04    5.41   -6.82
    6.31    1.05   -5.47   -9.59   -6.00    4.25   -5.16   -0.61
    4.91   -0.43    5.77    5.55   -3.62   -3.80   -5.21    3.45

根据运行结果可见
check ok, max_error = 0.000000
check ok, max_error = 0.000000
是否可以证明tCrC_r2sx(, i + j)与tCsC_r2s(, 0, 0, j)的数据类型均为T,不需要转换呢?
请reed老师指点

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.