Code Monkey home page Code Monkey logo

drmad's Introduction

DrMAD

License

Source code for the paper: Distilling Reverse-Mode Automatic Differentiation for Optimizing Hyperparameters of Deep Neural Networks.

What's DrMAD?

DrMAD is a hyperparameter tuning method based on automatic differentiation, which is a most criminally underused tool for machine learning.

DrMAD can tune thousands of continuous hyperparameters (e.g. L1 norms for every single neuron) for deep models on GPUs.

AD vs Bayesian optimization (BO)

BO, as a global optimization approach, can hardly support tuning more than 20 hyperparameters, because it treats the learning algorithm being tuned as a black-box and can only get feedback signals after convergence. Also, the learning rate tuned by BO is fixed for all iterations.

AD is different from symbolic differentiation (used by Mathematica), and numerical differentiation. AD, as a local optimization method based on gradient information, can make use of the feedback signals after every iteration and can tune thounsands of hyperparameters by using (hyper-)gradients with respect to hyperparameters. Checkout this paper if you want to understand AD techniques deeply.

Therefore, AD can tune hundreds of thousands of constant (e.g. L1 norm for every neuron) or dynamic (e.g. learning rates for every neuron at every iteration) hyperparameters.

The standard way of computing these (hyper-)gradients involves a forward and backward pass of computations. However, the backward pass usually needs to consume unaffordable memory (e.g. TBs of RAM for MNIST dataset) to store all the intermediate variables to exactly reverse the forward training procedure.

Hypergradient with an approximate backward pass

We propose a simple but effective method, DrMAD, to distill the knowledge of the forward pass into a shortcut path, through which we approximately reverse the training trajectory. When run on CPUs, DrMAD is at least 45 times faster and consumes 100 times less memory compared to state-of-the-art methods for optimizing hyperparameters with almost no compromise to its effectiveness on small-scale problems as used in the previous studies.

CPU code for reproducing

For reproducing the original result in the paper, please refer to CPU version

In the original paper, we set the momentum to a small value (0.1). Now we found that setting this value to 0.9 or even 0.95 will give better performance.

GPU code

We've implemented GPU version with Theano, which is not included in the original paper. This GPU implementation does NOT mean it is practical for large-scale models.


Citation

@article{drmad2016,
  title={DrMAD: Distilling Reverse-Mode Automatic Differentiation for Optimizing Hyperparameters of Deep Neural Networks},
  author={Fu, Jie and Luo, Hongyin and Feng, Jiashi and Low, Kian Hsiang and Chua, Tat-Seng},
  journal={IJCAI},
  year={2016}
}

Contact

If you have any problems or suggestions, please contact: jie.fu A_T u.nus.education

drmad's People

Contributors

bigaidream avatar taineleau avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

drmad's Issues

about learning rate

Hi,

I don't understand learning rate update in optimizer sgd4_mad, the N_safe_sampling = len(alphas) decided the loop numbers .

dose alphas represent how many learning rate in the net we want to tune, for example each layer have it's own learning rate, or even each Neuron has it own learning rate.

i don't understand why d_alphas be calculated like below,

d_alphas[i] = np.dot(d_x, v)

run "experiments/safe/safe.py" crash

try to run safe.py get error below:

$ python safe.py
Traceback (most recent call last):
  File "safe.py", line 180, in <module>
    results = run()
  File "safe.py", line 42, in run
    all_data = mnist.load_data_as_dict()
  File "/home/liang/workspace/drmad/cpu_ver/hypergrad/mnist.py", line 55, in load_data_as_dict
    X, T = load_data(normalize)[:2]
  File "/home/liang/workspace/drmad/cpu_ver/hypergrad/mnist.py", line 37, in load_data
    with open(datapath("mnist_data.pkl")) as f:
IOError: [Errno 2] No such file or directory: '/home/jie/d2/bitbucket/hypergradient_bo/data/mnist/mnist_data.pkl'

I also tried to run cifar10.py, get error below:

$ python cifar10.py
Traceback (most recent call last):
  File "cifar10.py", line 179, in <module>
    results = run()
  File "cifar10.py", line 41, in run
    all_data = mnist.load_data_as_dict()
  File "/home/liang/workspace/drmad/cpu_ver/hypergrad/mnist.py", line 55, in load_data_as_dict
    X, T = load_data(normalize)[:2]
  File "/home/liang/workspace/drmad/cpu_ver/hypergrad/mnist.py", line 37, in load_data
    with open(datapath("mnist_data.pkl")) as f:
IOError: [Errno 2] No such file or directory: '/home/jie/d2/bitbucket/hypergradient_bo/data/mnist/mnist_data.pkl'

why run cifar10 will load mnist's data?

doubt in predictions

in make_nn_funs, the predictions function,
does the last layer want to output the probabilities for each class ?

when get the last layer,

if i == (N_iter - 1):
     cur_units = cur_units - logsumexp(cur_units, axis=1)

can this get the normalized log-probabilities?

plot error

uncomment the plot(), get error.

Traceback (most recent call last):
  File "safe.py", line 186, in <module>
    plot()
  File "safe.py", line 161, in plot
    ax = fig.add_subplot(213)
  File "/Users/liang/anaconda/lib/python2.7/site-packages/matplotlib/figure.py", line 1005, in add_subplot
    a = subplot_class_factory(projection_class)(self, *args, **kwargs)
  File "/Users/liang/anaconda/lib/python2.7/site-packages/matplotlib/axes/_subplots.py", line 64, in __init__
    maxn=rows*cols, num=num))
ValueError: num must be 1 <= num <= 2, not 3

torch autograd

hey, saw you guys in someone else's starred repository stream. fyi, we have a torch version of autograd at github.com/twitter/autograd, and this would be a pretty great application of it. i'd be happy to help out however you needed to get you going.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.