Code Monkey home page Code Monkey logo

Comments (6)

dragen1860 avatar dragen1860 commented on July 21, 2024

hi,
model.state_dict() only works for standard nn.module class. However, my model class is not a standard class, hence you cant save weights by model.state_dict(). You can manually save the weights tensor by pickle.

from maml-pytorch.

JaminFong avatar JaminFong commented on July 21, 2024

I know what you mean. But I do this as follows:

		weights_origin = copy.deepcopy(self.net.parameters())
		for i in range(batchsz): # batchsz==self.meta_batchsz
			pred_q = self.net(query_x[i], bns=None, training=training)
			pred_q = F.softmax(pred_q, dim=1).argmax(dim=1)
			correct = torch.eq(pred_q, query_y[i]).sum().item()
			corrects[0] = corrects[0] + correct

			# 1. run the i-th task and compute loss for k=0
			self.inner_optim.zero_grad()
			pred = self.net(support_x[i])
			loss = F.cross_entropy(pred, support_y[i])
			loss.backward()
			nn.utils.clip_grad_norm_(self.net.parameters(), 5)
			self.inner_optim.step()

			pred_q = self.net(query_x[i], bns=None, training=training)
			loss_q = F.cross_entropy(pred_q, query_y[i])

			pred_q = F.softmax(pred_q, dim=1).argmax(dim=1)
			correct = torch.eq(pred_q, query_y[i]).sum().item()
			corrects[1] = corrects[1] + correct

			for k in range(1, self.K):
				# 1. run the i-th task and compute loss for k=1~K-1
				self.inner_optim.zero_grad()
				pred = self.net(support_x[i], bns=None, training=training)
				loss = F.cross_entropy(pred, support_y[i])
				loss.backward()
				nn.utils.clip_grad_norm_(self.net.parameters(), 5)
				self.inner_optim.step()

				pred_q = self.net(query_x[i], bns=None, training=training)
				loss_q = F.cross_entropy(pred_q, query_y[i])
				pred_q = F.softmax(pred_q, dim=1).argmax(dim=1)
				correct = torch.eq(pred_q, query_y[i]).sum().item() # convert to numpy
				corrects[k+1] = corrects[k+1] + correct

			# 4. record last step's loss for task i
			losses_q.append(loss_q)
			for p, p_o in zip(self.net.parameters(), weights_origin):
				p.data.copy_(p_o)

		# end of all tasks
		# sum over all losses across all tasks
		loss_q = torch.stack(losses_q).sum(0)

I load the original weights before every task, but update the weights by optimizer rather than using fast weights. But the acc waved around 20. I don't know what the matter is.

from maml-pytorch.

flexibility2 avatar flexibility2 commented on July 21, 2024

@JaminFong , hi, I think you are wrong in the
" for p, p_o in zip(self.net.parameters(), weights_origin):
p.data.copy_(p_o) "
this means that for a next batchzs, the " net " 's parameters are just the origin parameters, not the updated parameters, so the K step's updates can't work.

from maml-pytorch.

JaminFong avatar JaminFong commented on July 21, 2024

@flexibility2 Indeed, I think every meta-batch update is based on the original parameters as the paper illustrated.
ci7ozignrhg n uy 7rwjmu

from maml-pytorch.

flexibility2 avatar flexibility2 commented on July 21, 2024

@JaminFong
ok, and I think you are right~ Could you share the complete code so that we can discuss with that.
In addition, do you know what "K" stands for in the line 5 of Algorithm? I can't understand completely, and I hope I can get your explaination.

from maml-pytorch.

ArnoutDevos avatar ArnoutDevos commented on July 21, 2024

@flexibility2 K is the amount of "shots" you sample for task T_i (line 5), which are grouped in the set D. I think that for every task T_i gradient evaluation (line 6) the whole set D of task T_i is used.

from maml-pytorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.