import numpy as np
# taken from https://github.com/openai/baselines/blob/master/baselines/common/vec_env/vec_normalize.py
class RunningMeanStd(object):
def __init__(self, epsilon=1e-4, shape=()):
self.mean = np.zeros(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.count = epsilon
def update(self, x):
batch_mean = np.mean(x, axis=0)
batch_var = np.var(x, axis=0)
batch_count = 1
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
def update_mean_var_count_from_moments(mean, var, count, batch_mean, batch_var, batch_count):
delta = batch_mean - mean
tot_count = count + batch_count
new_mean = mean + delta * batch_count / tot_count
m_a = var * count
m_b = batch_var * batch_count
M2 = m_a + m_b + np.square(delta) * count * batch_count / tot_count
new_var = M2 / tot_count
new_count = tot_count
return new_mean, new_var, new_count
print("incorrect RunningMeanStd uses the same mean across data array")
ob_rms = RunningMeanStd(shape=(2,))
state = np.array([0.52191359, 0.24749929])
print(ob_rms.mean)
print(ob_rms.var)
ob_rms.update(state)
print(ob_rms.mean)
print(ob_rms.var)
class RunningMeanStd(object):
def __init__(self, epsilon=1e-4, shape=()):
self.mean = np.zeros(shape, 'float64')
self.var = np.ones(shape, 'float64')
self.count = epsilon
def update(self, x):
batch_mean = np.mean([x], axis=0)
batch_var = np.var([x], axis=0)
batch_count = 1
self.update_from_moments(batch_mean, batch_var, batch_count)
def update_from_moments(self, batch_mean, batch_var, batch_count):
self.mean, self.var, self.count = update_mean_var_count_from_moments(
self.mean, self.var, self.count, batch_mean, batch_var, batch_count)
print("correct RunningMeanStd uses different means across dimensions")
ob_rms = RunningMeanStd(shape=(2,))
state = np.array([0.52191359, 0.24749929])
print(ob_rms.mean)
print(ob_rms.var)
ob_rms.update(state)
print(ob_rms.mean)
print(ob_rms.var)
incorrect RunningMeanStd uses the same mean across data array
[0. 0.]
[1. 1.]
[0.38466797 0.38466797]
[0.01893871 0.01893871]
correct RunningMeanStd uses different means across dimensions
[0. 0.]
[1. 1.]
[0.5218614 0.24747454]
[0.00012722 0.00010611]