Code Monkey home page Code Monkey logo

siren-jax's Introduction

Siren-Jax

Unofficial implementation of Siren with Jax. This code reproduces image-related results in the original Siren papaer.

What is Siren?

It is a novel neural network that is proposed in the Implicit Neural Representations with Periodic Activation Functions by Sitzmann et al.

Siren uses sine functions as activation functions and it can represent continous differentiable signals bettern than networks with other activation functions.

If you want to know more about Siren, please check out the project page.

Why Jax?

Jax is very fast and convinient for calculating derivatives of a network. Indeed, the training speed is much faster than the Pytorch implementation from the paper. This Jax implementation takes less than 12 minutes for a training, and the original implementation takes 90 minuites for a test as the paper says.

Speed Test Results

Details

  • 1 Nvidia RTX 2080 Ti
  • Cuda 10.2
  • Ubuntu 18.04
  • Image size : 256 X 256
  • Single Batch
Vanillla Training Gradient Training Laplacian Training
110 seconds 311 seconds 661 seconds

How to use?

1. Install Jax

Please follow the official install guide.

2. Install packages

$ pip install -r requirements.txt

3. Train

This code runs the default training option. Please run python train.py -h to see other options.

$ python train.py --file reference.jpg

4. Test

$ python test.py --run_name reference

Example Results

This section shows results of implicite image representation and solving Possion equation. Training settings are all same unless it is metioned.

  • learning rate: 0.0001
  • single batch
  • 256 X 5 layers
  • epoch: 10000

Reproducing Paper Results

Results were almost same with the paper.

Training with color vanilla image

<ground truth -- vanilla network output>

vanilla color result

Training with gray vanilla image

<ground truth -- vanilla output -- gradient output -- laplacian output>

vanilla gray result

Training with gradient

<ground truth -- gradient output -- vanilla output -- laplacian output>

gradient result

Training with laplacian

<ground truth -- laplacian output -- vanilla output -- gradient output>

laplacian result

Batch Effect

The original paper only tested with single batch. I was curious the effect of batch size and I did a test.

  • batch size: 16384 (4 batches per epoch)

<trained with vanilla image -- trained with gradient, -- trained with laplacian -- ground truth>

batch result

It seems like using batches leads to worse result, as the number of derivative increases. The result of laplacian trained network is very much different with the ground truth image, compare to others.

Upscale Test

If a network can represent an image in an continous way, then it might be possible to create a higher resolution image. So I created larger images with the trained networks and compared them with an interpolated image.

<trained with vanilla image -- trained with gradient -- trained with laplacian -- interpolated (using PIL)>

upscale result

The network outputs are not so far better than the interplated image. The gradient trained image is blurry and laplacian trained image's color is not accurate. However, network generated images are much smoother than the interpolated image.

<trained with vanilla image -- trained with laplacian -- interpolated>

zoomed upscale result

These zoomed images. The laplacian trained image is very smooth compare to the interplated image.

What is more?

If you are interested in implicit representation, please check out Awesome Implicit Neural Prepresentations.

The curiosity that led me to reimplement the Jax was that whether it is possible to make a high-resolution image without any other dataset except the source image. Apparently, it does not work by simply genereating a larger image with a network that was trained with a small image. I'm trying some other stuffes, but not sure whether it will work.

siren-jax's People

Contributors

keunwoopark avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

siren-jax's Issues

"'tuple' object has no attribute 'total'" error when running train.py

Hi! I am trying to train the siren and ran into this bug:

python train.py --file reference.jpg

Traceback (most recent call last):
  File "train.py", line 113, in <module>
    main(args)
  File "train.py", line 62, in main
    model = Model(layers, args.nc, args.omega)
  File "/ubc/cs/research/kmyi/svsamban/research/siren-jax/siren/model.py", line 22, in __init__
    self.net = self.create_network(layers, n_channel, omega)
  File "/ubc/cs/research/kmyi/svsamban/research/siren-jax/siren/model.py", line 51, in create_network
    return Siren(2, layers, n_channel, omega)
  File "/ubc/cs/research/kmyi/svsamban/research/siren-jax/siren/network.py", line 45, in __init__
    net_params, net_apply = create_mlp(input_dim, layers, output_dim, omega)
  File "/ubc/cs/research/kmyi/svsamban/research/siren-jax/siren/network.py", line 30, in create_mlp
    out_shape, net_params = net_init_random(rng, in_shape)
  File "/scratch/soft/anaconda3/envs/nerfies-siren/lib/python3.8/site-packages/jax/experimental/stax.py", line 295, in init_fun
    input_shape, param = init_fun(layer_rng, input_shape)
  File "/ubc/cs/research/kmyi/svsamban/research/siren-jax/siren/layer.py", line 14, in init_fun
    W, b = W_init(k1, (input_shape[-1], out_dim)), b_init(
  File "/ubc/cs/research/kmyi/svsamban/research/siren-jax/siren/initializer.py", line 18, in init
    fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
  File "/scratch/soft/anaconda3/envs/nerfies-siren/lib/python3.8/site-packages/jax/_src/nn/initializers.py", line 49, in _compute_fans
    receptive_field_size = shape.total / shape[in_axis] / shape[out_axis]
AttributeError: 'tuple' object has no attribute 'total'
(nerfies-siren) svsamban@salty:/ubc/cs/research/kmyi/svsamban/research/siren-jax$ 

I am using jax version 0.2.20 and jaxlib version 0.1.69+cuda111 with tensorflow 2.4.0.I have a ubuntu 20.04 system with cuda 11.4. If you have any insight on how to circumvent this error please let me know!

Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.