Code Monkey home page Code Monkey logo

maml-pytorch's Introduction

MAML-Pytorch

PyTorch implementation of the supervised learning experiments from the paper: Model-Agnostic Meta-Learning (MAML).

Version 1.0: Both MiniImagenet and Omniglot Datasets are supported! Have Fun~

Version 2.0: Re-write meta learner and basic learner. Solved some serious bugs in version 1.0.

For Tensorflow Implementation, please visit official HERE and simplier version HERE.

For First-Order Approximation Implementation, Reptile namely, please visit HERE.

heart

Platform

  • python: 3.x
  • Pytorch: 0.4+

MiniImagenet

Howto

For 5-way 1-shot exp., it allocates nearly 6GB GPU memory.

  1. download MiniImagenet dataset from here, splitting: train/val/test.csv from here.
  2. extract it like:
miniimagenet/
├── images
	├── n0210891500001298.jpg  
	├── n0287152500001298.jpg 
	...
├── test.csv
├── val.csv
└── train.csv

  1. modify the path in miniimagenet_train.py:
        mini = MiniImagenet('miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
                    k_query=args.k_qry,
                    batchsz=10000, resize=args.imgsz)
		...
        mini_test = MiniImagenet('miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
                    k_query=args.k_qry,
                    batchsz=100, resize=args.imgsz)

to your actual data path.

  1. just run python miniimagenet_train.py and the running screenshot is as follows: screenshot-miniimagetnet

If your reproducation perf. is not so good, maybe you can enlarge your training epoch to get longer training. And MAML is notorious for its hard training. Therefore, this implementation only provide you a basic start point to begin your research. and the performance below is true and achieved on my machine.

Benchmark

Model Fine Tune 5-way Acc. 20-way Acc.
1-shot 5-shot 1-shot 5-shot
Matching Nets N 43.56% 55.31% 17.31% 22.69%
Meta-LSTM 43.44% 60.60% 16.70% 26.06%
MAML Y 48.7% 63.11% 16.49% 19.29%
Ours Y 46.2% 60.3% - -

Ominiglot

Howto

run python omniglot_train.py, the program will download omniglot dataset automatically.

decrease the value of args.task_num to fit your GPU memory capacity.

For 5-way 1-shot exp., it allocates nearly 3GB GPU memory.

Refer to this Rep.

@misc{MAML_Pytorch,
  author = {Liangqu Long},
  title = {MAML-Pytorch Implementation},
  year = {2018},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/dragen1860/MAML-Pytorch}},
  commit = {master}
}

maml-pytorch's People

Contributors

arnoutdevos avatar dragen1860 avatar lmzintgraf avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

maml-pytorch's Issues

Only use Conv and Linear

Hi, has anyone tried only to use Conv, and Linear function without BN? I commented BN in the config, and the result gives me random guess ...

For 5-ways, the acc=0.5 ...

I checked the code, I cannot figure out how this happens ...

I'm really confused with the v2.0 meta.py

  • First, the comment says the index of losses_q is tasks index.

    losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i], i is tasks idx

    However, in each task i , the whole list is updated.
    losses_q[0] += loss_q

    losses_q[1] += loss_q

    losses_q[k + 1] += loss_q

  • Second, I haven't seen the sum of loss_q?

    MAML-Pytorch/meta.py

    Lines 134 to 135 in fc20b31

    # sum over all losses on query set across all tasks
    loss_q = losses_q[-1] / task_num

    losses_q[-1] seems to be the last step's loss for the last task?

  • Third, if update_step == 1, there will be only one inner update. However, the loss after first update is computed under torch.no_grad(), so I think there is no backward update information on the query set.

    MAML-Pytorch/meta.py

    Lines 100 to 109 in fc20b31

    # this is the loss and accuracy after the first update
    with torch.no_grad():
    # [setsz, nway]
    logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
    loss_q = F.cross_entropy(logits_q, y_qry[i])
    losses_q[1] += loss_q
    # [setsz]
    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
    correct = torch.eq(pred_q, y_qry[i]).sum().item()
    corrects[1] = corrects[1] + correct

