Code Monkey home page Code Monkey logo

capsnet_guide_pytorch's Introduction

CapsNet Guide in PyTorch

Contributions welcome License

This is my PyTorch implementation of CapsNet in Hinton's paper Dynamic Routing Between Capsules. I try to implement in a style that helps newcomers understand the architecture of CapsNet and the idea of Capsules. Therefore I am not going to wrap the codes into capsule layer APIs, and I more oftenly declare constants rather than passing parameters to functions, neither are the codes optimized for speed. The classes and functions are supplemented with detailed comments. To read and understand the codes, simply start from the comments in main.py and follow the order list at the head of the file.

The figure below clearly illustrates the core idea of capsules:

capsuleVSneuron

As I am busy these days, I might not have time to checkout and fix every issue. But contributions are highly welcomed. If you find any bugs or errors in the codes, please do not hesitate to open an issue or a pull request.

Requirements

  • pytorch 0.4.1
  • torchvision
  • pytorch-extras (For one-hot vector conversion)
  • tensorboard-pytorch
  • tqdm

All codes are tested under Python 3.6.

Get Started

After cloning the repository, simply run the command:

python main.py

The codes will automatically download the MNIST dataset (if not exist) into ./data and start traing and testing. Tensorboard logs are automatically saved in ./runs, and model checkpoints are saved in ./ckpt by default.

Default parameters are defined in get_opts() in utils.py, which are listed below and can be changed by passing arguments. (e.g. python main.py -batch_size 128 -epochs 30)

-batch_size     64      # Data batch size for training and testing
-lr             1e-3    # Learning rate
-epochs         10      # Train epochs
-r              3       # Number of iterations of for Dynamic Routing
-use_cuda       True    # Use GPU
-print_every    10      # Interval of batches to print out losses

Based on my own experiments, one train epoch takes about 6 minutes on a GTX 1080Ti with default settings. (I set 10 epochs as default just for showcasing, you should try 30 or more.)

Results

I have not fine tuned the parameters, results below are obtained with the default configurations in the codes. You can find this Tensorboard log in ./runs/sample/. Please do make a pull request if you find out the best parameters :)

Train

train

Test

test

Reconstructed images

e1_l28.29 e2_l30.37 e3_l27.20 e4_l28.29 e5_l26.02 e6_l25.26 e7_l25.02 e8_l24.86 e9_l24.83 e10_l24.78

Reference

[1] Sabour, Sara, Nicholas Frosst, and Geoffrey E. Hinton. "Dynamic Routing Between Capsules." arXiv preprint arXiv:1710.09829 (2017).

capsnet_guide_pytorch's People

Contributors

alexlimh avatar laubonghaudoi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

capsnet_guide_pytorch's Issues

About the input dimension of fc1 of decoder

Hello, thanks for your code, which helps me strengthen the understanding of the capsnet. I have a question as follows.

For all the capsules, it seems better to share fc1 parameters of decoder (so the input dimension is 16 instead of 10 * 16). And this repo uses input dimension of 16. What is your consideration?

there is a bit of glitch about batch_size

When I set the batch_size == 128, its performance downgraded significantly,namely its accuracy is about 15%.I wonder if there is any problem with the network implementation.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.