Code Monkey home page Code Monkey logo

Comments (8)

qiwang067 avatar qiwang067 commented on May 19, 2024

请参考 Policy Gradient 的代码实现:
https://github.com/datawhalechina/leedeeprl-notes/tree/master/codes/PolicyGradient

from easy-rl.

Ynjxsjmh avatar Ynjxsjmh commented on May 19, 2024

https://github.com/datawhalechina/leedeeprl-notes/blob/84294e6e26d4c7e05f34ca553928a28ebb17cedf/codes/PolicyGradient/agent.py#L62-L65

这里他是手动算的,不是调用现成交叉熵函数。

看到这个代码,我还有一个疑惑是 self.policy_net 是一个继承 torch.nn.Module 对象的实例,self.policy_net(state) 感觉像是调用了 forward() 方法,这种写法是一个语法糖么?我读了下 torch.nn.Module 的文档,没有发现提到这种写法。

相关代码片如下:

https://github.com/datawhalechina/leedeeprl-notes/blob/84294e6e26d4c7e05f34ca553928a28ebb17cedf/codes/PolicyGradient/agent.py#L63

https://github.com/datawhalechina/leedeeprl-notes/blob/84294e6e26d4c7e05f34ca553928a28ebb17cedf/codes/PolicyGradient/agent.py#L23

https://github.com/datawhalechina/leedeeprl-notes/blob/84294e6e26d4c7e05f34ca553928a28ebb17cedf/codes/PolicyGradient/model.py#L14-L27

from easy-rl.

johnjim0816 avatar johnjim0816 commented on May 19, 2024

self.policy_net(state)可以参考pytorch tutorial中的以下部分:

# Forward pass: compute predicted y by passing x to the model. Module objects
# override the __call__ operator so you can call them like functions. When
# doing so you pass a Tensor of input data to the Module and it produces
# a Tensor of output data.
y_pred = model(xx)

from easy-rl.

Ynjxsjmh avatar Ynjxsjmh commented on May 19, 2024

谢谢,类似 self.policy_net(state) 这种形式要求 self.policy_net 对象是 callable 的,让对象 callable 的一种方式就是实现 __call__ 方法。无论是 Pytorch 的 torch.nn.Module 类,还是 TensorFlow 的 tf.keras.Model 类,在它们的继承的类中都实现了 __call__ 方法,所以可以通过 self.policy_net(state) 这种形式传入 state 参数。

至于交叉熵函数,我看 Mofan Zhou 使用的是 tf.nn.sparse_softmax_cross_entropy_with_logits,不过对背后怎么对应成下面这个公式还有些没明白:

image

https://github.com/MorvanZhou/Reinforcement-learning-with-tensorflow/blob/1fd1c08a6c8928d027fb75d51ebf6f9441e3dc33/contents/7_Policy_gradient_softmax/RL_brain.py#L78

from easy-rl.

qiwang067 avatar qiwang067 commented on May 19, 2024

关于交叉熵函数,你的理解有两个问题。
image

  1. 在 Policy Gradient 中,我们使用的是如上图所示的损失函数;
  2. Mofan Zhou 的代码就是这样计算的(请看代码注释),具体函数使用请搜索官方文档。
# to maximize total reward (log_p * R) is to minimize -(log_p * R), and the tf only have minimize(loss)
neg_log_prob = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=all_act, labels=self.tf_acts)   # this is negative log of chosen action
# or in this way:
# neg_log_prob = tf.reduce_sum(-tf.log(self.all_act_prob)*tf.one_hot(self.tf_acts, self.n_actions), axis=1)
loss = tf.reduce_mean(neg_log_prob * self.tf_vt)  # reward guided loss

from easy-rl.

Ynjxsjmh avatar Ynjxsjmh commented on May 19, 2024

谢谢,没注意到注释的部分,结合注释看理解了一些。关于

# neg_log_prob = tf.reduce_sum(-tf.log(self.all_act_prob)*tf.one_hot(self.tf_acts, self.n_actions), axis=1)

这一行中的

tf.log(self.all_act_prob)*tf.one_hot(self.tf_acts, self.n_actions)

是和 第四章 策略梯度 中下图一致

image

我不太理解的是图中红框部分为什么是那样表示,我的理解是

log([0.2, 0.5, 0.3] * [0, 1, 0]) = log(0.5)

其中 ([0.2, 0.5, 0.3] * [0, 1, 0]) 这部分对应 「使用 π 策略,在 s_{t} 状态下输出 a_{t} 动作的概率」,即下图中的公式:

image

照我这样想的话,代码应该是

neg_log_prob = tf.reduce_sum(-tf.log(self.all_act_prob*tf.one_hot(self.tf_acts, self.n_actions)), axis=1)

而不是 mofan 的

neg_log_prob = tf.reduce_sum(-tf.log(self.all_act_prob)*tf.one_hot(self.tf_acts, self.n_actions), axis=1)

请问是我的理解哪里有问题么?

from easy-rl.

qiwang067 avatar qiwang067 commented on May 19, 2024
  1. 红框部分的计算过程如下图所示:
    image
  2. mofan 的代码是对的,先求 log 概率,再乘以 one-hot 向量

from easy-rl.

Ynjxsjmh avatar Ynjxsjmh commented on May 19, 2024

现在一想,有点忘了为啥当初没理解 tf.log[0.2, 0.5, 0.3] 计算就是依次对 list 里的每个元素求 log。

exp_list = [tf.math.exp(tf.constant(1.0)), tf.math.exp(tf.constant(2.0))]

print(tf.math.log(exp_list))

# tf.Tensor([1. 2.], shape=(2,), dtype=float32)

现在反倒没想明白的是为啥是两个行向量相乘,估计只是一种表现形式吧...

from easy-rl.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.