Code Monkey home page Code Monkey logo

n3net's Introduction

Neural Nearest Neighbors Networks (NeurIPS*2018)

Official implementation of the denoising (PyTorch) and correspondence classification (Tensorflow) N3NET, that will be published in our NeurIPS paper:

Tobias Plötz and Stefan Roth, Neural Nearest Neighbors Networks, Advances in Neural Information Processing Systems (NeurIPS), 2018

Preprint: https://arxiv.org/abs/1810.12575

NeurIPS Proceedings: https://papers.nips.cc/paper/7386-neural-nearest-neighbors-networks

Contact: Tobias Plötz ([email protected])

Installation

The denoising code is tested with Python 3.6, PyTorch 0.4.1 and Cuda 8.0 but is likely to run with newer versions of PyTorch and Cuda.

To install PyInn run

```
pip install git+https://github.com/szagoruyko/pyinn.git@master
```

Further requirements can be installed with

```
pip install -r requirements.txt
```

Update March 28, 2019 Since tensor comprehensions is not maintained anymore, we provide a memory and time efficient implementation of an indexed matrix multiplication. To build the corresponding cuda kernel please cd into lib and run

```
python setup.py install
```

Please download the BSDS500, Urban100 and Set12 datasets by cd'ing into datasets/ and using the scripts provided therein. If you want to train your own Poisson-Gaussian denoising model, please additionally download the DIV2k dataset and the Waterloo dataset.

For setting up the correspondence classification code, please clone the repository (https://github.com/vcg-uvic/learned-correspondence-release), follow their instructions to setup your environment and copy the files located in src_correspondence/. This can also conveniently be done using the script src_correspondence/clone_CNNet.sh.

Running the denoising code

To run the following commands, please cd into src_denoising.

To train a new model run:

```
python main.py <options>
```

To test your model run:

```
python main.py --eval --eval_epoch <epoch> --evaldir <dir>
```

To test our pretrained networks run:

```
python main.py --eval --eval_epoch 51 --evaldir pretrained_sigma<25|50|70>
```

A note on pretrained models. The pretrained models will give slightly different results than in the paper on Set5 due to differences in the random seed. On the other, larger datasets, results are as in the paper. Furthermore, in the paper we used a strong learning rate decay. Training our model again with a slower decay yields better results and we will add new pretrained models soon.

Running the correspondence classification code

To run the following commands, please cd into src_correspondence/CNNet.

To train a new model run:

```
python main.py --run_mode=train --net_arch=nips_2018_nl <options>
```

We provide two pretrained models. One is trained on the Brown indoor dataset, the other is trained on the St. Peters outdoor dataset. To test our pretrained models run:

```
python main.py --run_mode=test --net_arch=nips_2018_nl  --data_te=<"brown_bm_3_05"|"st_peters"|"reichstag"> --data_va=<"brown_bm_3_05"|"st_peters"|"reichstag">
--log_dir="<pretrained_brown|pretrained_stpeters>"
--test_log_dir=<dir>
```

Using the Neural Nearest Neighbors Block for Your Project

The core of the PyTorch implementation is located in src_denoising/models/non_local.py which provides classes for neural nearest neighbors selection (NeuralNearestNeighbors), a domain agnostic N3Block (N3AggregationBase) and a N3Block tailored towards image data (N3Aggregation2D). The file src_denoising/models/n3net.py contains the N3Net module that uses the 2D N3Block as non-local processing layer.

The core of the Tensorflow implementation is located in src_correspondence/non_local.py which provides analogous functionality as above.

Citation

@inproceedings{Ploetz:2018:NNN,
  title     = {Neural Nearest Neighbors Networks},
  author    = {Pl\"otz, Tobias and Roth, Stefan},
  booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
  year = {2018}
}

n3net's People

Contributors

tobiasploetz avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.