Previous version

Hi, I was trying to use code from your previous version. I was wondering what were the problems in the previous version and were you able to successfully replicate the omniglot experiment?

Error when running

Hi, I was trying to run the code on Omniglot dataset, with the task of 5-way classification with 1 shot. However, I encountered the following error when it reached the line loss.backward() in the function inner_train() in csmlv0.py:

THCudaCheck FAIL file=/opt/conda/conda-bld/pytorch_1524584710464/work/aten/src/T
HC/generic/THCTensorMath.cu line=26 error=59 : device-side assert triggered
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/home/cgy/anaconda2/envs/maml-py3/lib/python3.6/multiprocessing/process.
py", line 258, in _bootstrap
    self.run()
  File "/home/cgy/anaconda2/envs/maml-py3/lib/python3.6/multiprocessing/process.
py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/cgy/Desktop/ztr/MAML-Pytorch-master_new_2/csmlv0.py", line 162, in
 inner_train
    loss.backward()
  File "/home/cgy/anaconda2/envs/maml-py3/lib/python3.6/site-packages/torch/tens
or.py", line 93, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/cgy/anaconda2/envs/maml-py3/lib/python3.6/site-packages/torch/auto
grad/__init__.py", line 89, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: cuda runtime error (59) : device-side assert triggered at /opt/con
da/conda-bld/pytorch_1524584710464/work/aten/src/THC/generic/THCTensorMath.cu:26
terminate called after throwing an instance of 'std::runtime_error'
  what():  cuda runtime error (59) : device-side assert triggered at /opt/conda/
conda-bld/pytorch_1524584710464/work/aten/src/THC/generic/THCStorage.c:184

It seems that each value of the argument Target of the function torch.nn.CrossEntropyLoss() is supposed to be in the range [0, C-1] (see the official doc here). However, the variable support_yb still uses the original image labels ranging from 0 to 1199 for train set, which might be the cause of the problem.

I am using anaconda Python 3.6, PyTorch 0.4, with CUDA 8.0. Thanks!

Measure the model performance

Hi,

Thanks for the simpler implementation of MAML.

As per the MAML paper At the end of meta-training, new tasks are sampled fromp(T),and meta-performance is measured by the model’s perfor-mance after learning fromKsamples. Generally, tasksused for meta-testing are held out during meta-training.

Anybody has tried fine-tuning the model with few number (0 to 10) of samples for a new class which was not there in the training dataset and measured the performance?

Is that part of the code already available in this repository?

Thank you,
KK

only reproduce 42.94 in miniimagenet

Hi, Thanks for your excellent work!
I only reproduce 42.94 using the default parameters. I guess the reason maybe that the code does not contain the 2nd derivatives. The original authors of MAML using
if FLAGS.stop_grad: grads = [tf.stop_gradient(grad) for grad in grads]

Is this the reason for the error?

miniImagenet experiments

Hi, thanks for sharing your code, BTW, I tried with the most recent version, but I couldn't get the same results as reported, 5-way 1-shot. I could get up to ~43% accs, do you have any idea?

In addition, I want to increase the num_filter from 32 to 64, and it causes out of memory? Do you have any idea about why it takes so many memory with a four-layer network?

Some questions that hoped to be answered

Hi, it's pleasure for me to read your code. However, I have some doubts, and I would be very grateful if you could reply the answers to me.

1、For a task i, why repeat training K times, and doesn't that lead to overfitting?
(What I saw in the original paper was “Sample K datapoints ……” ( in the Algorithm 2 ,line 5 ),rather than repeat the same sample K time.)

2、In the Batchnorm layer, why do you move the "mean" and "variance" parameters?

some questions about the implement's details

First, thanks for your easily understand coding work.

