Comments (2)
PPG has pretty large GPU memory requirement due to the its use of batch to store obs. To make it easier to use smaller GPU memory, we experiment with gradient accumulation.
Below are the gradient accumulation results for the value loss
In [74]: optimizer.zero_grad()
...: args.aux_batch_size = int(args.batch_size * args.n_iteration)
...: args.aux_minibatch_size = int(args.aux_batch_size // (args.n_aux_minibatch))
...: aux_inds = np.arange(args.aux_batch_size,)
...: print("aux phase starts")
...: for i, start in enumerate(range(0, args.aux_batch_size, args.aux_minibatch_size)):
...: end = start + args.aux_minibatch_size
...: aux_minibatch_ind = aux_inds[start:end]
...: m_aux_obs = aux_obs[aux_minibatch_ind].to(device)
...: m_aux_returns = aux_returns[aux_minibatch_ind].to(device)
...:
...: new_values = agent.get_value(m_aux_obs).view(-1)
...: new_aux_values = agent.get_aux_value(m_aux_obs).view(-1)
...: kl_loss = td.kl_divergence(agent.get_pi(m_aux_obs), old_agent.get_pi(m_aux_obs)).mean()
...:
...: real_value_loss = 0.5 * ((new_values - m_aux_returns) ** 2).mean()
...: aux_value_loss = 0.5 * ((new_aux_values - m_aux_returns) ** 2).mean()
...: joint_loss = aux_value_loss + args.beta_clone * kl_loss
...:
...: (real_value_loss).backward()
...: # nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
...: print(agent.critic[-1].weight.grad.sum())
...: break
...:
aux phase starts
tensor(-6.6915, device='cuda:0')
In [75]: optimizer.zero_grad()
...: args.aux_batch_size = int(args.batch_size * args.n_iteration)
...: args.aux_minibatch_size = int(args.aux_batch_size // (args.n_aux_minibatch * args.n_aux_grad_accum))
...: aux_inds = np.arange(args.aux_batch_size,)
...: print("aux phase starts")
...: for i, start in enumerate(range(0, args.aux_batch_size, args.aux_minibatch_size)):
...: end = start + args.aux_minibatch_size
...: aux_minibatch_ind = aux_inds[start:end]
...: m_aux_obs = aux_obs[aux_minibatch_ind].to(device)
...: m_aux_returns = aux_returns[aux_minibatch_ind].to(device)
...:
...: new_values = agent.get_value(m_aux_obs).view(-1)
...: new_aux_values = agent.get_aux_value(m_aux_obs).view(-1)
...: kl_loss = td.kl_divergence(agent.get_pi(m_aux_obs), old_agent.get_pi(m_aux_obs)).mean()
...:
...: real_value_loss = 0.5 * ((new_values - m_aux_returns) ** 2).mean()
...: aux_value_loss = 0.5 * ((new_aux_values - m_aux_returns) ** 2).mean()
...: joint_loss = aux_value_loss + args.beta_clone * kl_loss
...:
...: loss = (real_value_loss) / args.n_aux_grad_accum
...: loss.backward()
...: # nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
...: if (i+1) % args.n_aux_grad_accum == 0:
...: print(agent.critic[-1].weight.grad.sum())
...: break
...:
aux phase starts
tensor(-6.6915, device='cuda:0')
kl loss:
optimizer.zero_grad()
args.aux_batch_size = int(args.batch_size * args.n_iteration)
args.aux_minibatch_size = int(args.aux_batch_size // (args.n_aux_minibatch))
aux_inds = np.arange(args.aux_batch_size,)
print("aux phase starts")
for i, start in enumerate(range(0, args.aux_batch_size, args.aux_minibatch_size)):
end = start + args.aux_minibatch_size
aux_minibatch_ind = aux_inds[start:end]
m_aux_obs = aux_obs[aux_minibatch_ind].to(device)
m_aux_returns = aux_returns[aux_minibatch_ind].to(device)
new_values = agent.get_value(m_aux_obs).view(-1)
new_aux_values = agent.get_aux_value(m_aux_obs).view(-1)
kl_loss = td.kl_divergence(agent.get_pi(m_aux_obs), old_agent.get_pi(m_aux_obs)).mean()
real_value_loss = 0.5 * ((new_values - m_aux_returns) ** 2).mean()
aux_value_loss = 0.5 * ((new_aux_values - m_aux_returns) ** 2).mean()
joint_loss = aux_value_loss + args.beta_clone * kl_loss
(kl_loss).backward()
# nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
print(agent.actor.weight.grad.sum())
break
aux phase starts
tensor(5.4805e-07, device='cuda:0')
optimizer.zero_grad()
args.aux_batch_size = int(args.batch_size * args.n_iteration)
args.aux_minibatch_size = int(args.aux_batch_size // (args.n_aux_minibatch * args.n_aux_grad_accum))
aux_inds = np.arange(args.aux_batch_size,)
print("aux phase starts")
for i, start in enumerate(range(0, args.aux_batch_size, args.aux_minibatch_size)):
end = start + args.aux_minibatch_size
aux_minibatch_ind = aux_inds[start:end]
m_aux_obs = aux_obs[aux_minibatch_ind].to(device)
m_aux_returns = aux_returns[aux_minibatch_ind].to(device)
new_values = agent.get_value(m_aux_obs).view(-1)
new_aux_values = agent.get_aux_value(m_aux_obs).view(-1)
kl_loss = td.kl_divergence(agent.get_pi(m_aux_obs), old_agent.get_pi(m_aux_obs)).mean()
real_value_loss = 0.5 * ((new_values - m_aux_returns) ** 2).mean()
aux_value_loss = 0.5 * ((new_aux_values - m_aux_returns) ** 2).mean()
joint_loss = aux_value_loss + args.beta_clone * kl_loss
loss = (kl_loss) / args.n_aux_grad_accum
loss.backward()
# nn.utils.clip_grad_norm_(agent.parameters(), args.max_grad_norm)
if (i+1) % args.n_aux_grad_accum == 0:
print(agent.actor.weight.grad.sum())
break
aux phase starts
tensor(5.8156e-07, device='cuda:0')
So apparently gradient accumulation works with the value loss but something in td.kl_divergence
is throwing off the calculation for the policy related losses. For the moment let's ignore it due to the benefits of using gradient accumulation.
from cleanrl.
I noticed some nit different between the PPO implementation from your repos and the openai/baselines’s PPO implementation. For example, the policy_head
and value_head
of PPG’s PPO implementation are both initialized with scale 0.1, while the openai/baselines’sPPO implementation uses scale 0.01 for the policy_head
and 1
for the value_head
.
For these nit differences, some implementation decisions are no clear to me as far as how to implement PPG. For this reason, closing this issue now.
from cleanrl.
Related Issues (20)
- SAC cannot converge to optimal policy HOT 3
- Clean Offline RL (CORL) moved to a new fork HOT 1
- Adding new dependencies for ManiSkill2 clean rl
- Potential bug in PPO+RND? HOT 2
- numpy version issue with python 3.10 HOT 1
- Pyyaml error on poetry install HOT 10
- Question about the `noise-clip` parameter in DDPG. HOT 2
- Is is possible for SAC to support gymnasium too as TD3 and PPO ? HOT 4
- Poor Evaluation Performance in PPO HOT 5
- get action in sac_continuous_action.py HOT 2
- [BUG] Different final epsilon and evaluation epsilon for Atari implementations
- expected sequence of length 8 at dim 1 (got 0) HOT 4
- [BUG] Env does not reset when it's terminated HOT 2
- can't upload video running ppo_atari.py,wandb has no data HOT 8
- Why converting observation space to np.float32? HOT 2
- Reproduction util: wrong command path
- Why normarlize advantage only for pg_loss but not for vf_loss? HOT 1
- Contributing PPO + Transformer-XL HOT 3
- clamp in C51
- Gymnasium Version Requirement May Need To Be Updated HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from cleanrl.