Code Monkey home page Code Monkey logo

nas-bench-x11's Introduction

NAS-Bench-x11

NAS-Bench-x11 and the Power of Learning Curves
Shen Yan, Colin White, Yash Savani, Frank Hutter.
NeurIPS 2021.

Surrogate NAS benchmarks for multi-fidelity algorithms

We present a method to create surrogate neural architecture search (NAS) benchmarks, NAS-Bench-111, NAS-Bench-311, and NAS-Bench-NLP11, that output the full training information for each architecture, rather than just the final validation accuracy. This makes it possible to benchmark multi-fidelity techniques such as successive halving and learning curve extrapolation (LCE). Then we present a framework for converting popular single-fidelity algorithms into LCE-based algorithms.

nas-bench-x11

Installation

Clone this repository and install its requirements.

git clone https://github.com/automl/nas-bench-x11
cd nas-bench-x11
cat requirements.txt | xargs -n 1 -L 1 pip install
pip install -e .

Download the pretrained surrogate models and place them into checkpoints/. The current models are v0.5. We will continue to improve the surrogate model by adding the sliding window noise model.

NAS-Bench-311 and NAS-Bench-NLP11 will work as is. To use NAS-Bench-111, first install NAS-Bench-101.

Using the API

The api is located in nas_bench_x11/api.py.

Here is an example of how to use the API:

from nas_bench_x11.api import load_ensemble

# load the surrogate
nb311_surrogate_model = load_ensemble('path/to/nb311-v0.5')

# define a genotype as in the original DARTS repository
from collections import namedtuple
Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
arch = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('skip_connect', 1), ('max_pool_3x3', 2), ('sep_conv_3x3', 0), ('dil_conv_5x5', 1), ('sep_conv_5x5', 2), ('dil_conv_5x5', 4)], \
                normal_concat=[2, 3, 4, 5, 6], \
                reduce=[('dil_conv_5x5', 0), ('skip_connect', 1), ('avg_pool_3x3', 0), ('sep_conv_5x5', 1), ('avg_pool_3x3', 0), ('max_pool_3x3', 2), ('sep_conv_3x3', 1), ('max_pool_3x3', 3)], \
                reduce_concat=[4, 5, 6])

# query the surrogate to output the learning curve
learning_curve = nb311_surrogate_model.predict(config=arch, representation="genotype", with_noise=True)
print(learning_curve)
# outputs: [34.50166741 44.77032749 50.62796474 ... 93.47724664]

Run NAS experiments from our paper

You will also need to download the nas-bench-301 runtime model lgb_runtime_v1.0 and place it inside a folder called nb_models.

# Supported optimizers: (rs re ls bananas)-{svr, lce}, hb, bohb 

bash naslib/benchmarks/nas/run_nb311.sh 
bash naslib/benchmarks/nas/run_nb201.sh 
bash naslib/benchmarks/nas/run_nb201_cifar100.sh 
bash naslib/benchmarks/nas/run_nb201_imagenet16-200.sh
bash naslib/benchmarks/nas/run_nb111.sh 
bash naslib/benchmarks/nas/run_nbnlp.sh 

Results will be saved in results/.

Citation

@inproceedings{yan2021bench,
  title={NAS-Bench-x11 and the Power of Learning Curves},
  author={Yan, Shen and White, Colin and Savani, Yash and Hutter, Frank},
  booktitle={Thirty-Fifth Conference on Neural Information Processing Systems},
  year={2021}
}

nas-bench-x11's People

Contributors

crwhite14 avatar frank-hutter avatar shenyann avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

nas-bench-x11's Issues

no setup.py

python3 -m pip install -e .
ERROR: File "setup.py" or "setup.cfg" not found. Directory cannot be installed in editable mode: /xxx/Documents/nas-bench-x11

Surrogate_model.model from nb211-v0.5 is not able to be loaded

Hi, I downgraded TF to the version you specified and have nasbench installed. nb311-v0.5 surrogate model is able to be loaded, but not nb111 and nb211.

When I load nb211, I run into keyError
Traceback (most recent call last):
File "example.py", line 5, in
nb311_surrogate_model = load_ensemble('/Users/yao.a.yang/Documents/nas-bench-x11/checkpoints/nb211-v0.5')
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/api.py", line 68, in load_ensemble
surrogate_model.load(model_paths=ensemble_member_dirs)
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/ensemble.py", line 142, in load
ens_mem.load(os.path.join(member_logdir, 'surrogate_model.model'))
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/models/svd_lgb.py", line 118, in load
if len(joblib.load(model_path)) == 5:
File "/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/joblib/numpy_pickle.py", line 587, in load
obj = _unpickle(fobj, filename, mmap_mode)
File "/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/joblib/numpy_pickle.py", line 506, in _unpickle
obj = unpickler.load()
File "/usr/local/anaconda3/envs/py36/lib/python3.6/pickle.py", line 1050, in load
dispatchkey[0]
KeyError: 239

For surrogate_model.model in nb111-v0.5, see error:
/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/sklearn/base.py:315: UserWarning: Trying to unpickle estimator RegressorChain from version 0.23.2 when using version 0.24.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)
/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/sklearn/base.py:315: UserWarning: Trying to unpickle estimator StandardScaler from version 0.23.2 when using version 0.24.2. This might lead to breaking code or invalid results. Use at your own risk.
UserWarning)
<nas_bench_x11.models.svd_lgb.SVDLGBModel object at 0x7fa5ac487f28>
[<nas_bench_x11.models.svd_lgb.SVDLGBModel object at 0x7fa5ac487f28>]
Traceback (most recent call last):
File "example.py", line 17, in
learning_curve = nb311_surrogate_model.predict(config=arch, representation="genotype", with_noise=True)
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/api.py", line 113, in predict
pred = self.model.query(config_dict, search_space=search_space)
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/ensemble.py", line 290, in query
use_noise=True)
File "/Users/yao.a.yang/Documents/nas-bench-x11/nas_bench_x11/models/svd_lgb.py", line 161, in query
comp = self.model.predict(X)
File "/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/sklearn/multioutput.py", line 549, in predict
Y_pred_chain[:, chain_idx] = estimator.predict(X_aug)
File "/usr/local/anaconda3/envs/py36/lib/python3.6/site-packages/lightgbm/sklearn.py", line 800, in predict
raise ValueError("Number of features of the model must "
ValueError: Number of features of the model must match the input. Model n_features_ is 30 and input n_features is 56

Any guidance? Thanks!

Help for predicting the learning curve of training losses

Hi guys,
I'm practicing on the example code and I found that the API only supported for predicting the validation accuracy, but in your paper, I see that the API can predict for the losses. Can you make an example for using API to predict the train/validation losses.
Thank you guys for a great paper.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.