Question1:
in your meta.py program:
`
for k in range(1, self.update_step_test):
# 1. run the i-th task and compute loss for k=1~K-1
logits = net(x_spt, fast_weights, bn_training=True)
loss = F.cross_entropy(logits, y_spt)
# 2. compute grad on theta_pi
grad = torch.autograd.grad(loss, fast_weights)
# 3. theta_pi = theta_pi - train_lr * grad
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))

logits_q = net(x_qry, fast_weights, bn_training=True)
# loss_q will be overwritten and just keep the loss_q on last update step.
loss_q = F.cross_entropy(logits_q, y_qry)

with torch.no_grad():
    pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
    correct = torch.eq(pred_q, y_qry).sum().item()  # convert to numpy
    corrects[k + 1] = corrects[k + 1] + correct

`
why this fast_weight backpropogation part required to be uodated update_step times?

Question2:
in omniglotNShot.py, function load_data_cache()'s definition,
for sample in range(10): # num of episodes
it's that means an episode contains batchse(here is 32) tasks?
aren't a task equal to an episode?

I see, in function next(), I can't understand why you set:
if self.indexes[mode] >= len(self.datasets_cache[mode]): self.indexes[mode] = 0 self.datasets_cache[mode] = self.load_data_cache(self.datasets[mode])
I think it means when epoch turn to the multiple of 10, function next() will generate 10 times of 32 tasks, and every epoch use one(32 tasks) of them(10)?
so why not remove it, and remove the 'for' loop that above mentioned, just use function load_data_cache() in next() function?

Question3:
every time using fine-tuning aren't producing a new fine-tuned model?
I thought it was just using support set to calculate loss then backpropagation, which results in the fine-tuned model then evaluation in query set.
I see it use the same update parameters strategy as meta procedure, and just calculate the accuracy of the query set.

Some questions about the codes

  1. About validation accuracy of MiniImageNet in MAML original code, the evaluation is performed on the validation set instead of the test set, but in this code, evaluation is performed on the test set.

  2. There is no test phase in this code.

  3. For the optimization, I think the auto-grad seems right. During the inner loop, it is a first-order optimization, and after evaluating on query data, the "net.parameters" is updated with Adam optimizer, which is the same as that in MAML code.

Question about variable, losses_q

Hi authors, thanks for your repo, it's so elegant and help me much, but when I read the code, there is a question about it.

In Maml forward function, losses_q is a list type variable, store some losses, but I found that only the last element be used finally.

# end of all tasks
# sum over all losses on query set across all tasks
loss_q = losses_q[-1] / task_num

Then I read the shit-like [official implement] (https://github.com/cbfinn/maml/blob/master/maml.py#L135) code, and fount it sum over all losses

self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)

I am very confused about it, and hope for a reply.

Some ideas about parameters updating.

Thanks for your good implementation of MAML, however, I think that maybe using state_dict() and load_stat_dict() is much easier than modifying all the weights (in learner.py forward), can I first deepcopy the net parameters(state_dict()) and use the fast weights (also use a optimizer to update, instead of list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) ), then load the origin parameters back to update the meta learner? Thanks.

2nd Order or 1st Order Approximation?

Is this implementation a 1st order Approximation version of Maml ?
In meta.py, when you do autograd.grad, you do not specify create_graph = True, which means that the gradient operation would not be included in the computation graph.

Thus, although the design here is trying to calculate the 2nd order derivatives, the grad is not included, so only 1st order approximation.

question about Hessian-vector products

In the original paper, the authors claimed that MAML needs second gradient and Hessian-vector products. Could you explain how do you implement this or Pytorch just do this automatically? Thanks!

OmniglotNShot

I used omniglot to train this model, but I found that train_data and test_data used in fine-tune are same classes. Such as 5-way, train_data are [0,1, 2, 3 ,4] and test_data also are [0, 1, 2, 3, 4]。

Performance on Omniglot is slightly lower than paper report

Hi, thank you for your implement MAML in pytorch.

