Code Monkey home page Code Monkey logo

Comments (7)

titu1994 avatar titu1994 commented on June 12, 2024 1

@fehiepsi I achieved that score by doing a few things. I first trained 50 iterations with Adam learning rate 1e-3, then 100 iterations at 1e-4 and 100 final iterations at 1e-5.

I am sure that with SGD, you can set your learning rate scheduler to do similar things, but since I was using Adam and I didn't know of any direct method to set the learning rate, I did this manually. I believe you can set a learning rate scheduler via a callback for SGD. I haven't tried it for Adam though.

Note that with all deep learning models, it is not possible to perfectly reproduce the score (due to initialization of weights in NN and the adaptive gradients of Adam). Even the pre-trained model accuracy is slightly lesser than the accuracy mentioned in the paper.

from densenet.

titu1994 avatar titu1994 commented on June 12, 2024

@ahundt May I ask which model type you were training? The DN-40-12 model in the cifar-10 script?

I haven't trained the DenseNet models on Tensorflow so I don't know what the cause is (they hit an OOM on tensorflow, theano with GC allows me to train though).

EDIT:
Also, two runs? Did you make sure to lower the learning rate of Adam in the second run to 1e-5 or 2e-5?

from densenet.

ahundt avatar ahundt commented on June 12, 2024

Here it is:
https://github.com/ahundt/DenseNet/tree/tf

Ah yes the problem might be that I forgot to lower the learning rate. I bet there might also be a way to automatically have rate decay in keras.

from densenet.

ahundt avatar ahundt commented on June 12, 2024

you're definitely right it is already at 91% validation accuracy after just a few minutes. Good call and thanks! I'll see if I can submit a pull request which can load either the tf or the th model depending on what is run.

from densenet.

titu1994 avatar titu1994 commented on June 12, 2024

@ahundt No problem. Glad I could help.

On the issue of loading the correct model, you could do something like this:

from keras import backend as K

....

if K.backend() == 'theano':
    model.load_weights('path/to/theano/weights')
else:
    model.load_weights('path/to/tensorflow/weights')

from densenet.

fehiepsi avatar fehiepsi commented on June 12, 2024

@titu1994 Indeed, I get low validation accuracy even for theano using the current cifar10.py script. At 200 epochs, the accuracy is 88.06%. When I use your pretrained model, I get 94.84% accuracy. In some comment, you mentioned about decreasing the learning rate for Adam to 1e-5, is this the reason your pretrained model performs so well? Would you please help me reproduce your result? Thank you very much!

from densenet.

fehiepsi avatar fehiepsi commented on June 12, 2024

Thank you so much @titu1994 !!! Your comment is very helpful to me. I am trying to use the callback in keras with your suggestions.

from densenet.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.