Comments (6)
hi,
model.state_dict()
only works for standard nn.module
class. However, my model
class is not a standard class, hence you cant save weights by model.state_dict()
. You can manually save the weights tensor by pickle.
from maml-pytorch.
I know what you mean. But I do this as follows:
weights_origin = copy.deepcopy(self.net.parameters())
for i in range(batchsz): # batchsz==self.meta_batchsz
pred_q = self.net(query_x[i], bns=None, training=training)
pred_q = F.softmax(pred_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, query_y[i]).sum().item()
corrects[0] = corrects[0] + correct
# 1. run the i-th task and compute loss for k=0
self.inner_optim.zero_grad()
pred = self.net(support_x[i])
loss = F.cross_entropy(pred, support_y[i])
loss.backward()
nn.utils.clip_grad_norm_(self.net.parameters(), 5)
self.inner_optim.step()
pred_q = self.net(query_x[i], bns=None, training=training)
loss_q = F.cross_entropy(pred_q, query_y[i])
pred_q = F.softmax(pred_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, query_y[i]).sum().item()
corrects[1] = corrects[1] + correct
for k in range(1, self.K):
# 1. run the i-th task and compute loss for k=1~K-1
self.inner_optim.zero_grad()
pred = self.net(support_x[i], bns=None, training=training)
loss = F.cross_entropy(pred, support_y[i])
loss.backward()
nn.utils.clip_grad_norm_(self.net.parameters(), 5)
self.inner_optim.step()
pred_q = self.net(query_x[i], bns=None, training=training)
loss_q = F.cross_entropy(pred_q, query_y[i])
pred_q = F.softmax(pred_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, query_y[i]).sum().item() # convert to numpy
corrects[k+1] = corrects[k+1] + correct
# 4. record last step's loss for task i
losses_q.append(loss_q)
for p, p_o in zip(self.net.parameters(), weights_origin):
p.data.copy_(p_o)
# end of all tasks
# sum over all losses across all tasks
loss_q = torch.stack(losses_q).sum(0)
I load the original weights before every task, but update the weights by optimizer rather than using fast weight
s. But the acc waved around 20. I don't know what the matter is.
from maml-pytorch.
@JaminFong , hi, I think you are wrong in the
" for p, p_o in zip(self.net.parameters(), weights_origin):
p.data.copy_(p_o) "
this means that for a next batchzs, the " net " 's parameters are just the origin parameters, not the updated parameters, so the K step's updates can't work.
from maml-pytorch.
@flexibility2 Indeed, I think every meta-batch update is based on the original parameters as the paper illustrated.
from maml-pytorch.
@JaminFong
ok, and I think you are right~ Could you share the complete code so that we can discuss with that.
In addition, do you know what "K" stands for in the line 5 of Algorithm? I can't understand completely, and I hope I can get your explaination.
from maml-pytorch.
@flexibility2 K
is the amount of "shots" you sample for task T_i (line 5), which are grouped in the set D
. I think that for every task T_i gradient evaluation (line 6) the whole set D
of task T_i is used.
from maml-pytorch.
Related Issues (20)
- why use custom grad clip function?
- omniglot dataset download error HOT 1
- Does the hessian really gets computed? HOT 1
- Can you please add a 1-d CNN model to the learner.
- Incorrect losses_q HOT 3
- about dataset spiltting
- create_graph parameter is False hence first-order MAML? HOT 1
- Asking about inner and outer loop
- 您好,对于代码有两个问题,请教您一下,谢谢 HOT 3
- About training and testing HOT 6
- can you offer us your requirements of environment?
- Why is `for epoch in range(args.epoch // 10000):` HOT 2
- Why the code for Learner is so complicated? HOT 1
- 请问模型权值文件在哪里进行保存?代码在哪里? HOT 1
- debug
- Using the Learner object for my project, Loss not behaving at its best
- About accuracy HOT 6
- What is the backup file for and what is the reference navie5 in navie5?
- 准确率不变 HOT 4
- 训练自己的数据集
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from maml-pytorch.