I have tried your code and get some result. For omniglot dataset, the accuracy I got is lower than the original implementation. For 5-way 5-shot, the accuracy on test set is around 96%, while the paper report that the accuracy could achieve 99.9% for convolution network. The same thing happens for 5-way 1 shot.

I checked the code and found the model parameters are the same with original code. Do you have any idea about this?

miniImagenet experiments

Hi, thanks for sharing your code, BTW, I tried with the most recent version, but I couldn't get the same results as reported, 5-way 1-shot. I could get up to ~45% accs, do you have any idea?

reproduce result on mini-imagenet

Thanks for sharing the code.
I am trying to reproduce the result you report on the mini-imagenet by running
python miniimagenet_train.py

The training log shows that the accuracy stops at around 0.30. Any clues for this problem?
By the way, is it possible for you to share your training logs?
Much appreciated for your help.

bug in 2nd order?

I see in your code, just using
self.net(x_spt[i], fast_weights, bn_training=True)

however the torch.autograd.grad() method contain the following parameter:

create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: False.

Is that means your code just calculates the 1st order derivative?

Thank you!

Test data leaked into meta training

In the forward() function of maml.py (see here), the meta optimizer steps are taken no matter the input data is for training or for testing.
This results in the leakage of testing examples to the meta-learner which needs to be guarded by the training flag.

accuracy format

Hello,

First of all, I would like to thank you for your work. I have a question concerning accuracy format. I do not understand why there is more than one value

run code have a mistakes on db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)

Traceback (most recent call last):
File "", line 1, in
File "D:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py", line 105, in spawn_main
exitcode = _main(fd)
File "D:\ProgramData\Anaconda3\lib\multiprocessing\spawn.py", line 115, in _main
self = reduction.pickle.load(from_parent)
EOFError: Ran out of input
Traceback (most recent call last):
File "D:\Program Files\JetBrains\PyCharm 2018.2.3\helpers\pydev\pydevd.py", line 1664, in
main()
File "D:\Program Files\JetBrains\PyCharm 2018.2.3\helpers\pydev\pydevd.py", line 1658, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "D:\Program Files\JetBrains\PyCharm 2018.2.3\helpers\pydev\pydevd.py", line 1068, in run
pydev_imports.execfile(file, globals, locals) # execute the script
File "D:\Program Files\JetBrains\PyCharm 2018.2.3\helpers\pydev_pydev_imps_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "F:/NLP/meta-learning/MAML-Pytorch/miniimagenet_train.py", line 110, in
main()
File "F:/NLP/meta-learning/MAML-Pytorch/miniimagenet_train.py", line 68, in main
for step, item in enumerate(db):
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 819, in iter
return _DataLoaderIter(self)
File "D:\ProgramData\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py", line 560, in init
w.start()
File "D:\ProgramData\Anaconda3\lib\multiprocessing\process.py", line 105, in start
self._popen = self._Popen(self)
File "D:\ProgramData\Anaconda3\lib\multiprocessing\context.py", line 223, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "D:\ProgramData\Anaconda3\lib\multiprocessing\context.py", line 322, in _Popen
return Popen(process_obj)
File "D:\ProgramData\Anaconda3\lib\multiprocessing\popen_spawn_win32.py", line 65, in init
reduction.dump(process_obj, to_child)
File "D:\ProgramData\Anaconda3\lib\multiprocessing\reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'MiniImagenet.init..'

Question about autograd of pytorch.

if True: # TODO: this is a potential problems.

Hi dragen:
I have a question about the autograd mechanism of pytorch. It seems to me that pytorch doesn't support high-order gradient since it builds the graph during forward pass, while computes the gradients by reverse passing the graph without expanding it, thus the gradient variable actually has no link to the origin variable so we cant further compute its gradient w.r.t. the origin var.. However in the above part of the code, where the meta update is conducted, seems no special mechanism is used so I think the update doesn't consider high-order gradient formed by K times inner loop.
To further demonstrate the problem, I write a simple demo.

