Code Monkey home page Code Monkey logo

jumanji-benchmarks's Introduction

Jumanji Benchmarks

This project contains benchmarks for the jumanji 2048 environment.

Dashboard

To view a dashboard displaying all benchmark results run the following command and navigate to localhost:8050.

docker run -it --rm -p 8050:8050 ghcr.io/aar65537/jumanji-benchmarks:main

Improvements

no vmap vmap 103 vmap 106
cpu 64.36% 201.80% 392.29%
cuda 900.12% 1923.08% 706.87%

The above figure shows the total performance increase with all changes (measured as percent increase in steps/sec). The no vmap environments were wrapped with AutoResetWrapper. The cpu vmap environments were wrapped with VmapAutoResetWrapper. The cuda vmap environments were wrapped with AutoResetWrapper and then VmapWrapper. I found that VmapAutoResetWrapper had poor performance on the gpu.

The improvements fall into three main categories: minimizing conditional logic, preferring jax.vmap over jax.lax control flow, and algorithmic improvements. Minimizing conditional logic is important because when wrapped with jax.vmap, all branches of a conditional expression will be evaluated. Using jax.vmap instead of jax.lax control flow seems to reduce overhead when running on the gpu. Algorithmic improvements include an optimized move algorithm and a can move algorithm that doesn't mutate the board.

2e9f0186: Remove call to move inside jax.lax.switch

no vmap vmap 103 vmap 106
cpu -3.05% 74.75% 47.83%
cuda -0.06% 51.00% 45.00%

The current environment selects the correct move in a step using the following switch statement.

updated_board, additional_reward = jax.lax.switch(
    action,
    [move_up, move_right, move_down, move_left],
    state.board,
)

The problem is when vectorized all branches of the switch will be evaluated. So each call to step will perform all actions, not just the action you want. The solution is to only transform the board in the switch and perform the move outside of the switch.

updated_board, additional_reward = move(state.board, action)

def move(board, action, final_shift = True):
    board = transform_board(board, action)
    board, additional_reward = move_up(board, final_shift)
    board = transform_board(board, action)
    return board, additional_reward

def transform_board(board, action):
    return jax.lax.switch(
        action,
        [
            lambda: board,
            lambda: jnp.flip(jnp.transpose(board)),
            lambda: jnp.flip(board),
            lambda: jnp.transpose(board),
        ],
    )

This implementation avoids actually calling the expensive move_up inside the switch statement.

ca2e4ba5: Remove call to set inside jax.lax.cond

no vmap vmap 103 vmap 106
cpu -12.21% 58.71% 22.45%
cuda -9.24% 13.10% 30.53%

Currently, the environment shifts column elements using the following conditional statement.

def shift_nonzero_element(col, j, i):
    col = col.at[j].set(col[i])
    return col, j + 1

col, j = jax.lax.cond(
    col[i] != 0,
    shift_nonzero_element,
    lambda col, j, i: col, j,
    col, j, i
)

Again, we see that there is unnecessary logic inside a conditional. However in this case the source of the slow down is a a bit more obtuse. The problem is actually the line col = col.at[j].set(col[i]). This line is supposed to be performed in place, but since both branches of the conditional must be computed that isn't possible. Instead, a new copy of col will be created. This copy can be avoided by mutating the array outside of the conditional.

new_col_j, new_j = jax.lax.cond(
    col[i] != 0,
    lambda col, j, i: (col[i], j + 1),
    lambda col, j, i: (col[j], j),
    col, j, i
)
col = col.at[j].set(new_col_j)

A similar problem was also present when merging tiles.

8f9a67bd: Change jax.lax.scan to jax.vmap

no vmap vmap 103 vmap 106
cpu -11.71% -57.18% -40.42%
cuda 337.66% 219.20% 25.32%

Currently, the environment moves each column of the board with the following scan expression.

(board, additional_reward), _ = jax.lax.scan(
    f=functools.partial(move_up_col, final_shift=final_shift),
    init=(board, 0.0),
    xs=jnp.arange(board.shape[0]),
)

However, a scan isn't really necessary because each row can be moved independently. In this case, it is possible to rewrite the function so it can be wrapped with jax.vmap. I found that this reduced overhead on the gpu.

board, additional_reward = jax.vmap(move_left_row, (0, None))(board, final_shift)

77fccd6e: Implement new move algorithm

no vmap vmap 103 vmap 106
cpu 20.96% -8.01% 33.46%
cuda 33.44% 63.62% 28.01%

I also implemented a new optimized move algorithm that only uses a single while loop.

f5c34b3c: Implement can_move algorithm

no vmap vmap 103 vmap 106
cpu 80.83% 176.21% 242.01%
cuda 88.79% 126.83% 165.74%

Currently, the environment checks if an action is valid by performing the action and seeing if any tiles changed.

jnp.any(move_up(board, final_shift=False)[0] != board)

Even with the optimization of not performing the final shift, this is fairly expensive. I implemented a can move algorithm that can validate an action without mutating the board.

jumanji-benchmarks's People

Contributors

aar65537 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.