Code Monkey home page Code Monkey logo

save_mlp_as_tcnn_json's Introduction

save_torch_mlp_as_tcnn_json

This repository focuses on Training a Multi-Layer Perceptron (MLP) using PyTorch and saving the weights in JSON format for tinycudann. The motivation behind creating this repository stems from the observation that the tinycudann trainer lacks the efficacy of its PyTorch counterpart. In particular, the training loss in tinycudann may not decrease as rapidly as in PyTorch, and there is a risk of the training process failing to converge. To avoid such disadvantages of tinycudann but also leverage its lightning-fast inference speed, I created this repository which:

  1. Trains an MLP in pytorch;
  2. Saves the pytorch MLP weights as json files for tinycudann to load;
  3. Loads json weights in tinycudann for fast inference.

Pay attention to

  1. The input channels must be a multiple of 16, since the hardware matrix multipliers (TensorCores) operate on 16x16 matrix chunks. (refer to NVlabs/tiny-cuda-nn#6 for more details)
  2. The linear layers of tinycudann has no bias, so when training a pytorch MLP, set bias=False for nn.Linear
  3. The meaning of hidden_layers is not the same for a pytorch MLP and a tinycudann MLP, usually, the number of hidden layers in a pytorch MLP equals to the number of hidden layers in a tinycudann MLP + 1

Run an example

Simply run the save_mlp_as_tcnn_json.py file and you will see that the pytorch MLP and tinycudann MLP output nearly the same results.

save_mlp_as_tcnn_json's People

Contributors

yijie21 avatar

Stargazers

Godzilla avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.