Comments (9)
I wrote a piece of code for testing as follows. print_grad
is run on 2 processes, each process add i+1 to the data
of the network parameter. Since data
is shared, the result would be 2+1+2=5 for both processes. And we can see that the gradient behaves differently: each process has its own gradient initialized to 0, and 0+1=1 0+2=2 are different in the two processes.
If I understand correctly, gradient is allocated separately for each process as mentioned in the following post.
I think the point is that since grad
is still None
after we call share_memory()
, the gradient allocation inside each process would become separate. One can try to set grad
to 0 before calling share_memory()
. In this case, the gradient will be shared.
from __future__ import print_function
import os
import torch.multiprocessing as mp
import torch
from torch import nn
from torch.autograd import Variable
os.environ['OMP_NUM_THREADS'] = '1'
def print_grad(shared_model, i):
for p in shared_model.parameters():
if p._grad is None:
p._grad = Variable(torch.FloatTensor([0]))
p._grad += i+1
p.data += i+1
print(p.data)
print(p.grad)
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.x = nn.Parameter(torch.Tensor([2]))
def forward(self):
return self.x
model = TestNet()
model.share_memory()
processes = [mp.Process(target=print_grad, args=(model, i)) for i in range(0, 2)]
[p.start() for p in processes]
5
[torch.FloatTensor of size 1]
Variable containing:
1
[torch.FloatTensor of size 1]
5
[torch.FloatTensor of size 1]
Variable containing:
2
[torch.FloatTensor of size 1]
from pytorch-a3c.
I guess if shared_param.grad is not None
then some other thread must be updating the network. Current thread should not update it until others complete. But I have a question. As I understand, if the grad is not None, the code just returns which means that the gradient of the current thread is just discarded. Is this really the case? So only one of the threads can update the network. The others that complete in the same time will just run in vain?
from pytorch-a3c.
@xuehy It seems that the shared_param._grad = param.grad makes the shared_param.grad reference the same content as param. Therefore, once the shared_param._grad is not None it always has the same values as param.grad.
from pytorch-a3c.
@boscotsang I am still confusing.
Once the shared_param._grad is not None, it always has the same values as param.grad
But there are many threads owning different param.grad. Assume there are two threads A and B. What if shared_param._grad is assigned with A's param.grad? Then for thread B the shared_param.grad is always not None?
from pytorch-a3c.
@xuehy It seems that grad or _grad is not shared among processes with global_network.share_memory(). Only the weights are shared. Therefore, each process has its own shared_param.grad.
from pytorch-a3c.
@hugemicrobe
The document says,
. Does it mean that shared_param.grad is also shared?
from pytorch-a3c.
@xuehy
I think that the shared_param.grad is shared is exactly why we this function works, otherwise shared_param.grad would always be none. So it seems to me that when a process detected that some other process has copied its local grad to shared_param.grad, it choose to give up its own update, as it directly returns.
What do you think?
from pytorch-a3c.
If you are not confident with A3C, I've just made my A2C code public: https://github.com/ikostrikov/pytorch-a2c .
from pytorch-a3c.
@SYTMTHU Yes I can understand how it works. But I think in this way the processes are wasting a lot of time doing nothing. During a same period of time, is the update times of parameters with A3C actually the same as a non-distributed one? Can I make out in this way that the only difference is that the updates of A3C come from different environments while the updates of a non-distributed algorithm come from only one running environment?
from pytorch-a3c.
Related Issues (20)
- gradient share problem HOT 1
- GAE parameter name should be lambda not tau. And why is default 1.0? HOT 4
- What's the difference between environment 'Pong-v4' and 'PongDeterministic-v4'
- Reward Smoothing
- Multi-processing or multi-threading HOT 1
- The while True loop of function train?
- NotImplementedError HOT 6
- [Question] Does a2c support distributed processing?
- Question in train.py
- with respect to how to choose an action
- How does A3C aggregate the model from different learner? HOT 1
- Why do we reverse rewards? HOT 1
- Dependency list not provided (environment.yml file)
- Stuck in 'p.join()' HOT 1
- After some steps, all the NNs always output same action HOT 1
- Scepticism about the correctness of the use of the LSTMCell
- Can you provide the python, pytorch, numpy and other versions used in the project?
- TypeError: tuple indices must be integers or slices, not tuple
- if there's no "if shared_param.grad is not None: return" what will happen? HOT 1
- where see the result?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pytorch-a3c.