Code Monkey home page Code Monkey logo

capsnet-tensorflow's Introduction

Dynamic Routing Between Capsules

reference: Dynamic routing between capsules by Sara Sabour, Nicholas Frosst, Geoffrey E Hinton

Note: this implementation strictly follow the instructions of the paper, check the paper for details.

Takeaways

The key of the paper is not how accurate the CapsNet is, but the novel idea of representation of image with capsule.

Dependencies

  • Codes are tested on tensorflow 1.3, and python 2.7. But it should be compatible with python 3.x
  • Other dependencies as follows,
six>=1.11
matplotlib>=2.0.2
numpy>=1.7.1
scipy>=0.13.2
easydict>=1.6
tqdm>=4.17.1

install by running

$ cd $ROOT
$ pip install -r requirements.txt

Experiments

NOTE: all the experiments conducted on the checkpoint: Jbox(SJTU) or Google_Drive

reconstruction

By running:

$ cd code
$ python eval.py --ckpt 'path/to/ckpt' --mode reconstruct

reconstruct results: (Note: the float numbers on the row with even number are max norm of the 10 digit capsules)

1 2

capsule unit representation

By running:

$ cd code
$ python eval.py --ckpt 'path/to/ckpt' --mode cap_tweak

results:

Note: images along x-axis are representations of units of 16-D vector, and y-axis corresponds to the tweak range of [-0.25, 0.25] with stride 0.05.

cap_tweak-1 cap_tweak-1

adversarial test

By running:

$ cd code
$ python eval.py --ckpt 'path/to/ckpt' --mode adversarial

result:

adver-1 adver-2 adver-3

the adversarial result is not as good as i expected, I was hoping that capsule representation would be more robust to adversarial attack.

training

Note: all trained with batch_size = 100

latest commit with 3 iterations of dynamic routing:

1. update dynamic routing with tf.while_loop and static way
2. fix margin loss issue

result:

Iterations 1k 2k 3k 4k 5k
val_acc 98.90 99.16 99.09 99.30 99.24
test_acc - - - - 99.21

commit 8e3785d.

with bugs:
1. wrong implementation of margin loss
2. updating `prior` during routing 

result:

Iterations 2k 4k 5k 7k 9k 10k
val_acc 98.02 98.58 - 98.82 98.96 -
test_acc - - 98.89 - - 99.09

Train

  • clone the repo, and set up parameters in code/config.py
  • then
$ cd $ROOT/code
$ python train.py --data_dir 'path/to/data' --max_iters 10000 --ckpt 'OPTIONAL:path/to/ckpt' --batch_size 100

or train with logs by runing(NOTE: set extra arguments in train.sh accordingly):

$ cd $ROOT/code
$ bash train.sh
  • The less accurate may due to the missing 3M parameters.(My implementaion with 8M compared to 11M referred in the paper.) Different input size.
  • The model is still under-fitting.

TODO

  • report exclusive experiment results
  • try to fix the inefficacy

Reference

capsnet-tensorflow's People

Contributors

innerpeace-wu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

capsnet-tensorflow's Issues

Missing weight sharing?

Hi
Thank you for sharing your implementation. It is quite helpful!

However, I think you missed the weight sharing part (see page 4, under figure 2, each capsule in [6,6] grid is sharing their weights). So my understanding is that your cap_ws should have the last dim of 32, instead of 1152.

Although I tested both cases, in the early training stage, I didn't see too much difference.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.