From 3db3eabce29346dba1d52c4a162665b8fd6cbade Mon Sep 17 00:00:00 2001 From: Ruofan Kong Date: Mon, 2 Aug 2021 19:31:26 -0700 Subject: [PATCH] No Case: Remove the casting to Float32 in SampleBatch. --- rllib/policy/sample_batch.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/rllib/policy/sample_batch.py b/rllib/policy/sample_batch.py index e1d6625bd7f6..33de608c6e2a 100644 --- a/rllib/policy/sample_batch.py +++ b/rllib/policy/sample_batch.py @@ -104,11 +104,7 @@ def __init__(self, *args, **kwargs): batch_size = len_ lengths.append(len_) if isinstance(v, list): - if len(v) == 0 or not isinstance(v[0], dict): - self[k] = np.array(v, dtype=np.float32) - else: - # If we have field of type dict, let's keep it as np.object - self[k] = np.array(v) + self[k] = np.array(v) if not lengths: raise ValueError("Empty sample batch")