Code Monkey home page Code Monkey logo

fast_flops's Introduction

Fast FLOPS

... and where to find them ๐Ÿ‡

Ever needed to report FLOPs for your Torch/JAX code. We got you covered ! (More importantly, here's a survey on why bunnies flop)

Borrowed from https://gitlab.com/NERSC/roofline-on-nvidia-gpus/-/tree/roofline-hackathon-2020

Nsight Compute installation needed. Make sure the GPU counters are enabled

bunny

(Image Credit: Tuong Phung)

Workflow (Warning: Extremely clunky right now)

  • Define your function that you wanna profile

    • JAX

      SIZE = 500
      x = jax.random.normal(jax.random.PRNGKey(0), (SIZE, SIZE))
      y = jax.random.normal(jax.random.PRNGKey(1), (SIZE, SIZE))
      
      def func(x, y):
          return jnp.einsum("ij,jk->ik", x, y)
    • Torch

      SIZE = 500
      x = torch.randn(SIZE, SIZE).to(device='cuda')
      y = torch.randn(SIZE, SIZE).to(device='cuda')
      
      def func(x, y):
          return torch.einsum("ij,jk->ik", x, y)
  • Wrap the function in the flops_counter decorator

    from fast_flops import func_flops
    
    @flops_counter
    def func_flops(func, x, y):
        return func(x, y)
  • Let it run through JIT (Don't worry we have warmups to cleanse the JIT overhead) and execute !

    • JAX

      func = jax.jit(func)
      func_flops(func, x, y)
    • Torch

      func = torch.compile(func, fullgraph=True, mode='max-autotune')
      func_flops(func, x, y)
  • The pipeline can be executed using

    bash launch_profiler.sh examples/matmul/test_matmul_torch.py

    with the output looking something like

    Measured Time: 3.196823494576654e-05
    Measured GFLOP/s: 2353.1630879907675
    Measured FLOPS: 67108864.0

TODOs

  • Only turn on ncu during the hot loop in addition to nvtx
  • Axe run_profiler.py and postprocess.py. Ideally should work with python decorated_file.py
  • Plotting utilieis and csv plumbing

fast_flops's People

Contributors

mitkotak avatar

Stargazers

David Marx avatar Nathan Raw avatar Xinle Cheng avatar Abhijith S Parackal avatar Shengjie Luo avatar  avatar Tuong Phung avatar Shidi Tang avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.