Code Monkey home page Code Monkey logo

pytorch-capsulenet's Introduction

Pytorch-CapsuleNet

A flexible and easy-to-follow Pytorch implementation of Hinton's Capsule Network.

There are already many repos containing the code for CapsNet. However, most of them are too tight to customize. And as we all know, Hinton's original paper is only tested on MNIST datasets. We clearly want to do more.

This repo is designed to hold other datasets and configurations. And the most important thing is, we want to make the code flexible. Then, we can tailor the network according to our needs.

Currently, the code supports both MNIST and CIFAR-10 datasets.

Requirements

  • Python 3.x
  • Pytorch 0.3.0 or above
  • Numpy
  • tqdm (to make display better, of course you can replace it with 'print')

Run

Just run Python test_capsnet.py in your terminal. That's all. If you want to change the dataset (MNIST or CIFAR-10), you can easily set the dataset variable.

It is better to run the code on a server with GPUs. Capsule network demands good computing devices. For instance, on my device (Nvidia K80), it will take about 5 minutes for one epoch of the MNIST datasets (batch size = 100).

More details

There are 3 .py files:

  • capsnet.py: the main class for capsule network
  • data_loader.py: the class to hold many classes
  • test_capsnet.py: the training and testing code

The results on your device may look like the following picture:

Acknowledgements

pytorch-capsulenet's People

Contributors

jindongwang avatar raoashish10 avatar vinhtq115 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

pytorch-capsulenet's Issues

Would this capsule net be able to handle 2 classes?

Would this capsule net be able to handle 2 classes? I'm playing with the same sort of dataset to cifar10 but with 2 classes only.

I've amended the code where the digits capsule takes in 2 inputs.

Wherever torch.eye is called I've changed it from 10 to 2

And I've amended the linear layer in the decoder to nn.Linear(16 * 2, 512)

It runs but I just want to make sure this would be correct...

How to change the number of categories of a classification to apply to your own dataset?

File "C:\Users\23671\Desktop\CNN\self_conv\DDR\CNN_2d.py", line 405, in forward
u = u.view(x.size(0), self.num_routes, -1)
RuntimeError: shape '[2, 1152, -1]' is invalid for input of size 1125888

When applied to my own dataset, the code error input size is greater than the convolutional kernel size, so I changed padding to 1 and changed the convolutional kernel size at the same time, but the error was reported as above, I guess the number of categories that did not change the classification was 5, but I don't know where to change it

胶囊网络前N-1次迭代停止梯度的问题

目前看到的多个版本胶囊网络都是在前N-1次迭代时停止u_hat的梯度,只在最后一次开放梯度。只有作者这个版本没有这样操作,请问作者是否清楚这里的原因

After a few Epochs, Loss Would Be NaN

Hello, I tried to test your code but after a few epochs (8 at most) the loss of the network turns out to be NaN. Has anyone else faced this issue while testing? How can I fix it
Untitled-1
?

Undefined Gradients due to division by zero.

squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)

To make it stable, following is suggested:

def squash(self, input_tensor, epsilon=1e-7):
    squared_norm = (input_tensor ** 2 + epsilon).sum(-1, keepdim=True)
    output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
    return output_tensor

mistake in Decoder wrt to non square images.

Hello Jin Dong Wang,

I noticed in the:

class Decoder(nn.Module): def __init__(self, input_width=28, input_height=28, input_channel=1): super(Decoder, self).__init__() self.input_width = input_width self.input_height = input_height self.input_channel = input_channel self.reconstraction_layers = nn.Sequential( nn.Linear(16 * LABEL_SIZE, 512), nn.ReLU(inplace=True), nn.Linear(512, 1024), nn.ReLU(inplace=True), nn.Linear(1024, self.input_height * self.input_height * self.input_channel), nn.Sigmoid() )

The line: nn.Linear(1024, self.input_height * self.input_height * self.input_channel),

Should be: nn.Linear(1024, self.input_height * self.input_width * self.input_channel),

Do you agree?

Missing License

Hello,
I really like the implementation, but would it be possible to add a license (e.g. MIT) to the repository , so that the code can be reused?

Low test accuracy of 0.1

