Skip to content

Commit

Permalink
still trying to find the discrepency
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhengyiLuo committed Nov 4, 2022
1 parent f6dcd3e commit 7a580ee
Show file tree
Hide file tree
Showing 5 changed files with 348 additions and 222 deletions.
Binary file added a.pkl
Binary file not shown.
2 changes: 1 addition & 1 deletion config/statear/kin_poly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ data_wild_file: real_annotations
of_file_wild: real_img_feats

mujoco_model: 'humanoid_smpl_neutral_mesh'
seed: 2
seed: 4
fr_num: 100
augment: false

Expand Down
14 changes: 8 additions & 6 deletions kin_poly/core/agent_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def setup_data_loader(self):
self.test_data_loaders = []
self.data_loader = data_loader = StateARDataset(cfg,
cfg.data,
sim=False)
self.test_data_loaders.append(StateARDataset(cfg, "test", sim = True))
sim=True)
self.test_data_loaders.append(StateARDataset(cfg, "test"))

from kin_poly.utils.statear_smpl_config import Config

Expand All @@ -338,10 +338,9 @@ def setup_data_loader(self):
wild=True,
create_dirs=False,
mujoco_path=
"assets/mujoco_models/%s.xml",
"/hdd/zen/dev/copycat/Copycat/assets/mujoco_models/%s.xml",
)
self.test_data_loaders.append(
StateARDataset(cfg_wild, "test", sim=True))
self.test_data_loaders.append(StateARDataset(cfg_wild, "test"))

def load_checkpoint(self, i_iter):
cfg, device, dtype = self.cfg, self.device, self.dtype
Expand Down Expand Up @@ -572,6 +571,9 @@ def sample_worker(self, pid, queue, min_batch_size):
sampling_temp=self.cfg.policy_specs.get("sampling_temp", 0.5),
sampling_freq=self.cfg.policy_specs.get("sampling_freq", 0.9),
)
self.data_loader.curr_key

np.random.random(5)
# context_sample = self.data_loader.sample_seq(freq_dict = self.freq_dict, sampling_temp = self.cfg.policy_specs.get("sampling_temp", 0.5), sampling_freq = self.cfg.policy_specs.get("sampling_freq", 0.9), full_sample = True if self.data_loader.get_seq_len(self.fit_ind) < 1000 else False)
# context_sample = self.data_loader.sample_seq(freq_dict = self.freq_dict, sampling_temp = 0.5)
# context_sample = self.data_loader.sample_seq()
Expand Down Expand Up @@ -621,7 +623,7 @@ def sample_worker(self, pid, queue, min_batch_size):
if flags.debug:
np.set_printoptions(precision=4, suppress=1)
print(c_reward, c_info)

# add end reward
if self.end_reward and info.get("end", False):
reward += self.env.end_reward
Expand Down
Loading

0 comments on commit 7a580ee

Please sign in to comment.