Code Monkey home page Code Monkey logo

capsnet-keras's Introduction

CapsNet-Keras

License

Now Val_acc>99.5%. A Keras implementation of CapsNet in the paper:
Sara Sabour, Nicholas Frosst, Geoffrey E Hinton. Dynamic Routing Between Capsules. NIPS 2017

Recent updates:

  • Released v0.1 to use. No frequent changes will be made.

TODO

  • Keep debugging to improve the accuracy. The learning rate decay can be tuned.
  • The model has 8M parameters, while the paper said it should be 11M. I'll figure out what's the problem.
  • It is time to do something with CapsuleNet...(LOL)

Contribution

  • Your contribution to the repo is welcome. Open an issue or contact me with [email protected] or WeChat (微信号) wenlong-guo.

Requirements

Usage

Training

Step 1. Install Keras:

$ pip install keras

Step 2. Clone this repository with git.

$ git clone https://github.com/xifengguo/CapsNet-Keras.git
$ cd CapsNet-Keras

Step 3. Training:

$ python capsulenet.py

Training with one routing iteration (default 3).

$ python capsulenet.py --num_routing 1

Other parameters include batch_size, epochs, lam_recon, shift_fraction, save_dir can passed to the function in the same way. Please refer to capsulenet.py

Testing

Suppose you have trained a model using the above command, then the trained model will be saved to result/trained_model.h5. Now just launch the following command to get test results.

$ python capsulenet.py --is_training 0 --weights result/trained_model.h5

It will output the testing accuracy and show the reconstructed images. The testing data is same as the validation data. It will be easy to test on new data, just change the code as you want (Of course you can do it!!!)

If sadly you do not have a good computer to train the model (sad face), you can download a model I trained from https://pan.baidu.com/s/1hsF2bvY

Results

Main result
by launching python capsulenet.py: The epoch=1 means the result is evaluated after training one epoch. In the saved log file, it starts from 0.

Epoch 1 5 10 15 20
train_acc 90.65 98.95 99.36 99.63 99.75
vali_acc 98.51 99.30 99.34 99.49 99.59

Losses and accuracies:

Results with one routing iteration
by launching python CapsNet.py --num_routing 1

Epoch 1 5 10 15 20
train_acc 89.64 99.02 99.42 99.66 99.73
vali_acc 98.55 99.33 99.43 99.57 99.58

Every epoch consumes about 110s on a single GTX 1070 GPU.

NOTE: The training is still under-fitting, welcome to try for your own.
The learning rate decay is not fine-tuned, I just tried this one. You can tune this.

Testing result
The result by launching
python capsulenet.py --is_training 0 --weights result/trained_model.h5

The model structure:

Other Implementations

capsnet-keras's People

Stargazers

 avatar

Watchers

 avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.