diff --git a/python/magent/builtin/mx_model/a2c.py b/python/magent/builtin/mx_model/a2c.py index 3e5f53e18..b2baa3076 100644 --- a/python/magent/builtin/mx_model/a2c.py +++ b/python/magent/builtin/mx_model/a2c.py @@ -189,10 +189,9 @@ def train(self, sample_buffer, print_every=1000): # calc buffer size n = 0 for episode in sample_buffer.episodes(): - if episode.terminal: - n += len(episode.rewards) - else: - n += len(episode.rewards) - 1 + n += len(episode.rewards) + if not episode.terminal: + n -= 1 if n == 0: return [0.0, 0.0, 0.0], 0.0