From e69b6074dacd479ae56c9ca3131c1b4c148e82a7 Mon Sep 17 00:00:00 2001 From: n1colas Date: Tue, 11 Mar 2025 12:47:54 +0100 Subject: [PATCH 1/3] checkpoint --- .idea/.gitignore | 8 + Diff-Control/BCZ_LSTM_engine.py | 11 -- Diff-Control/BCZ_engine.py | 13 +- Diff-Control/config.py | 2 + Diff-Control/config/duck_controlnet.yaml | 10 +- Diff-Control/config/duck_diffusion.yaml | 6 +- Diff-Control/config/square_ph_diffusion.yaml | 176 ++++++++++++++++++ Diff-Control/controlnet_engine.py | 32 ++-- Diff-Control/dataset/__init__.py | 1 - Diff-Control/dataset/square_ph.py | 44 +++++ Diff-Control/dataset/tomato_pick_and_place.py | 0 Diff-Control/model/__init__.py | 2 +- Diff-Control/prebuild_engine.py | 135 +++++++++++--- Diff-Control/tomato_engine.py | 11 -- Diff-Control/train.py | 50 ++--- 15 files changed, 394 insertions(+), 107 deletions(-) create mode 100644 .idea/.gitignore create mode 100644 Diff-Control/config/square_ph_diffusion.yaml create mode 100644 Diff-Control/dataset/square_ph.py mode change 100755 => 100644 Diff-Control/dataset/tomato_pick_and_place.py diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/Diff-Control/BCZ_LSTM_engine.py b/Diff-Control/BCZ_LSTM_engine.py index 3bb56ed..0a9414b 100644 --- a/Diff-Control/BCZ_LSTM_engine.py +++ b/Diff-Control/BCZ_LSTM_engine.py @@ -1,22 +1,11 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip from model import BCZ_LSTM from dataset.lid_pick_and_place import * from dataset.tomato_pick_and_place import * from dataset.pick_duck import * from dataset.drum_hit import * -from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle import torch.optim as optim diff --git a/Diff-Control/BCZ_engine.py b/Diff-Control/BCZ_engine.py index 2e4ce59..1bc7b4d 100644 --- a/Diff-Control/BCZ_engine.py +++ b/Diff-Control/BCZ_engine.py @@ -1,23 +1,12 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip from model import ModAttn_ImageBC -from dataset import UR5_dataloader +#from dataset import UR5_dataloader from dataset.lid_pick_and_place import * from dataset.tomato_pick_and_place import * from dataset.pick_duck import * from dataset.drum_hit import * -from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle import torch.optim as optim diff --git a/Diff-Control/config.py b/Diff-Control/config.py index bd7a94e..e480e59 100644 --- a/Diff-Control/config.py +++ b/Diff-Control/config.py @@ -94,6 +94,8 @@ cfg.test.eigen_crop = False # crops according to Eigen NIPS14 cfg.test.garg_crop = False # crops according to Garg ECCV16 + + cfg.network = CN() cfg.network.name = "alpha-MDF" cfg.network.encoder = "" diff --git a/Diff-Control/config/duck_controlnet.yaml b/Diff-Control/config/duck_controlnet.yaml index a244dca..bb61ad5 100644 --- a/Diff-Control/config/duck_controlnet.yaml +++ b/Diff-Control/config/duck_controlnet.yaml @@ -23,8 +23,8 @@ train: input_size_2: '' input_size_3: '' data_path: [ - '/tf/datasets/pick_duck/pick_duck_1', - '/tf/datasets/pick_duck/pick_duck_2', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' ] batch_size: 64 num_epochs: 1000 @@ -45,7 +45,7 @@ test: win_size: 24 model_name: 'duck_controlnet' data_path: [ - '/tf/datasets/pick_duck/pick_duck_1', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' ] sensor_len: 2 channel_img_1: 3 @@ -53,8 +53,8 @@ test: input_size_1: 77 input_size_2: '' input_size_3: '' - checkpoint_path_1: './experiments/duck_model/ema-model-98000' - checkpoint_path_2: './experiments/duck_model/sensor_model-98000' + checkpoint_path_1: './experiments/duck_diffusion/v1.0-ema-model-19600' + checkpoint_path_2: './experiments/duck_diffusion/v1.0-sensor_model-19600' dataset: 'Duck' optim: optim: 'adamw' diff --git a/Diff-Control/config/duck_diffusion.yaml b/Diff-Control/config/duck_diffusion.yaml index 2b53b51..133dd09 100644 --- a/Diff-Control/config/duck_diffusion.yaml +++ b/Diff-Control/config/duck_diffusion.yaml @@ -23,8 +23,8 @@ train: input_size_2: '' input_size_3: '' data_path: [ - '/tf/datasets/pick_duck/pick_duck_1', - '/tf/datasets/pick_duck/pick_duck_2', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' ] batch_size: 64 num_epochs: 3000 @@ -45,7 +45,7 @@ test: win_size: 24 model_name: 'duck_diffusion' data_path: [ - '/tf/datasets/pick_duck/pick_duck_2', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' ] sensor_len: 2 channel_img_1: 3 diff --git a/Diff-Control/config/square_ph_diffusion.yaml b/Diff-Control/config/square_ph_diffusion.yaml new file mode 100644 index 0000000..da62210 --- /dev/null +++ b/Diff-Control/config/square_ph_diffusion.yaml @@ -0,0 +1,176 @@ +mode: + mode: 'train' + model_zoo: 'diffusion-model' + multiprocessing_distributed: False + dist_url: tcp://127.0.0.1:2345 + num_threads: 1 + do_online_eval: True + parameter_path: '' +train: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + seed: 0 + model_name: 'square_ph_diffusion' + dataset: 'SquarePhDataset' + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + #data_path: [ + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + #] + data_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + batch_size: 64 + num_epochs: 3000 + learning_rate: 1e-4 + weight_decay: 1e-6 + adam_eps: 1e-3 + log_freq: 10 + eval_freq: 200 + save_freq: 200 + log_directory: './experiments' + loss: 'mse' +task: + abs_action: true + dataset: + _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset + abs_action: true + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + horizon: 24 # was 16 originally + n_obs_steps: 2 + pad_after: 7 + pad_before: 1 + rotation_rep: rotation_6d + seed: 42 + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + use_cache: true + val_ratio: 0.02 + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + dataset_type: ph + env_runner: + _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunnerSquarePh + abs_action: true + crf: 22 + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + fps: 10 + max_steps: 400 + n_action_steps: 8 + n_envs: 28 + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + render_obs_key: agentview_image + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + test_start_seed: 100000 + tqdm_interval_sec: 1.0 + train_start_idx: 0 + name: square_image + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 +test: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + model_name: 'square_ph_diffusion' + data_path: [ + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + ] + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + checkpoint_path_1: '' + checkpoint_path_2: '' + dataset: 'SquarePhDataset' +optim: + optim: 'adamw' + lr_scheduler: 'polynomial_decay' diff --git a/Diff-Control/controlnet_engine.py b/Diff-Control/controlnet_engine.py index dceadb0..f8d0a09 100644 --- a/Diff-Control/controlnet_engine.py +++ b/Diff-Control/controlnet_engine.py @@ -1,11 +1,4 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip +from dataset.square_ph import SquarePhDataset from model import ( UNetwithControl, SensorModel, @@ -14,24 +7,22 @@ StatefulUNet, ) from dataset.lid_pick_and_place import * -from dataset.tomato_pick_and_place import * +#from config.tomato_pick_and_place import * from dataset.pick_duck import * from dataset.drum_hit import * from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +#from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers import DDPMScheduler from diffusers.training_utils import EMAModel -from diffusers.optimization import get_scheduler class Engine: - def __init__(self, args, logger): + def __init__(self, args, logger, diff_pol_dataset=None): self.args = args self.logger = logger self.batch_size = self.args.train.batch_size @@ -61,13 +52,19 @@ def __init__(self, args, logger): self.data_path = self.args.train.data_path else: self.data_path = self.args.test.data_path - self.dataset = Tomato(self.data_path) + #self.dataset = Tomato(self.data_path) elif self.args.train.dataset == "Duck": if self.mode == "train": self.data_path = self.args.train.data_path else: self.data_path = self.args.test.data_path self.dataset = Duck(self.data_path) + elif self.args.train.dataset == "SquarePhDataset": + if self.mode == "train": + self.data_path = self.args.train.data_path + else: + self.data_path = self.args.test.data_path + self.dataset = SquarePhDataset(diff_pol_dataset) elif self.args.train.dataset == "Drum": if self.mode == "train": self.data_path = self.args.train.data_path @@ -208,7 +205,8 @@ def train(self): batch_size=self.batch_size, shuffle=True, num_workers=8, - collate_fn=tomato_pad_collate_xy_lang, + #collate_fn=tomato_pad_collate_xy_lang, + collate_fn=pad_collate_xy_lang, ) pytorch_total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad @@ -398,7 +396,7 @@ def online_test(self): self.data_path = self.args.train.data_path else: self.data_path = self.args.test.data_path - test_dataset = Tomato(self.data_path) + #test_dataset = Tomato(self.data_path) elif self.args.train.dataset == "Duck": if self.mode == "train": self.data_path = self.args.train.data_path diff --git a/Diff-Control/dataset/__init__.py b/Diff-Control/dataset/__init__.py index bc134d4..349fa43 100644 --- a/Diff-Control/dataset/__init__.py +++ b/Diff-Control/dataset/__init__.py @@ -1,4 +1,3 @@ from dataset import lid_pick_and_place -from dataset import tomato_pick_and_place from dataset import pick_duck from dataset import drum_hit diff --git a/Diff-Control/dataset/square_ph.py b/Diff-Control/dataset/square_ph.py new file mode 100644 index 0000000..9a6b7a5 --- /dev/null +++ b/Diff-Control/dataset/square_ph.py @@ -0,0 +1,44 @@ +import torch +from torchvision import transforms +from einops import rearrange +import clip + +class SquarePhDataset(torch.utils.data.Dataset): + def __init__(self, square_ph_dataset, image_size=(224, 224)): + self.square_ph_dataset = square_ph_dataset + self.image_size = image_size + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.clip_model, _ = clip.load("ViT-B/32", device=self.device) + self.dummy_sentence = "push the square to its goal" + self.transform = transforms.Resize(self.image_size) + + def __len__(self): + return len(self.square_ph_dataset) + + def __getitem__(self, idx): + transition_history = 12 + data = self.square_ph_dataset.__getitem__(idx) + if idx < transition_history: + idx_pre = 0 + else: + idx_pre = idx - transition_history + + prev_data = self.square_ph_dataset.__getitem__(idx_pre) + + # Extract image and action from the PushT dataset + image = data['obs']['agentview_image'][0] + action = rearrange(data['action'], "Ta Da -> Da Ta") + + prior_action = rearrange(prev_data['action'], "Ta Da -> Da Ta") + # Resize and normalize the image to match Duck dataset + image = self.transform(image) + + # Tokenize dummy sentence using CLIP + sentence = clip.tokenize([self.dummy_sentence])[0] + + return image, prior_action, action, sentence + + def get_validation_dataset(self): + val_set = self.square_ph_dataset.get_validation_dataset() + return SquarePhDataset(val_set, image_size=self.image_size) diff --git a/Diff-Control/dataset/tomato_pick_and_place.py b/Diff-Control/dataset/tomato_pick_and_place.py old mode 100755 new mode 100644 diff --git a/Diff-Control/model/__init__.py b/Diff-Control/model/__init__.py index b1300a1..fec3560 100644 --- a/Diff-Control/model/__init__.py +++ b/Diff-Control/model/__init__.py @@ -4,4 +4,4 @@ from model.stateful_module import StatefulUNet from model.film_model import Backbone as ModAttn_ImageBC from model.film_lstm import Backbone as BCZ_LSTM -from model.stateful_module import StatefulControlNet +from model.stateful_module import StatefulControlNet \ No newline at end of file diff --git a/Diff-Control/prebuild_engine.py b/Diff-Control/prebuild_engine.py index 6a4bdae..7324fcd 100644 --- a/Diff-Control/prebuild_engine.py +++ b/Diff-Control/prebuild_engine.py @@ -1,31 +1,23 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip +from dataset.square_ph import SquarePhDataset +from diffusion_policy.policy.base_image_policy import BaseImagePolicy from model import UNetwithControl, SensorModel, StatefulUNet from dataset.lid_pick_and_place import * -from dataset.tomato_pick_and_place import * +#from config.tomato_pick_and_place import * from dataset.pick_duck import * from dataset.drum_hit import * from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.training_utils import EMAModel -from diffusers.optimization import get_scheduler -class Engine: - def __init__(self, args, logger): +class Engine(BaseImagePolicy): + def __init__(self, args, logger, diff_pol_dataset=None): + super().__init__() self.args = args self.logger = logger self.batch_size = self.args.train.batch_size @@ -68,6 +60,14 @@ def __init__(self, args, logger): else: self.data_path = self.args.test.data_path self.dataset = Drum(self.data_path) + elif self.args.train.dataset == "SquarePhDataset": + if self.mode == "train": + self.data_path = self.args.train.data_path + else: + self.data_path = self.args.test.data_path + + self.dataset = SquarePhDataset(diff_pol_dataset) + self.test_dataset = self.dataset.get_validation_dataset() if self.args.train.dataset == "Drum": self.model = StatefulUNet(dim_x=self.dim_x, window_size=self.win_size) @@ -87,8 +87,8 @@ def __init__(self, args, logger): raise TypeError("model must be an instance of nn.Module") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): - self.model.cuda() - self.sensor_model.cuda() + self.model.cuda(self.device) + self.sensor_model.cuda(self.device) # -----------------------------------------------------------------------------# # ------------------------------- use pre-trained? ---------------------------# @@ -142,6 +142,11 @@ def train(self): ) print("Total number of parameters: ", pytorch_total_params) + + test_dataloader = torch.utils.data.DataLoader( + self.test_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8 + ) + # Create optimizer optimizer_ = build_optimizer( [self.model, self.sensor_model], @@ -177,6 +182,8 @@ def train(self): # -----------------------------------------------------------------------------# # --------------------------------- train ------------------------------# # -----------------------------------------------------------------------------# + previous_val_loss = float("inf") + all_val_losses = [] while epoch < self.args.train.num_epochs: step = 0 for data in dataloader: @@ -255,7 +262,8 @@ def train(self): scheduler.step(self.global_step) # Save a model based of a chosen save frequency - if self.global_step != 0 and (epoch + 1) % self.args.train.save_freq == 0: + #if self.global_step != 0 and (epoch + 1) % self.args.train.save_freq == 0: + if self.global_step != 0: self.ema_nets = self.model checkpoint = { "global_step": self.global_step, @@ -288,20 +296,98 @@ def train(self): if ( self.args.mode.do_online_eval and self.global_step != 0 - and (epoch + 1) % self.args.train.eval_freq == 0 + #and (epoch + 1) % self.args.train.eval_freq == 0 ): time.sleep(0.1) self.ema_nets = self.model self.ema_nets.eval() self.sensor_model.eval() self.online_test() - self.ema_nets.train() - self.sensor_model.train() - self.ema.copy_to(self.ema_nets.parameters()) + ############################################################### + val_loss = None + with torch.no_grad(): + val_losses = list() + for data in test_dataloader: + data = [item.to(self.device) for item in data] + (images, prior_action, action, sentence) = data + optimizer_.zero_grad() + + text_features = self.clip_model.encode_text(sentence) + text_features = text_features.clone().detach() + text_features = text_features.to(torch.float32) + + # sample noise to add to actions + noise = torch.randn(action.shape, device=self.device) + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (action.shape[0],), + device=self.device, + ).long() + # add noise to the clean images according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = self.noise_scheduler.add_noise(action, noise, timesteps) + + # forward + img_emb = self.sensor_model(images) + predicted_noise = self.model( + noisy_actions, img_emb, text_features, timesteps + ) + loss = self.criterion(noise, predicted_noise) + val_losses.append(loss.cpu().item()) + if len(val_losses) > 0: + val_loss = torch.mean(torch.tensor(val_losses)).item() + all_val_losses.append((epoch, val_loss)) + + ############################################################### + self.ema_nets.train() + self.sensor_model.train() + self.ema.copy_to(self.ema_nets.parameters()) + + if val_loss > previous_val_loss: + return + + previous_val_loss = val_loss # Update epoch epoch += 1 + def set_normalizer(self, normalizer): + pass + + def predict_action(self, obs_dict): + images = obs_dict["images"] + sentence = obs_dict["sentence"] + + with torch.no_grad(): + text_features = self.clip_model.encode_text(sentence) + text_features = text_features.clone().detach() + text_features = text_features.to(torch.float32) + + img_emb = self.sensor_model(images) + + # initialize action from Guassian noise + noisy_action = torch.randn((1, self.dim_x, self.win_size)).to( + self.device + ) + + # init scheduler + self.noise_scheduler.set_timesteps(50) + + for k in self.noise_scheduler.timesteps: + # predict noise + t = torch.stack([k]).to(self.device) + predicted_noise = self.ema_nets( + noisy_action, img_emb, text_features, t + ) + + # inverse diffusion step (remove noise) + noisy_action = self.noise_scheduler.step( + model_output=predicted_noise, timestep=k, sample=noisy_action + ).prev_sample + return noisy_action + # -----------------------------------------------------------------------------# # --------------------------------- test ------------------------------# # -----------------------------------------------------------------------------# @@ -331,6 +417,13 @@ def online_test(self): else: self.data_path = self.args.test.data_path test_dataset = Drum(self.data_path) + elif self.args.train.dataset == "SquarePhDataset": + if self.mode == "train": + self.data_path = self.args.train.data_path + else: + self.data_path = self.args.test.data_path + test_dataset = self.dataset.get_validation_dataset() + test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=1, shuffle=True, num_workers=8 ) diff --git a/Diff-Control/tomato_engine.py b/Diff-Control/tomato_engine.py index 6da5716..3e8ee7b 100644 --- a/Diff-Control/tomato_engine.py +++ b/Diff-Control/tomato_engine.py @@ -1,24 +1,13 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip from model import UNetwithControl, SensorModel from dataset.tomato_pick_and_place import * from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.training_utils import EMAModel -from diffusers.optimization import get_scheduler class Engine: diff --git a/Diff-Control/train.py b/Diff-Control/train.py index ff9c418..5489c98 100644 --- a/Diff-Control/train.py +++ b/Diff-Control/train.py @@ -4,7 +4,7 @@ import yaml from sys import argv from config import cfg -import controlnet_engine, BCZ_engine, BCZ_LSTM_engine, prebuild_engine, tomato_engine +import controlnet_engine, BCZ_engine, BCZ_LSTM_engine, prebuild_engine#, tomato_engine # mini_controlnet_engine, BCZ_LSTM_engine import warnings @@ -18,7 +18,7 @@ style="%", ) logging.basicConfig(**logging_kwargs) -logger = logging.getLogger("diffusion-pilicy") +logger = logging.getLogger("diffusion-policy") def parse_args(): @@ -38,43 +38,43 @@ def parse_args(): def main(): - cfg, config_file = parse_args() - cfg.freeze() + cfg_diff_ctrl, config_file = parse_args() + cfg_diff_ctrl.freeze() ####### check all the parameter settings ####### - logger.info("{}".format(cfg)) - logger.info("check mode - {}".format(cfg.mode.mode)) + logger.info("{}".format(cfg_diff_ctrl)) + logger.info("check mode - {}".format(cfg_diff_ctrl.mode.mode)) # Create directory for logs and experiment name - if not os.path.exists(cfg.train.log_directory): - os.mkdir(cfg.train.log_directory) - if not os.path.exists(os.path.join(cfg.train.log_directory, cfg.train.model_name)): - os.mkdir(os.path.join(cfg.train.log_directory, cfg.train.model_name)) + if not os.path.exists(cfg_diff_ctrl.train.log_directory): + os.mkdir(cfg_diff_ctrl.train.log_directory) + if not os.path.exists(os.path.join(cfg_diff_ctrl.train.log_directory, cfg_diff_ctrl.train.model_name)): + os.mkdir(os.path.join(cfg_diff_ctrl.train.log_directory, cfg_diff_ctrl.train.model_name)) os.mkdir( - os.path.join(cfg.train.log_directory, cfg.train.model_name, "summaries") + os.path.join(cfg_diff_ctrl.train.log_directory, cfg_diff_ctrl.train.model_name, "summaries") ) else: logger.warning( "This logging directory already exists: {}. Over-writing current files".format( - os.path.join(cfg.train.log_directory, cfg.train.model_name) + os.path.join(cfg_diff_ctrl.train.log_directory, cfg_diff_ctrl.train.model_name) ) ) ####### start the training ####### - if cfg.mode.model_zoo == "controlnet": - train_engine = controlnet_engine.Engine(args=cfg, logger=logger) - elif cfg.mode.model_zoo == "diffusion-model": - train_engine = prebuild_engine.Engine(args=cfg, logger=logger) - elif cfg.mode.model_zoo == "BCZ": - train_engine = BCZ_engine.Engine(args=cfg, logger=logger) - elif cfg.mode.model_zoo == "BCZ_LSTM": - train_engine = BCZ_LSTM_engine.Engine(args=cfg, logger=logger) - elif cfg.mode.model_zoo == "tomato-model": - train_engine = tomato_engine.Engine(args=cfg, logger=logger) + if cfg_diff_ctrl.mode.model_zoo == "controlnet": + train_engine = controlnet_engine.Engine(args=cfg_diff_ctrl, logger=logger) + elif cfg_diff_ctrl.mode.model_zoo == "diffusion-model": + train_engine = prebuild_engine.Engine(args=cfg_diff_ctrl, logger=logger) + elif cfg_diff_ctrl.mode.model_zoo == "BCZ": + train_engine = BCZ_engine.Engine(args=cfg_diff_ctrl, logger=logger) + elif cfg_diff_ctrl.mode.model_zoo == "BCZ_LSTM": + train_engine = BCZ_LSTM_engine.Engine(args=cfg_diff_ctrl, logger=logger) + elif cfg_diff_ctrl.mode.model_zoo == "tomato-model": + train_engine = tomato_engine.Engine(args=cfg_diff_ctrl, logger=logger) - if cfg.mode.mode == "train": + if cfg_diff_ctrl.mode.mode == "train": train_engine.train() - if cfg.mode.mode == "pretrain": + if cfg_diff_ctrl.mode.mode == "pretrain": train_engine.train() - if cfg.mode.mode == "test": + if cfg_diff_ctrl.mode.mode == "test": train_engine.test() From c87ce47d0f44e452bf8623f27b783e83e4ec3e89 Mon Sep 17 00:00:00 2001 From: n1colas Date: Thu, 13 Mar 2025 20:42:08 +0100 Subject: [PATCH 2/3] checkpoint --- Diff-Control/dataset/square_ph.py | 3 ++- Diff-Control/prebuild_engine.py | 28 ++++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/Diff-Control/dataset/square_ph.py b/Diff-Control/dataset/square_ph.py index 9a6b7a5..5087334 100644 --- a/Diff-Control/dataset/square_ph.py +++ b/Diff-Control/dataset/square_ph.py @@ -8,7 +8,8 @@ def __init__(self, square_ph_dataset, image_size=(224, 224)): self.square_ph_dataset = square_ph_dataset self.image_size = image_size - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu" + ) self.clip_model, _ = clip.load("ViT-B/32", device=self.device) self.dummy_sentence = "push the square to its goal" self.transform = transforms.Resize(self.image_size) diff --git a/Diff-Control/prebuild_engine.py b/Diff-Control/prebuild_engine.py index 7324fcd..8ad3f47 100644 --- a/Diff-Control/prebuild_engine.py +++ b/Diff-Control/prebuild_engine.py @@ -1,3 +1,6 @@ +import cv2 +import wandb + from dataset.square_ph import SquarePhDataset from diffusion_policy.policy.base_image_policy import BaseImagePolicy from model import UNetwithControl, SensorModel, StatefulUNet @@ -16,7 +19,7 @@ class Engine(BaseImagePolicy): - def __init__(self, args, logger, diff_pol_dataset=None): + def __init__(self, args, logger, diff_pol_dataset=None, env_runner=None): super().__init__() self.args = args self.logger = logger @@ -35,6 +38,7 @@ def __init__(self, args, logger, diff_pol_dataset=None): self.win_size = self.args.train.win_size self.global_step = 0 self.mode = self.args.mode.mode + self.env_runner = env_runner if self.args.train.dataset == "OpenLid": if self.mode == "train": @@ -142,7 +146,6 @@ def train(self): ) print("Total number of parameters: ", pytorch_total_params) - test_dataloader = torch.utils.data.DataLoader( self.test_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8 ) @@ -292,6 +295,23 @@ def train(self): ), ) + if self.global_step != 0: + runner_log = self.env_runner.run(self) + videos = {k: v for k, v in runner_log.items() if isinstance(v, wandb.Video)} + for k, v in videos: + cap = cv2.VideoCapture(v._path) + frames = [] + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert from OpenCV BGR to RGB + frames.append(frame) + cap.release() + video_tensor = torch.tensor(np.array(frames), dtype=torch.float32).permute(0, 3, 1, 2) / 255.0 + # Add batch dimension (1, T, C, H, W) + video_tensor = video_tensor.unsqueeze(0) + self.writer.add_video(k, video_tensor, self.global_step, fps=10) # online evaluation if ( self.args.mode.do_online_eval @@ -357,8 +377,8 @@ def set_normalizer(self, normalizer): pass def predict_action(self, obs_dict): - images = obs_dict["images"] - sentence = obs_dict["sentence"] + images = obs_dict["image"] + sentence = obs_dict["sentence"].to(self.device).unsqueeze(0) with torch.no_grad(): text_features = self.clip_model.encode_text(sentence) From 55a386cd436c4b4051f0eb2f1c3e2acb2d9032a2 Mon Sep 17 00:00:00 2001 From: n1colas Date: Mon, 14 Apr 2025 18:00:12 +0200 Subject: [PATCH 3/3] backup --- .idea/Diff-Control.iml | 14 ++ Diff-Control/config/pusht_diffusion.yaml | 108 +++++++++++ Diff-Control/config/square_ph_controlnet.yaml | 176 ++++++++++++++++++ Diff-Control/config/square_ph_diffusion.yaml | 4 +- Diff-Control/controlnet_engine.py | 148 ++++++++++++++- Diff-Control/dataset/pusht.py | 45 +++++ Diff-Control/prebuild_engine.py | 46 +++-- 7 files changed, 520 insertions(+), 21 deletions(-) create mode 100644 .idea/Diff-Control.iml create mode 100644 Diff-Control/config/pusht_diffusion.yaml create mode 100644 Diff-Control/config/square_ph_controlnet.yaml create mode 100644 Diff-Control/dataset/pusht.py diff --git a/.idea/Diff-Control.iml b/.idea/Diff-Control.iml new file mode 100644 index 0000000..db5c449 --- /dev/null +++ b/.idea/Diff-Control.iml @@ -0,0 +1,14 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/Diff-Control/config/pusht_diffusion.yaml b/Diff-Control/config/pusht_diffusion.yaml new file mode 100644 index 0000000..78508b9 --- /dev/null +++ b/Diff-Control/config/pusht_diffusion.yaml @@ -0,0 +1,108 @@ +mode: + mode: 'train' + model_zoo: 'diffusion-model' + multiprocessing_distributed: False + dist_url: tcp://127.0.0.1:2345 + num_threads: 1 + do_online_eval: True + parameter_path: '' +train: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + seed: 0 + model_name: 'pusht_diffusion' + dataset: 'PushTDataset' + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + #data_path: [ + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + #] + #data_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + data_path: /home/nicolasl/diffusion_imitation_learning/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr + batch_size: 64 + num_epochs: 3000 + learning_rate: 1e-4 + weight_decay: 1e-6 + adam_eps: 1e-3 + log_freq: 10 + eval_freq: 50 + save_freq: 50 + log_directory: './experiments' + loss: 'mse' +task: + dataset: + _target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset + horizon: 24 # was 16 + max_train_episodes: 90 + pad_after: 7 + pad_before: 1 + seed: 42 + val_ratio: 0.02 + zarr_path: data/pusht/pusht_cchi_v7_replay.zarr + env_runner: + _target_: diffusion_policy.env_runner.pusht_image_runner_diff_ctrl.PushTImageRunnerDiffCtrl + fps: 10 + legacy_test: true + max_steps: 300 + n_action_steps: 8 + n_envs: null + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + test_start_seed: 100000 + train_start_seed: 0 + image_shape: + - 3 + - 96 + - 96 + name: pusht_image + shape_meta: + action: + shape: + - 2 + obs: + agent_pos: + shape: + - 2 + type: low_dim + image: + shape: + - 3 + - 96 + - 96 + type: rgb +test: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + model_name: 'pusht_diffusion' + data_path: [ + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + ] + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + checkpoint_path_1: '' + checkpoint_path_2: '' + dataset: 'PushTDataset' +optim: + optim: 'adamw' + lr_scheduler: 'polynomial_decay' diff --git a/Diff-Control/config/square_ph_controlnet.yaml b/Diff-Control/config/square_ph_controlnet.yaml new file mode 100644 index 0000000..e37e63f --- /dev/null +++ b/Diff-Control/config/square_ph_controlnet.yaml @@ -0,0 +1,176 @@ +mode: + mode: 'train' + model_zoo: 'controlnet' + multiprocessing_distributed: False + dist_url: tcp://127.0.0.1:2345 + num_threads: 1 + do_online_eval: True + parameter_path: '' +train: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + seed: 0 + model_name: 'square_ph_controlnet' + dataset: 'SquarePhDataset' + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + #data_path: [ + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + #] + data_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + batch_size: 64 + num_epochs: 3000 + learning_rate: 1e-4 + weight_decay: 1e-6 + adam_eps: 1e-3 + log_freq: 10 + eval_freq: 50 + save_freq: 50 + log_directory: './experiments' + loss: 'mse' +task: + abs_action: true + dataset: + _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset + abs_action: true + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + horizon: 24 # was 16 originally + n_obs_steps: 2 + pad_after: 7 + pad_before: 1 + rotation_rep: rotation_6d + seed: 42 + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + use_cache: true + val_ratio: 0.02 + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + dataset_type: ph + env_runner: + _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunnerSquarePh + abs_action: true + crf: 22 + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + fps: 10 + max_steps: 400 + n_action_steps: 8 + n_envs: 28 + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + render_obs_key: agentview_image + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + test_start_seed: 100000 + tqdm_interval_sec: 1.0 + train_start_idx: 0 + name: square_image + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 +test: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + model_name: 'square_ph_controlnet' + data_path: [ + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + ] + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + checkpoint_path_1: './experiments/square_ph_diffusion/v1.0-ema-model-104000' + checkpoint_path_2: './experiments/square_ph_diffusion/v1.0-sensor_model-104000' + dataset: 'SquarePhDataset' +optim: + optim: 'adamw' + lr_scheduler: 'polynomial_decay' diff --git a/Diff-Control/config/square_ph_diffusion.yaml b/Diff-Control/config/square_ph_diffusion.yaml index da62210..aeb2868 100644 --- a/Diff-Control/config/square_ph_diffusion.yaml +++ b/Diff-Control/config/square_ph_diffusion.yaml @@ -33,8 +33,8 @@ train: weight_decay: 1e-6 adam_eps: 1e-3 log_freq: 10 - eval_freq: 200 - save_freq: 200 + eval_freq: 50 + save_freq: 50 log_directory: './experiments' loss: 'mse' task: diff --git a/Diff-Control/controlnet_engine.py b/Diff-Control/controlnet_engine.py index f8d0a09..f2f3877 100644 --- a/Diff-Control/controlnet_engine.py +++ b/Diff-Control/controlnet_engine.py @@ -1,4 +1,8 @@ +import cv2 +import wandb + from dataset.square_ph import SquarePhDataset +from diffusion_policy.policy.base_image_policy import BaseImagePolicy from model import ( UNetwithControl, SensorModel, @@ -21,8 +25,9 @@ from diffusers.training_utils import EMAModel -class Engine: - def __init__(self, args, logger, diff_pol_dataset=None): +class Engine(BaseImagePolicy): + def __init__(self, args, logger, diff_pol_dataset=None, env_runner=None): + super().__init__() self.args = args self.logger = logger self.batch_size = self.args.train.batch_size @@ -40,6 +45,7 @@ def __init__(self, args, logger, diff_pol_dataset=None): self.win_size = self.args.train.win_size self.global_step = 0 self.mode = self.args.mode.mode + self.env_runner = env_runner if self.args.train.dataset == "OpenLid": if self.mode == "train": @@ -65,6 +71,8 @@ def __init__(self, args, logger, diff_pol_dataset=None): else: self.data_path = self.args.test.data_path self.dataset = SquarePhDataset(diff_pol_dataset) + self.test_dataset = self.dataset.get_validation_dataset() + elif self.args.train.dataset == "Drum": if self.mode == "train": self.data_path = self.args.train.data_path @@ -213,6 +221,11 @@ def train(self): ) print("Total number of parameters: ", pytorch_total_params) + if self.env_runner: + test_dataloader = torch.utils.data.DataLoader( + self.test_dataset, batch_size=1, shuffle=True, num_workers=8 + ) + """ only build optimizer for the trainable parts """ @@ -257,6 +270,7 @@ def train(self): # -----------------------------------------------------------------------------# # --------------------------------- train ------------------------------# # -----------------------------------------------------------------------------# + previous_val_loss = float("inf") while epoch < self.args.train.num_epochs: step = 0 for data in dataloader: @@ -349,6 +363,7 @@ def train(self): # Save a model based of a chosen save frequency if self.global_step != 0 and (epoch + 1) % self.args.train.save_freq == 0: + #if self.global_step != 0: self.ema_nets = self.model checkpoint = { "global_step": self.global_step, @@ -364,22 +379,138 @@ def train(self): ), ) + if self.global_step != 0 and self.env_runner and (epoch + 1) % self.args.train.eval_freq == 0: + #if self.global_step != 0 and self.env_runner: + print(f"saving step {self.global_step}, epoch: {epoch}") + runner_log = self.env_runner.run(self) + test_mean_score = runner_log['test/mean_score'] + train_mean_score = runner_log['train/mean_score'] + + self.writer.add_scalar( + "test/mean_score", test_mean_score, self.global_step + ) + self.writer.add_scalar( + "train/mean_score", train_mean_score, self.global_step + ) + videos = {k: v for k, v in runner_log.items() if isinstance(v, wandb.Video)} + for k in videos: + cap = cv2.VideoCapture(videos[k]._path) + frames = [] + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert from OpenCV BGR to RGB + frames.append(frame) + cap.release() + video_tensor = torch.tensor(np.array(frames), dtype=torch.float32).permute(0, 3, 1, 2) / 255.0 + # Add batch dimension (1, T, C, H, W) + video_tensor = video_tensor.unsqueeze(0) + self.writer.add_video(k, video_tensor, self.global_step, fps=10) + # online evaluation if ( self.args.mode.do_online_eval and self.global_step != 0 - and (epoch + 1) % self.args.train.eval_freq == 0 + and (epoch + 1) % 25 == 0 ): time.sleep(0.1) self.ema_nets = self.model self.ema_nets.eval() self.online_test() - self.ema_nets.train() - self.ema.copy_to(self.ema_nets.parameters()) + ############################################################## + + val_loss = None + with torch.no_grad(): + val_losses = list() + for data in test_dataloader: + data = [item.to(self.device) for item in data] + (images, prior_action, action, sentence) = data + optimizer_.zero_grad() + + text_features = self.clip_model.encode_text(sentence) + text_features = text_features.clone().detach() + text_features = text_features.to(torch.float32) + + # sample noise to add to actions + noise = torch.randn(action.shape, device=self.device) + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (action.shape[0],), + device=self.device, + ).long() + # add noise to the clean images according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = self.noise_scheduler.add_noise(action, noise, timesteps) + + # forward + img_emb = self.sensor_model(images) + predicted_noise = self.model( + noisy_actions, img_emb, text_features, prior_action, timesteps + ) + loss = self.criterion(noise, predicted_noise) + val_losses.append(loss.cpu().item()) + if len(val_losses) > 0: + val_loss = torch.mean(torch.tensor(val_losses)).item() + self.writer.add_scalar( + "val_loss", val_loss, self.global_step + ) + ############################################################## + self.ema_nets.train() + self.ema.copy_to(self.ema_nets.parameters()) + + if (epoch + 1) % self.args.train.eval_freq == 0: + if val_loss > previous_val_loss and epoch >= 300: + return + + previous_val_loss = val_loss # Update epoch epoch += 1 + + def set_normalizer(self, normalizer): + pass + + def predict_action(self, obs_dict): + images = obs_dict["image"] + prior_action = obs_dict["prior_action"] + sentence = obs_dict["sentence"].to(self.device).unsqueeze(0) + + with torch.no_grad(): + text_features = self.clip_model.encode_text(sentence) + text_features = text_features.clone().detach() + text_features = text_features.to(torch.float32) + text_features = torch.stack([text_features] * 28, dim=0) + text_features = text_features.squeeze(1) + + img_emb = self.sensor_model(images) + + # initialize action from Guassian noise + noisy_action = torch.randn((28, self.dim_x, self.win_size)).to( + self.device + ) + + # TODO: add previous action + + # init scheduler + self.noise_scheduler.set_timesteps(50) + + for k in self.noise_scheduler.timesteps: + # predict noise + t = [torch.stack([k]).to(self.device) for _ in range(28)] # Shape (1,1) per tensor + t = torch.cat(t, dim=0) + predicted_noise = self.ema_nets( + noisy_action, img_emb, text_features, prior_action, t + ) + + # inverse diffusion step (remove noise) + noisy_action = self.noise_scheduler.step( + model_output=predicted_noise, timestep=k, sample=noisy_action + ).prev_sample + return noisy_action # -----------------------------------------------------------------------------# # --------------------------------- test ------------------------------# # -----------------------------------------------------------------------------# @@ -409,6 +540,13 @@ def online_test(self): else: self.data_path = self.args.test.data_path test_dataset = Drum(self.data_path) + elif self.args.train.dataset == "SquarePhDataset": + if self.mode == "train": + self.data_path = self.args.train.data_path + else: + self.data_path = self.args.test.data_path + test_dataset = self.dataset.get_validation_dataset() + test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=1, shuffle=True, num_workers=8 ) diff --git a/Diff-Control/dataset/pusht.py b/Diff-Control/dataset/pusht.py new file mode 100644 index 0000000..5d61704 --- /dev/null +++ b/Diff-Control/dataset/pusht.py @@ -0,0 +1,45 @@ +import torch +from torchvision import transforms +from einops import rearrange +import clip + +class PushTDataset(torch.utils.data.Dataset): + def __init__(self, pusht_dataset, image_size=(224, 224)): + self.pusht_dataset = pusht_dataset + self.image_size = image_size + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu" + ) + self.clip_model, _ = clip.load("ViT-B/32", device=self.device) + self.dummy_sentence = "push the T to its goal" + self.transform = transforms.Resize(self.image_size) + + def __len__(self): + return len(self.square_ph_dataset) + + def __getitem__(self, idx): + transition_history = 12 + data = self.pusht_dataset.__getitem__(idx) + if idx < transition_history: + idx_pre = 0 + else: + idx_pre = idx - transition_history + + prev_data = self.pusht_dataset.__getitem__(idx_pre) + + # Extract image and action from the PushT dataset + image = data['obs']['agentview_image'] + action = rearrange(data['action'], "B Ta Da -> B Da Ta") + + prior_action = rearrange(prev_data['action'], "Ta Da -> Da Ta") + # Resize and normalize the image to match Duck dataset + image = self.transform(image) + + # Tokenize dummy sentence using CLIP + sentence = clip.tokenize([self.dummy_sentence])[0] + + return image, prior_action, action, sentence + + def get_validation_dataset(self): + val_set = self.pusht_dataset.get_validation_dataset() + return PushTDataset(val_set, image_size=self.image_size) diff --git a/Diff-Control/prebuild_engine.py b/Diff-Control/prebuild_engine.py index 8ad3f47..295d257 100644 --- a/Diff-Control/prebuild_engine.py +++ b/Diff-Control/prebuild_engine.py @@ -146,9 +146,10 @@ def train(self): ) print("Total number of parameters: ", pytorch_total_params) - test_dataloader = torch.utils.data.DataLoader( - self.test_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8 - ) + if self.env_runner: + test_dataloader = torch.utils.data.DataLoader( + self.test_dataset, batch_size=1, shuffle=True, num_workers=8 + ) # Create optimizer optimizer_ = build_optimizer( @@ -265,8 +266,8 @@ def train(self): scheduler.step(self.global_step) # Save a model based of a chosen save frequency - #if self.global_step != 0 and (epoch + 1) % self.args.train.save_freq == 0: - if self.global_step != 0: + if self.global_step != 0 and (epoch + 1) % self.args.train.save_freq == 0: + #if self.global_step != 0: self.ema_nets = self.model checkpoint = { "global_step": self.global_step, @@ -295,11 +296,21 @@ def train(self): ), ) - if self.global_step != 0: + if self.global_step != 0 and self.env_runner and (epoch + 1) % self.args.train.eval_freq == 0: + print(f"saving step {self.global_step}, epoch: {epoch}") runner_log = self.env_runner.run(self) + test_mean_score = runner_log['test/mean_score'] + train_mean_score = runner_log['train/mean_score'] + + self.writer.add_scalar( + "test/mean_score", test_mean_score, self.global_step + ) + self.writer.add_scalar( + "train/mean_score", train_mean_score, self.global_step + ) videos = {k: v for k, v in runner_log.items() if isinstance(v, wandb.Video)} - for k, v in videos: - cap = cv2.VideoCapture(v._path) + for k in videos: + cap = cv2.VideoCapture(videos[k]._path) frames = [] while cap.isOpened(): ret, frame = cap.read() @@ -316,7 +327,7 @@ def train(self): if ( self.args.mode.do_online_eval and self.global_step != 0 - #and (epoch + 1) % self.args.train.eval_freq == 0 + and (epoch + 1) % 25 == 0 ): time.sleep(0.1) self.ema_nets = self.model @@ -358,6 +369,9 @@ def train(self): val_losses.append(loss.cpu().item()) if len(val_losses) > 0: val_loss = torch.mean(torch.tensor(val_losses)).item() + self.writer.add_scalar( + "val_loss", val_loss, self.global_step + ) all_val_losses.append((epoch, val_loss)) ############################################################### @@ -365,10 +379,11 @@ def train(self): self.sensor_model.train() self.ema.copy_to(self.ema_nets.parameters()) - if val_loss > previous_val_loss: - return + if (epoch + 1) % self.args.train.eval_freq == 0: + if val_loss > previous_val_loss and epoch >= 300: + return - previous_val_loss = val_loss + previous_val_loss = val_loss # Update epoch epoch += 1 @@ -384,11 +399,13 @@ def predict_action(self, obs_dict): text_features = self.clip_model.encode_text(sentence) text_features = text_features.clone().detach() text_features = text_features.to(torch.float32) + text_features = torch.stack([text_features] * 28, dim=0) + text_features = text_features.squeeze(1) img_emb = self.sensor_model(images) # initialize action from Guassian noise - noisy_action = torch.randn((1, self.dim_x, self.win_size)).to( + noisy_action = torch.randn((28, self.dim_x, self.win_size)).to( self.device ) @@ -397,7 +414,8 @@ def predict_action(self, obs_dict): for k in self.noise_scheduler.timesteps: # predict noise - t = torch.stack([k]).to(self.device) + t = [torch.stack([k]).to(self.device) for _ in range(28)] # Shape (1,1) per tensor + t = torch.cat(t, dim=0) predicted_noise = self.ema_nets( noisy_action, img_emb, text_features, t )