Comments (7)
@fehiepsi I achieved that score by doing a few things. I first trained 50 iterations with Adam learning rate 1e-3, then 100 iterations at 1e-4 and 100 final iterations at 1e-5.
I am sure that with SGD, you can set your learning rate scheduler to do similar things, but since I was using Adam and I didn't know of any direct method to set the learning rate, I did this manually. I believe you can set a learning rate scheduler via a callback for SGD. I haven't tried it for Adam though.
Note that with all deep learning models, it is not possible to perfectly reproduce the score (due to initialization of weights in NN and the adaptive gradients of Adam). Even the pre-trained model accuracy is slightly lesser than the accuracy mentioned in the paper.
from densenet.
@ahundt May I ask which model type you were training? The DN-40-12 model in the cifar-10 script?
I haven't trained the DenseNet models on Tensorflow so I don't know what the cause is (they hit an OOM on tensorflow, theano with GC allows me to train though).
EDIT:
Also, two runs? Did you make sure to lower the learning rate of Adam in the second run to 1e-5 or 2e-5?
from densenet.
Here it is:
https://github.com/ahundt/DenseNet/tree/tf
Ah yes the problem might be that I forgot to lower the learning rate. I bet there might also be a way to automatically have rate decay in keras.
from densenet.
you're definitely right it is already at 91% validation accuracy after just a few minutes. Good call and thanks! I'll see if I can submit a pull request which can load either the tf or the th model depending on what is run.
from densenet.
@ahundt No problem. Glad I could help.
On the issue of loading the correct model, you could do something like this:
from keras import backend as K
....
if K.backend() == 'theano':
model.load_weights('path/to/theano/weights')
else:
model.load_weights('path/to/tensorflow/weights')
from densenet.
@titu1994 Indeed, I get low validation accuracy even for theano using the current cifar10.py script. At 200 epochs, the accuracy is 88.06%. When I use your pretrained model, I get 94.84% accuracy. In some comment, you mentioned about decreasing the learning rate for Adam to 1e-5, is this the reason your pretrained model performs so well? Would you please help me reproduce your result? Thank you very much!
from densenet.
Thank you so much @titu1994 !!! Your comment is very helpful to me. I am trying to use the callback in keras with your suggestions.
from densenet.
Related Issues (20)
- Cifar10 weights HOT 1
- no longer works with newest keras HOT 8
- About implementation of __dense_block HOT 2
- Running Densnet in CPU HOT 1
- Why reshape in fc-densenet in top layer before applying softmax? HOT 2
- AttributeError: 'NoneType' object has no attribute 'get_file' HOT 2
- Where is connection to 12 layers of each dense block? HOT 1
- About DenseNet HOT 1
- About Densenet architecture HOT 1
- Plans for memory efficient implementation in Keras? HOT 1
- How to upload my own dataset instead of the Cifar10 dataset HOT 1
- How to use 'DENSENET_121_WEIGHTS_PATH_NO_TOP'? HOT 1
- Mistake in L2 regularization HOT 1
- creat model HOT 2
- creat model HOT 1
- imagenet_inference
- inter_channel HOT 1
- Poor CIFAR100 accuracy HOT 1
- Error about normalize_data_format
- Unable to convert to frozen graph or checkpoint
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from densenet.