Re-sampling tasks after each epoch increases the performance

The create_batch function is only called once when the MiniImagenet dataset object is created, which means the tasks sampled are the same in every epoch.

I changed the code to second-order (according to #32) and call create_batch in every epoch, the performance can achieve 47.17%.

Mutiple GPU

I am trying to use multiple GPUs . In your back folder, I see some references of multi GPUs. Have you able to run the code on multiple GPUs. If so can you please guide me on how to use it with MAML.

Conv-ReLU-BN issue

I just found this code use conv-relu-bn. However it should be conv-bn-relu. Could you please fix it.

Pytorch implementation of meta updating

Hi dragen,
I have read your code. It's a good job.
When I tried to implement the maml, I used parameters_origin = model.state_dict() and model.load_state_dict(parameters_origin) to regain the weights after inner update. And then I carried out meta update using the summed loss. But this seemed not to work, and the acc waved around 20.
I don't know what's going wrong and have u reproduced the result on miniimagenet?
Best!

Maybe there is some error in learner.py?

The author tries to define the running state of the BatchNorm manually. However, I found there is some error in learner.py, in line 66, and line 67.

learner.py

running_mean = nn.Parameter(torch.zeros(param[0]), requires_grad=False) # line 66 #
running_var = nn.Parameter(torch.ones(param[0]), requires_grad=False) # line 67 #
#######################
The running_mean and the running_var in BatchNorm is not the parameter of the network itself, since it dose not require calculating gradient.
It is the running_stats of the model and only records the mean and var in each minibatch during iterations.
We can change line 66 and line 67 as following to avoid calculating gradient:
###############
self.register_buffer('running_mean', XXXXXXXX)
self.register_buffer('running_var', XXXXXXXXXX)
###############
reference:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/batchnorm.html#BatchNorm2d

questions with respect to the no_grad()

Hi, @dragen1860 , thanks for your implementation, I have a problem running the code, When I set the number of inner updates to 1, then an exception occurs, " RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn".I am wondering this is caused by you set the "with torch.no_grad():" for every inner update, but I am not sure the function of this line of code, is it for not updating the mean and variance in batch_norm?

Hopefully you can help me figure this out, thanks,

It cost too much time for miniImageNet training.

Hi

Thank you for sharing your code.

I cloned your code and ran the miniImageNet training, and it took 2497 seconds, which is about 42 minutes, for 1000 steps with GTX 1080 Ti GPU. I also reduced the task nums to 2. The training time might be too expensive for me.

I noticed the GPU-Util increased to 20-40% for about 1 second and decreased to 0% for a few seconds when training.

So I am wondering if the training time is acceptable.

RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'

Traceback (most recent call last):
File "C:/Users/Dingd/Documents/GitHub/MAML-Pytorch-master/omniglot_train.py", line 95, in
main(args)
File "C:/Users/Dingd/Documents/GitHub/MAML-Pytorch-master/omniglot_train.py", line 55, in main
accs = maml(x_spt, y_spt, x_qry, y_qry)
File "C:\Users\Dingd\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\modules\module.py", line 489, in call
result = self.forward(*input, **kwargs)
File "C:\Users\Dingd\Documents\GitHub\MAML-Pytorch-master\meta.py", line 85, in forward
loss = F.cross_entropy(logits, y_spt[i])
File "C:\Users\Dingd\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\functional.py", line 1970, in cross_entropy
return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
File "C:\Users\Dingd\AppData\Local\Programs\Python\Python36\lib\site-packages\torch\nn\functional.py", line 1790, in nll_loss
ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Int for argument #2 'target'

What's the difference between task_num and N-way?

Hello, thank you so much for your code. It is very clear and readable.
I am new to meta-learning.
I do not understand the meaning of N in N-way. I think the number of tasks is the same as N. But in your code, it is clearly not.
So I hope you could explain and give me your advice.
Thank you so much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.