Code Monkey home page Code Monkey logo

nasnet-pytorch's Introduction

A neat pytorch implementation of NASNet

The performance of the ported models on ImageNet (Accuracy):

Model Checkpoint Million Parameters Val Top-1 Val Top-5
NASNet-A_Mobile_224 5.3 70.2 89.4
NASNet-A_large_331 88.9 82.3 96.0

The slight performance drop may be caused by the different spatial padding methods between tensorflow and pytorch.

The porting process is done by tensorflow_dump.py and pytorch_load.py, modified from Cadene's project. Note that NASNets with the original performance can be found there.

You can evaluate the models by running imagenet_eval.py, e.g. evaluate the NASNet-A_Mobile_224 ported model by

python imagenet_eval.py --nas-type mobile --resume /path/to/modelfile --gpus 0 --data /path/to/imagenet_root_dir

The ported model files are provided: NASNet-A_Mobile_224, NASNet-A_large_331.

Future work:

  • add drop path for training
  • more nasnet model settings

nasnet-pytorch's People

Contributors

wandering007 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

nasnet-pytorch's Issues

Cannot reproduce reported imagenet validation accuracy

Hi
I seek help as I am unable to reproduce your reported results. Here is my setup:
Python version 3.6.5
Pytorch version 1.1.0

Imagenet validation dataset downloaded from imagenet website and placed in a val folder having the following structure:
imagenet_dataset val/ n01440764/ ...images... n01443537/ ...images... .... 1000 such folders nasnet_a_large.pth weights downloaded from your dropbox link. I run your script as follows: python imagenet_eval.py --nas-type large --resume ../nasnet_a_large.pth --gpus 0 1 --data ../../data/imagenet_dataset --batch-size 4 I get Top1 = 17% and Top5 = 4% Test set[50000]: Top1: 17.00%, Top5: 4.00%, Average loss: 0.8056 Please help. Thank you

different shapes between `x_path_1` and `x_path_2`

Hi, I just run the nasnet code but get an error at this line:
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 8 but got size 7 for tensor number 1 in the list.`
I think x_path_1 is one element longer than x_path_2 along the last two dimensions, as at the previous line you stripped the first element on x_path_2. What do you think?

Thanks a lot for your help.

Best,
Xianzhe

Low & Wrong Accuracy of NASNet-A-Mobile_224

I just use below model file and Use ImageNet (year: 2012)
The ported model files are provided: NASNet-A_Mobile_224
But, I got top 1 acc: 53.97%, top5: 32.50% and This seems overfitting.
Is there something settings should I check when I run these things?

error while loading pretrained weights

I get this error while loading the weights:

RuntimeError                              Traceback (most recent call last)
<ipython-input-32-d82b0689f3c8> in <module>()
----> 1 model=model.load_state_dict(weights)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    767         if len(error_msgs) > 0:
    768             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 769                                self.__class__.__name__, "\n\t".join(error_msgs)))
    770 
    771     def _named_members(self, get_members_fn, prefix='', recurse=True):

RuntimeError: Error(s) in loading state_dict for NASNet:
	Missing key(s) in state_dict: "aux_features.2.weight", "aux_features.3.weight", "aux_features.3.bias", "aux_features.3.running_mean", "aux_features.3.running_var", "aux_features.5.weight", "aux_features.6.weight", "aux_features.6.bias", "aux_features.6.running_mean", "aux_features.6.running_var", "aux_linear.weight", "aux_linear.bias". 

NasNet Missing key(s) in state_dict

I defined a new function
def build_and_load_model(path, nas_type): model = NASNetALarge(1001) if nas_type == 'large' else NASNetAMobile(1001) filename_model = os.path.join(path, 'pytorch', nas_type == 'large' and 'nasnet_a_large.pth' or 'nasnet_a_mobile.pth') state_dict = torch.load(filename_model, map_location='cpu') model.load_state_dict(state_dict) return model
Among them, "tf-models/pytorch/nasnet_a_mobile.pth file is from the pth file you provided

The ported model files are provided: NASNet-A_Mobile_224, NASNet-A_large_331.

Now, I met the following error when call build_and_load_model()
RuntimeError: Error(s) in loading state_dict for NASNet: Missing key(s) in state_dict: "aux_features.2.weight", "aux_features.3.running_var", "aux_features.3.bias", "aux_features.3.weight", "aux_features.3.running_mean", "aux_features.5.weight", "aux_features.6.running_var", "aux_features.6.bias", "aux_features.6.weight", "aux_features.6.running_mean", "aux_linear.bias", "aux_linear.weight".
Is it not match that nasnet_a_mobile.pth provided and the model that NASNetAMobile(1001) function defines?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.