Why does your CapsuleNet coding resulted in test accuracy of only 0.1 ?

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar/cifar-10-python.tar.gz
170500096it [00:02, 74718714.99it/s]
Files already downloaded and verified
Epoch: [1/30], Batch: [1/500], train accuracy: 0.080000, loss: 0.009003
Epoch: [1/30], Batch: [101/500], train accuracy: 0.400000, loss: 0.008004
Epoch: [1/30], Batch: [201/500], train accuracy: 0.390000, loss: 0.008002
Epoch: [1/30], Batch: [301/500], train accuracy: 0.340000, loss: 0.008003
Epoch: [1/30], Batch: [401/500], train accuracy: 0.400000, loss: 0.007995
100% 500/500 [02:44<00:00, 2.95it/s]
Epoch: [1/30], train loss: 0.008047
Epoch: [1/30], test accuracy: 0.431800, loss: 0.797698
Epoch: [2/30], Batch: [1/500], train accuracy: 0.480000, loss: 0.007946
Epoch: [2/30], Batch: [101/500], train accuracy: 0.360000, loss: 0.007998
Epoch: [2/30], Batch: [201/500], train accuracy: 0.470000, loss: 0.007937
Epoch: [2/30], Batch: [301/500], train accuracy: 0.510000, loss: 0.007817
Epoch: [2/30], Batch: [401/500], train accuracy: 0.400000, loss: 0.007882
100% 500/500 [02:45<00:00, 3.04it/s]
Epoch: [2/30], train loss: 0.007874
Epoch: [2/30], test accuracy: 0.468100, loss: 0.769648
Epoch: [3/30], Batch: [1/500], train accuracy: 0.480000, loss: 0.007770
Epoch: [3/30], Batch: [101/500], train accuracy: 0.500000, loss: 0.007449
Epoch: [3/30], Batch: [201/500], train accuracy: 0.490000, loss: 0.007445
Epoch: [3/30], Batch: [301/500], train accuracy: 0.410000, loss: 0.007163
Epoch: [3/30], Batch: [401/500], train accuracy: 0.450000, loss: 0.007121
100% 500/500 [02:44<00:00, 3.04it/s]
Epoch: [3/30], train loss: 0.007182
Epoch: [3/30], test accuracy: 0.521300, loss: 0.662147
Epoch: [4/30], Batch: [1/500], train accuracy: 0.470000, loss: 0.007082
Epoch: [4/30], Batch: [101/500], train accuracy: 0.560000, loss: 0.006161
Epoch: [4/30], Batch: [201/500], train accuracy: 0.560000, loss: 0.006386
Epoch: [4/30], Batch: [301/500], train accuracy: 0.590000, loss: 0.006173
Epoch: [4/30], Batch: [401/500], train accuracy: 0.560000, loss: 0.006173
100% 500/500 [02:44<00:00, 3.04it/s]
Epoch: [4/30], train loss: 0.006349
Epoch: [4/30], test accuracy: 0.555800, loss: 0.612919
Epoch: [5/30], Batch: [1/500], train accuracy: 0.590000, loss: 0.005984
Epoch: [5/30], Batch: [101/500], train accuracy: 0.540000, loss: 0.005862
Epoch: [5/30], Batch: [201/500], train accuracy: 0.560000, loss: 0.006026
Epoch: [5/30], Batch: [301/500], train accuracy: 0.580000, loss: 0.005553
Epoch: [5/30], Batch: [401/500], train accuracy: 0.680000, loss: 0.005457
100% 500/500 [02:44<00:00, 3.05it/s]
Epoch: [5/30], train loss: 0.005918
Epoch: [5/30], test accuracy: 0.590500, loss: 0.580702
Epoch: [6/30], Batch: [1/500], train accuracy: 0.660000, loss: 0.005525
Epoch: [6/30], Batch: [101/500], train accuracy: 0.580000, loss: 0.005722
Epoch: [6/30], Batch: [201/500], train accuracy: 0.610000, loss: 0.005729
Epoch: [6/30], Batch: [301/500], train accuracy: 0.580000, loss: 0.005752
Epoch: [6/30], Batch: [401/500], train accuracy: 0.570000, loss: 0.005795
100% 500/500 [02:44<00:00, 3.05it/s]
Epoch: [6/30], train loss: 0.005638
Epoch: [6/30], test accuracy: 0.602900, loss: 0.562709
Epoch: [7/30], Batch: [1/500], train accuracy: 0.640000, loss: 0.005054
Epoch: [7/30], Batch: [101/500], train accuracy: 0.630000, loss: 0.005274
Epoch: [7/30], Batch: [201/500], train accuracy: 0.640000, loss: 0.005245
Epoch: [7/30], Batch: [301/500], train accuracy: 0.600000, loss: 0.005418
Epoch: [7/30], Batch: [401/500], train accuracy: 0.630000, loss: 0.004710
100% 500/500 [02:44<00:00, 3.06it/s]
Epoch: [7/30], train loss: 0.005448
Epoch: [7/30], test accuracy: 0.616700, loss: 0.547133
Epoch: [8/30], Batch: [1/500], train accuracy: 0.650000, loss: 0.005194
Epoch: [8/30], Batch: [101/500], train accuracy: 0.690000, loss: 0.004807
Epoch: [8/30], Batch: [201/500], train accuracy: 0.690000, loss: 0.005423
Epoch: [8/30], Batch: [301/500], train accuracy: 0.640000, loss: 0.005231
Epoch: [8/30], Batch: [401/500], train accuracy: 0.670000, loss: 0.004982
100% 500/500 [02:44<00:00, 3.04it/s]
Epoch: [8/30], train loss: 0.005282
Epoch: [8/30], test accuracy: 0.632800, loss: 0.534200
Epoch: [9/30], Batch: [1/500], train accuracy: 0.640000, loss: 0.005729
Epoch: [9/30], Batch: [101/500], train accuracy: 0.700000, loss: 0.004558
Epoch: [9/30], Batch: [201/500], train accuracy: 0.640000, loss: 0.005404
Epoch: [9/30], Batch: [301/500], train accuracy: 0.670000, loss: 0.005204
Epoch: [9/30], Batch: [401/500], train accuracy: 0.600000, loss: 0.005491
100% 500/500 [02:44<00:00, 3.05it/s]
Epoch: [9/30], train loss: 0.005110
Epoch: [9/30], test accuracy: 0.640100, loss: 0.525246
Epoch: [10/30], Batch: [1/500], train accuracy: 0.700000, loss: 0.004635
Epoch: [10/30], Batch: [101/500], train accuracy: 0.670000, loss: 0.005026
Epoch: [10/30], Batch: [201/500], train accuracy: 0.710000, loss: 0.004482
Epoch: [10/30], Batch: [301/500], train accuracy: 0.730000, loss: 0.004045
Epoch: [10/30], Batch: [401/500], train accuracy: 0.640000, loss: 0.005618
100% 500/500 [02:44<00:00, 3.04it/s]
Epoch: [10/30], train loss: 0.005005
Epoch: [10/30], test accuracy: 0.656500, loss: 0.507344
Epoch: [11/30], Batch: [1/500], train accuracy: 0.640000, loss: 0.005273
Epoch: [11/30], Batch: [101/500], train accuracy: 0.640000, loss: 0.005570
Epoch: [11/30], Batch: [201/500], train accuracy: 0.750000, loss: 0.004569
Epoch: [11/30], Batch: [301/500], train accuracy: 0.640000, loss: 0.005356
Epoch: [11/30], Batch: [401/500], train accuracy: 0.670000, loss: 0.004885
100% 500/500 [02:44<00:00, 3.02it/s]
Epoch: [11/30], train loss: 0.004867
Epoch: [11/30], test accuracy: 0.659600, loss: 0.506003
Epoch: [12/30], Batch: [1/500], train accuracy: 0.680000, loss: 0.004995
Epoch: [12/30], Batch: [101/500], train accuracy: 0.650000, loss: 0.004921
Epoch: [12/30], Batch: [201/500], train accuracy: 0.650000, loss: 0.005072
Epoch: [12/30], Batch: [301/500], train accuracy: 0.680000, loss: 0.004890
Epoch: [12/30], Batch: [401/500], train accuracy: 0.600000, loss: 0.005777
100% 500/500 [02:44<00:00, 3.04it/s]
Epoch: [12/30], train loss: 0.004758
Epoch: [12/30], test accuracy: 0.669700, loss: 0.492968
Epoch: [13/30], Batch: [1/500], train accuracy: 0.650000, loss: 0.004795
Epoch: [13/30], Batch: [101/500], train accuracy: 0.730000, loss: 0.004316
Epoch: [13/30], Batch: [201/500], train accuracy: 0.720000, loss: 0.003964
Epoch: [13/30], Batch: [301/500], train accuracy: 0.760000, loss: 0.004180
Epoch: [13/30], Batch: [401/500], train accuracy: 0.710000, loss: 0.004485
100% 500/500 [02:44<00:00, 3.06it/s]
Epoch: [13/30], train loss: 0.004673
Epoch: [13/30], test accuracy: 0.672600, loss: 0.498016
Epoch: [14/30], Batch: [1/500], train accuracy: 0.700000, loss: 0.004716
Epoch: [14/30], Batch: [101/500], train accuracy: 0.720000, loss: 0.004677
Epoch: [14/30], Batch: [201/500], train accuracy: 0.720000, loss: 0.004923
Epoch: [14/30], Batch: [301/500], train accuracy: 0.720000, loss: 0.003978
Epoch: [14/30], Batch: [401/500], train accuracy: 0.760000, loss: 0.004141
100% 500/500 [02:44<00:00, 3.03it/s]
Epoch: [14/30], train loss: 0.004577
Epoch: [14/30], test accuracy: 0.675800, loss: 0.487264
Epoch: [15/30], Batch: [1/500], train accuracy: 0.700000, loss: 0.004650
Epoch: [15/30], Batch: [101/500], train accuracy: 0.790000, loss: 0.003771
Epoch: [15/30], Batch: [201/500], train accuracy: 0.690000, loss: 0.004529
Epoch: [15/30], Batch: [301/500], train accuracy: 0.660000, loss: 0.005128
Epoch: [15/30], Batch: [401/500], train accuracy: 0.730000, loss: 0.004389
100% 500/500 [02:44<00:00, 3.06it/s]
Epoch: [15/30], train loss: 0.004505
Epoch: [15/30], test accuracy: 0.684100, loss: 0.483410
Epoch: [16/30], Batch: [1/500], train accuracy: 0.630000, loss: 0.005254
Epoch: [16/30], Batch: [101/500], train accuracy: 0.770000, loss: 0.004265
Epoch: [16/30], Batch: [201/500], train accuracy: 0.710000, loss: 0.004419
Epoch: [16/30], Batch: [301/500], train accuracy: 0.690000, loss: 0.004336
Epoch: [16/30], Batch: [401/500], train accuracy: 0.740000, loss: 0.004540
100% 500/500 [02:44<00:00, 3.04it/s]
Epoch: [16/30], train loss: 0.004415
Epoch: [16/30], test accuracy: 0.683000, loss: 0.482916
Epoch: [17/30], Batch: [1/500], train accuracy: 0.740000, loss: 0.004360
Epoch: [17/30], Batch: [101/500], train accuracy: 0.630000, loss: 0.005027
Epoch: [17/30], Batch: [201/500], train accuracy: 0.720000, loss: 0.004651
Epoch: [17/30], Batch: [301/500], train accuracy: 0.690000, loss: 0.004591
Epoch: [17/30], Batch: [401/500], train accuracy: 0.800000, loss: 0.003930
100% 500/500 [02:44<00:00, 3.09it/s]
Epoch: [17/30], train loss: nan
Epoch: [17/30], test accuracy: 0.100000, loss: nan
Epoch: [18/30], Batch: [1/500], train accuracy: 0.050000, loss: nan
Epoch: [18/30], Batch: [101/500], train accuracy: 0.100000, loss: nan
Epoch: [18/30], Batch: [201/500], train accuracy: 0.080000, loss: nan
Epoch: [18/30], Batch: [301/500], train accuracy: 0.060000, loss: nan
Epoch: [18/30], Batch: [401/500], train accuracy: 0.040000, loss: nan
100% 500/500 [02:39<00:00, 3.15it/s]
Epoch: [18/30], train loss: nan
Epoch: [18/30], test accuracy: 0.100000, loss: nan
Epoch: [19/30], Batch: [1/500], train accuracy: 0.070000, loss: nan
Epoch: [19/30], Batch: [101/500], train accuracy: 0.130000, loss: nan
Epoch: [19/30], Batch: [201/500], train accuracy: 0.110000, loss: nan
Epoch: [19/30], Batch: [301/500], train accuracy: 0.110000, loss: nan
Epoch: [19/30], Batch: [401/500], train accuracy: 0.170000, loss: nan
100% 500/500 [02:39<00:00, 3.15it/s]
Epoch: [19/30], train loss: nan
Epoch: [19/30], test accuracy: 0.100000, loss: nan
Epoch: [20/30], Batch: [1/500], train accuracy: 0.080000, loss: nan
Epoch: [20/30], Batch: [101/500], train accuracy: 0.110000, loss: nan
Epoch: [20/30], Batch: [201/500], train accuracy: 0.160000, loss: nan
Epoch: [20/30], Batch: [301/500], train accuracy: 0.090000, loss: nan
Epoch: [20/30], Batch: [401/500], train accuracy: 0.080000, loss: nan
100% 500/500 [02:39<00:00, 3.15it/s]
Epoch: [20/30], train loss: nan
Epoch: [20/30], test accuracy: 0.100000, loss: nan
Epoch: [21/30], Batch: [1/500], train accuracy: 0.090000, loss: nan
Epoch: [21/30], Batch: [101/500], train accuracy: 0.110000, loss: nan
Epoch: [21/30], Batch: [201/500], train accuracy: 0.100000, loss: nan
Epoch: [21/30], Batch: [301/500], train accuracy: 0.100000, loss: nan
Epoch: [21/30], Batch: [401/500], train accuracy: 0.110000, loss: nan
100% 500/500 [02:39<00:00, 3.15it/s]
Epoch: [21/30], train loss: nan
Epoch: [21/30], test accuracy: 0.100000, loss: nan
Epoch: [22/30], Batch: [1/500], train accuracy: 0.080000, loss: nan
Epoch: [22/30], Batch: [101/500], train accuracy: 0.160000, loss: nan
Epoch: [22/30], Batch: [201/500], train accuracy: 0.110000, loss: nan
Epoch: [22/30], Batch: [301/500], train accuracy: 0.140000, loss: nan
Epoch: [22/30], Batch: [401/500], train accuracy: 0.100000, loss: nan
100% 500/500 [02:38<00:00, 3.16it/s]
Epoch: [22/30], train loss: nan
Epoch: [22/30], test accuracy: 0.100000, loss: nan
Epoch: [23/30], Batch: [1/500], train accuracy: 0.100000, loss: nan
Epoch: [23/30], Batch: [101/500], train accuracy: 0.080000, loss: nan
Epoch: [23/30], Batch: [201/500], train accuracy: 0.100000, loss: nan
Epoch: [23/30], Batch: [301/500], train accuracy: 0.090000, loss: nan
Epoch: [23/30], Batch: [401/500], train accuracy: 0.100000, loss: nan
100% 500/500 [02:38<00:00, 3.16it/s]
Epoch: [23/30], train loss: nan
Epoch: [23/30], test accuracy: 0.100000, loss: nan
Epoch: [24/30], Batch: [1/500], train accuracy: 0.070000, loss: nan
Epoch: [24/30], Batch: [101/500], train accuracy: 0.110000, loss: nan
Epoch: [24/30], Batch: [201/500], train accuracy: 0.080000, loss: nan
Epoch: [24/30], Batch: [301/500], train accuracy: 0.080000, loss: nan
Epoch: [24/30], Batch: [401/500], train accuracy: 0.110000, loss: nan
100% 500/500 [02:39<00:00, 3.14it/s]
Epoch: [24/30], train loss: nan
Epoch: [24/30], test accuracy: 0.100000, loss: nan
Epoch: [25/30], Batch: [1/500], train accuracy: 0.090000, loss: nan
Epoch: [25/30], Batch: [101/500], train accuracy: 0.070000, loss: nan
Epoch: [25/30], Batch: [201/500], train accuracy: 0.100000, loss: nan
Epoch: [25/30], Batch: [301/500], train accuracy: 0.110000, loss: nan
Epoch: [25/30], Batch: [401/500], train accuracy: 0.110000, loss: nan
100% 500/500 [02:39<00:00, 3.14it/s]
Epoch: [25/30], train loss: nan
Epoch: [25/30], test accuracy: 0.100000, loss: nan
Epoch: [26/30], Batch: [1/500], train accuracy: 0.130000, loss: nan
Epoch: [26/30], Batch: [101/500], train accuracy: 0.050000, loss: nan
Epoch: [26/30], Batch: [201/500], train accuracy: 0.070000, loss: nan
Epoch: [26/30], Batch: [301/500], train accuracy: 0.100000, loss: nan
Epoch: [26/30], Batch: [401/500], train accuracy: 0.120000, loss: nan
100% 500/500 [02:38<00:00, 3.16it/s]
Epoch: [26/30], train loss: nan
Epoch: [26/30], test accuracy: 0.100000, loss: nan
Epoch: [27/30], Batch: [1/500], train accuracy: 0.100000, loss: nan
Epoch: [27/30], Batch: [101/500], train accuracy: 0.090000, loss: nan
Epoch: [27/30], Batch: [201/500], train accuracy: 0.090000, loss: nan
Epoch: [27/30], Batch: [301/500], train accuracy: 0.050000, loss: nan
Epoch: [27/30], Batch: [401/500], train accuracy: 0.090000, loss: nan
100% 500/500 [02:39<00:00, 3.15it/s]
Epoch: [27/30], train loss: nan
Epoch: [27/30], test accuracy: 0.100000, loss: nan
Epoch: [28/30], Batch: [1/500], train accuracy: 0.100000, loss: nan
Epoch: [28/30], Batch: [101/500], train accuracy: 0.080000, loss: nan
Epoch: [28/30], Batch: [201/500], train accuracy: 0.080000, loss: nan
Epoch: [28/30], Batch: [301/500], train accuracy: 0.070000, loss: nan
Epoch: [28/30], Batch: [401/500], train accuracy: 0.100000, loss: nan
100% 500/500 [02:39<00:00, 3.15it/s]
Epoch: [28/30], train loss: nan
Epoch: [28/30], test accuracy: 0.100000, loss: nan
Epoch: [29/30], Batch: [1/500], train accuracy: 0.060000, loss: nan
Epoch: [29/30], Batch: [101/500], train accuracy: 0.090000, loss: nan
Epoch: [29/30], Batch: [201/500], train accuracy: 0.100000, loss: nan
Epoch: [29/30], Batch: [301/500], train accuracy: 0.100000, loss: nan
Epoch: [29/30], Batch: [401/500], train accuracy: 0.100000, loss: nan
100% 500/500 [02:39<00:00, 3.16it/s]
Epoch: [29/30], train loss: nan
Epoch: [29/30], test accuracy: 0.100000, loss: nan
Epoch: [30/30], Batch: [1/500], train accuracy: 0.130000, loss: nan
Epoch: [30/30], Batch: [101/500], train accuracy: 0.080000, loss: nan
Epoch: [30/30], Batch: [201/500], train accuracy: 0.080000, loss: nan
Epoch: [30/30], Batch: [301/500], train accuracy: 0.100000, loss: nan
Epoch: [30/30], Batch: [401/500], train accuracy: 0.110000, loss: nan
100% 500/500 [02:39<00:00, 3.11it/s]
Epoch: [30/30], train loss: nan
Epoch: [30/30], test accuracy: 0.100000, loss: nan

Mistake in Decoder Class in capsnet.py

nn.Linear(1024, self.input_height * self.input_height * self.input_channel),

Shouldn't this line be nn.Linear(1024, self.input_height * self.input_width * self.input_channel)? It works for CIFAR-10 but I think if you want to generalise the code then you will have to frame it this way. Let me know if I am wrong

squash function seems not turn vector with 16 value to 1 value as vector length

In file capsnet.py => function squash,

def squash(self, input_tensor):
        print(f'input_tensor.shape: {input_tensor.shape}')
        squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
        print(f'squared_norm.shape: {squared_norm.shape}')

Print out :

input_tensor.shape: torch.Size([100, 1, 10, 16, 1])
squared_norm.shape: torch.Size([100, 1, 10, 16, 1])

in my understanding, the squared_norm should be the length of vector, where the vector is in 16 dimension (= dim of 3)
thus, after operation of **2 and .sum, it should become single number,
i expected the output to be in shape [100, 1, 10, 1, 1]
however, the code .sum on last dimension, which is incorrect.

do i misunderstand?

How can I get the score of each class?

Hello, I have run this model in myself dataset, but I find that it can only get three parameters: output, reconstructions and masked. I want to calculate the AUC, it needs score of each class, so, how can I get the parameter? Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.