Code Monkey home page Code Monkey logo

robustdarts's Introduction

RobustDARTS

Code accompanying the paper:

Understanding and Robustifying Differentiable Architecture Search
Arber Zela, Thomas Elsken, Tonmoy Saikia, Yassine Marrakchi, Thomas Brox and Frank Hutter.
In: International Conference on Learning Representations (ICLR 2020).

Codebase

The code is basically based on the original DARTS implementation.

Requirements

Python >= 3.5.5, PyTorch == 0.3.1, torchvision == 0.2.0

As we show in our paper, DARTS will start assigning a large weight to skip connections as the search progresses, while at the same time the dominant eigenvalue of the validation loss Hessian starts increasing.

test_error_eigenvalues

Figure: Snapshot of the normal cells and dominant eigenvalue of the Hessian of the validation loss w.r.t. the architectural parameters over time.

Architecture search

To carry out the DARTS (2nd order) architecture search on all search spaces and datasets used throughout the paper, run:

./scripts/start_search.sh

NOTE: We used the Slurm Workload Manager environment to run our jobs, but it can be easily adapted to other job scheduling systems.

To carry out the DARTS-ADA and DARTS-ES (2nd order) architecture search on all search spaces and datasets, run:

./scripts/start_search_ADA.sh

Since, DARTS-ES and DARTS-ADA work with the same stopping criterion, we do not need to run them separately, but instead we just start DARTS-ADA and log the architectures after the first rollback iteration, supposing that DARTS would early stop, and then continue with the adapting regularization.

To start Random Search with Weight Sharing on all search spaces and datasets used throughout the paper, run:

./scripts/start_search_RandomNAS.sh

Architecture evaluation

To start evaluating all the architectures logged by the search runs, run:

./scripts/start_eval.sh

Make sure to set --archs_config_file to the correct .yaml file where the architecture genotypes are saved.

Citation

@inproceedings{zela2020understanding,
	title={Understanding and Robustifying Differentiable Architecture Search},
	author={Arber Zela and Thomas Elsken and Tonmoy Saikia and Yassine Marrakchi and Thomas Brox and Frank Hutter},
	booktitle={International Conference on Learning Representations},
	year={2020},
	url={https://openreview.net/forum?id=H1gDNyrKDS}
}

robustdarts's People

Contributors

arberzela avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

robustdarts's Issues

Different default parameter values than what is mentioned in the paper.

Hi

Thanks for the paper and code.

I noticed that a few argument values in search/src/args.py differ from what is mentioned in the paper like max_weight_decay, mul_factor.

Also if I understood it right, in the paper most plots seem to indicate that overfitting is reduced with drop_path_prob >= 0.5, yet the default value of it in code is given 0.2

The value of "factor" parameter (the threshold for eigenvalues) is also different in the paper than code.

Could you please clarify about the good/default value of all these 3 parameters as per your experiments?

Thank You!

please check the code of "discrete == True" in model_search.py

Hi, I appreciate for sharing your work including the paper and this open-source code.

I have a question about measuring "performance drop" in 4.2 section of the paper (figure 5(b)).

Let me explain my concern with an example. My understanding of RobustDARTS is as follows:

  1. validation accuracy of models are measured after search phase (i.e., after 50 epochs of search)
  2. weighting in mixed operation (i.e., softmax output of alpha) is not one-hot before discretizing alpha, and is converted into one-hot vector after discretizing to leave one operation per edge.
    ex) [one edge] alpha after 50 epoch search = [0.1 0.4 0.2 0.1]
    --> softmax(alpha) = [0.2245 0.3030 0.2481 0.2245]
    --> this is used for weighting of mixed op before discretizing.
    • after discretizing --> the weighting of mixed op is converted into an one-hot vector: [0 1 0 0].

Please check if my understanding is right.

If it is, the code of "discrete == True" in forward() of Network class in model_search.py seems to have an error.
In your code, weighting is [0.1 0.4 0.2 0.1] if discrete (without softmax activation), otherwise [0.2245 0.3030 0.2481 0.2245] (with softmax activation).
In my opinion, if discrete, the weighting should be [0 1 0 0].

Please check this issue.
If I am misunderstanding the discretizing process, I would like to ask for a clearer explanation of the discretizing step and the measurement of "performance drop."

Thank you.

question about analyser.compute_dw

Hi,

Noticed you are monitoring 1st-order derivatives of loss versus architecture parameters as well when compute_hessian is activated. Is it to make sure the 1st-order derivatives are actually close to zero so the eigenvalues of hessian would be the global curvatures?

If so, what's the mechanism do you apply to judge if 1st-order derivatives are close to zero? Thanks very much!

Best,
Bolian

Is it a bug about la_tracker.ev?

Hi, I found something weird in src/search/train_search.

In line 161, if args.compute_hessian, ev is assigned -1.
But i think it should be

if not args.compute_hessian:
    ev = -1
else:
    ev = la_tracker.ev[-1]

because la_tracker.ev is only appended something in line 421 when args.compute_hessian is True.
If I directly run your code, a list index out of range error occurs.

I am not sure if I understand it right.
Thank you!

Cutout variable

Hi

In train_search.py, there seems an intention to be strengthening the data augmentation along with the architecture search progress by increasing the probability to activate Cutout.

  if args.drop_path_prob != 0:
    model.drop_path_prob = args.drop_path_prob * epoch / (args.epochs - 1)
    train_transform.transforms[-1].cutout_prob = args.cutout_prob * epoch / (args.epochs - 1)

However, the variable in Cutout is prob rather than cutout_prob; therefore, it looks like Cutout is always activated if args.cutout is set as True. Should the code be revised? Thanks.

  class Cutout(object):
    def __init__(self, length, prob=1.0):
    self.length = length
    self.prob = prob

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.