Skip to content

Commit 50ec092

Browse files
author
Jan Michelfeit
committed
#625 refactor RunningMeanAndVar
1 parent dad72a2 commit 50ec092

File tree

3 files changed

+17
-22
lines changed

3 files changed

+17
-22
lines changed

src/imitation/policies/replay_buffer_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def sample(self, *args, **kwargs):
175175

176176
# Normalize to have mean of 0 and standard deviation of 1
177177
self.entropy_stats.update(entropies)
178-
entropies -= self.entropy_stats.mean
178+
entropies -= self.entropy_stats.running_mean
179179
entropies /= self.entropy_stats.std
180180

181181
entropies_th = (

src/imitation/util/util.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -370,30 +370,25 @@ def __init__(
370370
device: Optional[str] = None,
371371
) -> None:
372372
"""Initialize blank mean, variance, count."""
373-
self.mean = th.zeros(shape, device=device)
373+
self.running_mean = th.zeros(shape, device=device)
374374
self.M2 = th.zeros(shape, device=device)
375375
self.count = 0
376376

377-
def update(self, x: th.Tensor) -> None:
377+
def update(self, batch: th.Tensor) -> None:
378378
"""Update the mean and variance with a batch `x`."""
379379
with th.no_grad():
380-
batch_mean = th.mean(x, dim=0)
381-
batch_var = th.var(x, dim=0, unbiased=False)
382-
batch_count = x.shape[0]
383-
batch_M2 = batch_var * batch_count
384-
if self.count == 0:
385-
self.count = batch_count
386-
self.mean = batch_mean
387-
self.M2 = batch_M2
388-
return
389-
390-
delta = batch_mean - self.mean
391-
total_count = self.count + batch_count
392-
self.mean += delta * batch_count / total_count
393-
394-
self.M2 += batch_M2 + delta * delta * self.count * batch_count / total_count
395-
396-
self.count = total_count
380+
batch_mean = th.mean(batch, dim=0)
381+
batch_var = th.var(batch, dim=0, unbiased=False)
382+
batch_count = batch.shape[0]
383+
384+
delta = batch_mean - self.running_mean
385+
tot_count = self.count + batch_count
386+
self.running_mean += delta * batch_count / tot_count
387+
388+
self.M2 += batch_var * batch_count
389+
self.M2 += th.square(delta) * self.count * batch_count / tot_count
390+
391+
self.count += batch_count
397392

398393
@property
399394
def var(self) -> th.Tensor:

tests/util/test_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_RunningMeanAndVar():
127127
first_half = data[:10]
128128
running_stats.update(first_half)
129129
np.testing.assert_allclose(
130-
running_stats.mean,
130+
running_stats.running_mean,
131131
first_half.mean(dim=0),
132132
atol=1e-5,
133133
rtol=1e-4,
@@ -141,7 +141,7 @@ def test_RunningMeanAndVar():
141141

142142
running_stats.update(data[10:])
143143
np.testing.assert_allclose(
144-
running_stats.mean,
144+
running_stats.running_mean,
145145
data.mean(dim=0),
146146
atol=1e-5,
147147
rtol=1e-4,

0 commit comments

Comments
 (0)