Comments (8)
请参考 Policy Gradient 的代码实现:
https://github.com/datawhalechina/leedeeprl-notes/tree/master/codes/PolicyGradient
from easy-rl.
这里他是手动算的,不是调用现成交叉熵函数。
看到这个代码,我还有一个疑惑是 self.policy_net
是一个继承 torch.nn.Module
对象的实例,self.policy_net(state)
感觉像是调用了 forward()
方法,这种写法是一个语法糖么?我读了下 torch.nn.Module
的文档,没有发现提到这种写法。
相关代码片如下:
from easy-rl.
self.policy_net(state)
可以参考pytorch tutorial中的以下部分:
# Forward pass: compute predicted y by passing x to the model. Module objects
# override the __call__ operator so you can call them like functions. When
# doing so you pass a Tensor of input data to the Module and it produces
# a Tensor of output data.
y_pred = model(xx)
from easy-rl.
谢谢,类似 self.policy_net(state)
这种形式要求 self.policy_net
对象是 callable 的,让对象 callable 的一种方式就是实现 __call__
方法。无论是 Pytorch 的 torch.nn.Module 类,还是 TensorFlow 的 tf.keras.Model 类,在它们的继承的类中都实现了 __call__
方法,所以可以通过 self.policy_net(state)
这种形式传入 state 参数。
至于交叉熵函数,我看 Mofan Zhou 使用的是 tf.nn.sparse_softmax_cross_entropy_with_logits
,不过对背后怎么对应成下面这个公式还有些没明白:
from easy-rl.
- 在 Policy Gradient 中,我们使用的是如上图所示的损失函数;
- Mofan Zhou 的代码就是这样计算的(请看代码注释),具体函数使用请搜索官方文档。
# to maximize total reward (log_p * R) is to minimize -(log_p * R), and the tf only have minimize(loss)
neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=all_act, labels=self.tf_acts) # this is negative log of chosen action
# or in this way:
# neg_log_prob = tf.reduce_sum(-tf.log(self.all_act_prob)*tf.one_hot(self.tf_acts, self.n_actions), axis=1)
loss = tf.reduce_mean(neg_log_prob * self.tf_vt) # reward guided loss
from easy-rl.
谢谢,没注意到注释的部分,结合注释看理解了一些。关于
# neg_log_prob = tf.reduce_sum(-tf.log(self.all_act_prob)*tf.one_hot(self.tf_acts, self.n_actions), axis=1)
这一行中的
tf.log(self.all_act_prob)*tf.one_hot(self.tf_acts, self.n_actions)
是和 第四章 策略梯度 中下图一致
我不太理解的是图中红框部分为什么是那样表示,我的理解是
log([0.2, 0.5, 0.3] * [0, 1, 0]) = log(0.5)
其中 ([0.2, 0.5, 0.3] * [0, 1, 0])
这部分对应 「使用 π 策略,在 s_{t} 状态下输出 a_{t} 动作的概率」,即下图中的公式:
照我这样想的话,代码应该是
neg_log_prob = tf.reduce_sum(-tf.log(self.all_act_prob*tf.one_hot(self.tf_acts, self.n_actions)), axis=1)
而不是 mofan 的
neg_log_prob = tf.reduce_sum(-tf.log(self.all_act_prob)*tf.one_hot(self.tf_acts, self.n_actions), axis=1)
请问是我的理解哪里有问题么?
from easy-rl.
from easy-rl.
现在一想,有点忘了为啥当初没理解 tf.log[0.2, 0.5, 0.3] 计算就是依次对 list 里的每个元素求 log。
exp_list = [tf.math.exp(tf.constant(1.0)), tf.math.exp(tf.constant(2.0))]
print(tf.math.log(exp_list))
# tf.Tensor([1. 2.], shape=(2,), dtype=float32)
现在反倒没想明白的是为啥是两个行向量相乘,估计只是一种表现形式吧...
from easy-rl.
Related Issues (20)
- 1.7.1 Gym示例 返回值增多了 HOT 3
- 第四章图4.10标注是不是有误? HOT 1
- Edit problem in Chapter3 HOT 1
- 随书代码在哪 HOT 6
- 第五章勘误 HOT 1
- 内容勘误? HOT 3
- 添加参考文献 HOT 1
- SAC代码问题 HOT 2
- 4.3 REINFORCE:蒙特卡洛策略梯度 HOT 1
- 最新的版本,可以出PDF吗 HOT 2
- value_iteration 算法不收敛 ? HOT 1
- 错别字 HOT 2
- DuelingDQN.ipynb中可能存在的两个BUG~
- 我在运行DQN代码时,初始的state总会多一个值。
- 图6.8左下角标识应该是“动作价值(Q)”? HOT 1
- DDPG算法实现出现问题
- 关于书中DDPG算法的疑问
- PPO算法的实现, 为啥要给概率取对数? HOT 2
- 连续动作空间的PPO算法 HOT 2
- dqn算法问题
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from easy-rl.