diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..13566b8 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/Diff-Control.iml b/.idea/Diff-Control.iml new file mode 100644 index 0000000..db5c449 --- /dev/null +++ b/.idea/Diff-Control.iml @@ -0,0 +1,14 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/Diff-Control/BCZ_LSTM_engine.py b/Diff-Control/BCZ_LSTM_engine.py index 3bb56ed..0a9414b 100644 --- a/Diff-Control/BCZ_LSTM_engine.py +++ b/Diff-Control/BCZ_LSTM_engine.py @@ -1,22 +1,11 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip from model import BCZ_LSTM from dataset.lid_pick_and_place import * from dataset.tomato_pick_and_place import * from dataset.pick_duck import * from dataset.drum_hit import * -from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle import torch.optim as optim diff --git a/Diff-Control/BCZ_engine.py b/Diff-Control/BCZ_engine.py index 2e4ce59..1bc7b4d 100644 --- a/Diff-Control/BCZ_engine.py +++ b/Diff-Control/BCZ_engine.py @@ -1,23 +1,12 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip from model import ModAttn_ImageBC -from dataset import UR5_dataloader +#from dataset import UR5_dataloader from dataset.lid_pick_and_place import * from dataset.tomato_pick_and_place import * from dataset.pick_duck import * from dataset.drum_hit import * -from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle import torch.optim as optim diff --git a/Diff-Control/config.py b/Diff-Control/config.py index bd7a94e..e480e59 100644 --- a/Diff-Control/config.py +++ b/Diff-Control/config.py @@ -94,6 +94,8 @@ cfg.test.eigen_crop = False # crops according to Eigen NIPS14 cfg.test.garg_crop = False # crops according to Garg ECCV16 + + cfg.network = CN() cfg.network.name = "alpha-MDF" cfg.network.encoder = "" diff --git a/Diff-Control/config/duck_controlnet.yaml b/Diff-Control/config/duck_controlnet.yaml index a244dca..bb61ad5 100644 --- a/Diff-Control/config/duck_controlnet.yaml +++ b/Diff-Control/config/duck_controlnet.yaml @@ -23,8 +23,8 @@ train: input_size_2: '' input_size_3: '' data_path: [ - '/tf/datasets/pick_duck/pick_duck_1', - '/tf/datasets/pick_duck/pick_duck_2', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' ] batch_size: 64 num_epochs: 1000 @@ -45,7 +45,7 @@ test: win_size: 24 model_name: 'duck_controlnet' data_path: [ - '/tf/datasets/pick_duck/pick_duck_1', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' ] sensor_len: 2 channel_img_1: 3 @@ -53,8 +53,8 @@ test: input_size_1: 77 input_size_2: '' input_size_3: '' - checkpoint_path_1: './experiments/duck_model/ema-model-98000' - checkpoint_path_2: './experiments/duck_model/sensor_model-98000' + checkpoint_path_1: './experiments/duck_diffusion/v1.0-ema-model-19600' + checkpoint_path_2: './experiments/duck_diffusion/v1.0-sensor_model-19600' dataset: 'Duck' optim: optim: 'adamw' diff --git a/Diff-Control/config/duck_diffusion.yaml b/Diff-Control/config/duck_diffusion.yaml index 2b53b51..133dd09 100644 --- a/Diff-Control/config/duck_diffusion.yaml +++ b/Diff-Control/config/duck_diffusion.yaml @@ -23,8 +23,8 @@ train: input_size_2: '' input_size_3: '' data_path: [ - '/tf/datasets/pick_duck/pick_duck_1', - '/tf/datasets/pick_duck/pick_duck_2', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' ] batch_size: 64 num_epochs: 3000 @@ -45,7 +45,7 @@ test: win_size: 24 model_name: 'duck_diffusion' data_path: [ - '/tf/datasets/pick_duck/pick_duck_2', + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' ] sensor_len: 2 channel_img_1: 3 diff --git a/Diff-Control/config/pusht_diffusion.yaml b/Diff-Control/config/pusht_diffusion.yaml new file mode 100644 index 0000000..78508b9 --- /dev/null +++ b/Diff-Control/config/pusht_diffusion.yaml @@ -0,0 +1,108 @@ +mode: + mode: 'train' + model_zoo: 'diffusion-model' + multiprocessing_distributed: False + dist_url: tcp://127.0.0.1:2345 + num_threads: 1 + do_online_eval: True + parameter_path: '' +train: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + seed: 0 + model_name: 'pusht_diffusion' + dataset: 'PushTDataset' + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + #data_path: [ + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + #] + #data_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + data_path: /home/nicolasl/diffusion_imitation_learning/diffusion_policy/data/pusht/pusht_cchi_v7_replay.zarr + batch_size: 64 + num_epochs: 3000 + learning_rate: 1e-4 + weight_decay: 1e-6 + adam_eps: 1e-3 + log_freq: 10 + eval_freq: 50 + save_freq: 50 + log_directory: './experiments' + loss: 'mse' +task: + dataset: + _target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset + horizon: 24 # was 16 + max_train_episodes: 90 + pad_after: 7 + pad_before: 1 + seed: 42 + val_ratio: 0.02 + zarr_path: data/pusht/pusht_cchi_v7_replay.zarr + env_runner: + _target_: diffusion_policy.env_runner.pusht_image_runner_diff_ctrl.PushTImageRunnerDiffCtrl + fps: 10 + legacy_test: true + max_steps: 300 + n_action_steps: 8 + n_envs: null + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + test_start_seed: 100000 + train_start_seed: 0 + image_shape: + - 3 + - 96 + - 96 + name: pusht_image + shape_meta: + action: + shape: + - 2 + obs: + agent_pos: + shape: + - 2 + type: low_dim + image: + shape: + - 3 + - 96 + - 96 + type: rgb +test: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + model_name: 'pusht_diffusion' + data_path: [ + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + ] + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + checkpoint_path_1: '' + checkpoint_path_2: '' + dataset: 'PushTDataset' +optim: + optim: 'adamw' + lr_scheduler: 'polynomial_decay' diff --git a/Diff-Control/config/square_ph_controlnet.yaml b/Diff-Control/config/square_ph_controlnet.yaml new file mode 100644 index 0000000..e37e63f --- /dev/null +++ b/Diff-Control/config/square_ph_controlnet.yaml @@ -0,0 +1,176 @@ +mode: + mode: 'train' + model_zoo: 'controlnet' + multiprocessing_distributed: False + dist_url: tcp://127.0.0.1:2345 + num_threads: 1 + do_online_eval: True + parameter_path: '' +train: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + seed: 0 + model_name: 'square_ph_controlnet' + dataset: 'SquarePhDataset' + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + #data_path: [ + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + #] + data_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + batch_size: 64 + num_epochs: 3000 + learning_rate: 1e-4 + weight_decay: 1e-6 + adam_eps: 1e-3 + log_freq: 10 + eval_freq: 50 + save_freq: 50 + log_directory: './experiments' + loss: 'mse' +task: + abs_action: true + dataset: + _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset + abs_action: true + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + horizon: 24 # was 16 originally + n_obs_steps: 2 + pad_after: 7 + pad_before: 1 + rotation_rep: rotation_6d + seed: 42 + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + use_cache: true + val_ratio: 0.02 + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + dataset_type: ph + env_runner: + _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunnerSquarePh + abs_action: true + crf: 22 + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + fps: 10 + max_steps: 400 + n_action_steps: 8 + n_envs: 28 + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + render_obs_key: agentview_image + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + test_start_seed: 100000 + tqdm_interval_sec: 1.0 + train_start_idx: 0 + name: square_image + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 +test: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + model_name: 'square_ph_controlnet' + data_path: [ + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + ] + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + checkpoint_path_1: './experiments/square_ph_diffusion/v1.0-ema-model-104000' + checkpoint_path_2: './experiments/square_ph_diffusion/v1.0-sensor_model-104000' + dataset: 'SquarePhDataset' +optim: + optim: 'adamw' + lr_scheduler: 'polynomial_decay' diff --git a/Diff-Control/config/square_ph_diffusion.yaml b/Diff-Control/config/square_ph_diffusion.yaml new file mode 100644 index 0000000..aeb2868 --- /dev/null +++ b/Diff-Control/config/square_ph_diffusion.yaml @@ -0,0 +1,176 @@ +mode: + mode: 'train' + model_zoo: 'diffusion-model' + multiprocessing_distributed: False + dist_url: tcp://127.0.0.1:2345 + num_threads: 1 + do_online_eval: True + parameter_path: '' +train: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + seed: 0 + model_name: 'square_ph_diffusion' + dataset: 'SquarePhDataset' + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + #data_path: [ + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_1', + # '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + #] + data_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + batch_size: 64 + num_epochs: 3000 + learning_rate: 1e-4 + weight_decay: 1e-6 + adam_eps: 1e-3 + log_freq: 10 + eval_freq: 50 + save_freq: 50 + log_directory: './experiments' + loss: 'mse' +task: + abs_action: true + dataset: + _target_: diffusion_policy.dataset.robomimic_replay_image_dataset.RobomimicReplayImageDataset + abs_action: true + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + horizon: 24 # was 16 originally + n_obs_steps: 2 + pad_after: 7 + pad_before: 1 + rotation_rep: rotation_6d + seed: 42 + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + use_cache: true + val_ratio: 0.02 + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + dataset_type: ph + env_runner: + _target_: diffusion_policy.env_runner.robomimic_image_runner.RobomimicImageRunnerSquarePh + abs_action: true + crf: 22 + dataset_path: /home/ssdArray/datasets/diffusion_policy_data/robomimic/robomimic/datasets/square/ph/image_abs.hdf5 + fps: 10 + max_steps: 400 + n_action_steps: 8 + n_envs: 28 + n_obs_steps: 2 + n_test: 50 + n_test_vis: 4 + n_train: 6 + n_train_vis: 2 + past_action: false + render_obs_key: agentview_image + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 + test_start_seed: 100000 + tqdm_interval_sec: 1.0 + train_start_idx: 0 + name: square_image + shape_meta: + action: + shape: + - 10 + obs: + agentview_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_eef_pos: + shape: + - 3 + robot0_eef_quat: + shape: + - 4 + robot0_eye_in_hand_image: + shape: + - 3 + - 84 + - 84 + type: rgb + robot0_gripper_qpos: + shape: + - 2 +test: + dim_x: 10 + dim_z: '' + dim_a: 10 + dim_gt: 10 + num_ensemble: 32 + win_size: 24 + model_name: 'square_ph_diffusion' + data_path: [ + '/home/ssdArray/datasets/diff_control_data/pick_duck/pick_duck_2' + ] + sensor_len: 2 + channel_img_1: 3 + channel_img_2: '' + input_size_1: 77 + input_size_2: '' + input_size_3: '' + checkpoint_path_1: '' + checkpoint_path_2: '' + dataset: 'SquarePhDataset' +optim: + optim: 'adamw' + lr_scheduler: 'polynomial_decay' diff --git a/Diff-Control/controlnet_engine.py b/Diff-Control/controlnet_engine.py index dceadb0..f2f3877 100644 --- a/Diff-Control/controlnet_engine.py +++ b/Diff-Control/controlnet_engine.py @@ -1,11 +1,8 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip +import cv2 +import wandb + +from dataset.square_ph import SquarePhDataset +from diffusion_policy.policy.base_image_policy import BaseImagePolicy from model import ( UNetwithControl, SensorModel, @@ -14,24 +11,23 @@ StatefulUNet, ) from dataset.lid_pick_and_place import * -from dataset.tomato_pick_and_place import * +#from config.tomato_pick_and_place import * from dataset.pick_duck import * from dataset.drum_hit import * from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle -from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +#from diffusers.schedulers.scheduling_ddpm import DDPMScheduler +from diffusers import DDPMScheduler from diffusers.training_utils import EMAModel -from diffusers.optimization import get_scheduler -class Engine: - def __init__(self, args, logger): +class Engine(BaseImagePolicy): + def __init__(self, args, logger, diff_pol_dataset=None, env_runner=None): + super().__init__() self.args = args self.logger = logger self.batch_size = self.args.train.batch_size @@ -49,6 +45,7 @@ def __init__(self, args, logger): self.win_size = self.args.train.win_size self.global_step = 0 self.mode = self.args.mode.mode + self.env_runner = env_runner if self.args.train.dataset == "OpenLid": if self.mode == "train": @@ -61,13 +58,21 @@ def __init__(self, args, logger): self.data_path = self.args.train.data_path else: self.data_path = self.args.test.data_path - self.dataset = Tomato(self.data_path) + #self.dataset = Tomato(self.data_path) elif self.args.train.dataset == "Duck": if self.mode == "train": self.data_path = self.args.train.data_path else: self.data_path = self.args.test.data_path self.dataset = Duck(self.data_path) + elif self.args.train.dataset == "SquarePhDataset": + if self.mode == "train": + self.data_path = self.args.train.data_path + else: + self.data_path = self.args.test.data_path + self.dataset = SquarePhDataset(diff_pol_dataset) + self.test_dataset = self.dataset.get_validation_dataset() + elif self.args.train.dataset == "Drum": if self.mode == "train": self.data_path = self.args.train.data_path @@ -208,13 +213,19 @@ def train(self): batch_size=self.batch_size, shuffle=True, num_workers=8, - collate_fn=tomato_pad_collate_xy_lang, + #collate_fn=tomato_pad_collate_xy_lang, + collate_fn=pad_collate_xy_lang, ) pytorch_total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad ) print("Total number of parameters: ", pytorch_total_params) + if self.env_runner: + test_dataloader = torch.utils.data.DataLoader( + self.test_dataset, batch_size=1, shuffle=True, num_workers=8 + ) + """ only build optimizer for the trainable parts """ @@ -259,6 +270,7 @@ def train(self): # -----------------------------------------------------------------------------# # --------------------------------- train ------------------------------# # -----------------------------------------------------------------------------# + previous_val_loss = float("inf") while epoch < self.args.train.num_epochs: step = 0 for data in dataloader: @@ -351,6 +363,7 @@ def train(self): # Save a model based of a chosen save frequency if self.global_step != 0 and (epoch + 1) % self.args.train.save_freq == 0: + #if self.global_step != 0: self.ema_nets = self.model checkpoint = { "global_step": self.global_step, @@ -366,22 +379,138 @@ def train(self): ), ) + if self.global_step != 0 and self.env_runner and (epoch + 1) % self.args.train.eval_freq == 0: + #if self.global_step != 0 and self.env_runner: + print(f"saving step {self.global_step}, epoch: {epoch}") + runner_log = self.env_runner.run(self) + test_mean_score = runner_log['test/mean_score'] + train_mean_score = runner_log['train/mean_score'] + + self.writer.add_scalar( + "test/mean_score", test_mean_score, self.global_step + ) + self.writer.add_scalar( + "train/mean_score", train_mean_score, self.global_step + ) + videos = {k: v for k, v in runner_log.items() if isinstance(v, wandb.Video)} + for k in videos: + cap = cv2.VideoCapture(videos[k]._path) + frames = [] + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert from OpenCV BGR to RGB + frames.append(frame) + cap.release() + video_tensor = torch.tensor(np.array(frames), dtype=torch.float32).permute(0, 3, 1, 2) / 255.0 + # Add batch dimension (1, T, C, H, W) + video_tensor = video_tensor.unsqueeze(0) + self.writer.add_video(k, video_tensor, self.global_step, fps=10) + # online evaluation if ( self.args.mode.do_online_eval and self.global_step != 0 - and (epoch + 1) % self.args.train.eval_freq == 0 + and (epoch + 1) % 25 == 0 ): time.sleep(0.1) self.ema_nets = self.model self.ema_nets.eval() self.online_test() - self.ema_nets.train() - self.ema.copy_to(self.ema_nets.parameters()) + ############################################################## + + val_loss = None + with torch.no_grad(): + val_losses = list() + for data in test_dataloader: + data = [item.to(self.device) for item in data] + (images, prior_action, action, sentence) = data + optimizer_.zero_grad() + + text_features = self.clip_model.encode_text(sentence) + text_features = text_features.clone().detach() + text_features = text_features.to(torch.float32) + + # sample noise to add to actions + noise = torch.randn(action.shape, device=self.device) + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (action.shape[0],), + device=self.device, + ).long() + # add noise to the clean images according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = self.noise_scheduler.add_noise(action, noise, timesteps) + + # forward + img_emb = self.sensor_model(images) + predicted_noise = self.model( + noisy_actions, img_emb, text_features, prior_action, timesteps + ) + loss = self.criterion(noise, predicted_noise) + val_losses.append(loss.cpu().item()) + if len(val_losses) > 0: + val_loss = torch.mean(torch.tensor(val_losses)).item() + self.writer.add_scalar( + "val_loss", val_loss, self.global_step + ) + ############################################################## + self.ema_nets.train() + self.ema.copy_to(self.ema_nets.parameters()) + + if (epoch + 1) % self.args.train.eval_freq == 0: + if val_loss > previous_val_loss and epoch >= 300: + return + + previous_val_loss = val_loss # Update epoch epoch += 1 + + def set_normalizer(self, normalizer): + pass + + def predict_action(self, obs_dict): + images = obs_dict["image"] + prior_action = obs_dict["prior_action"] + sentence = obs_dict["sentence"].to(self.device).unsqueeze(0) + + with torch.no_grad(): + text_features = self.clip_model.encode_text(sentence) + text_features = text_features.clone().detach() + text_features = text_features.to(torch.float32) + text_features = torch.stack([text_features] * 28, dim=0) + text_features = text_features.squeeze(1) + + img_emb = self.sensor_model(images) + + # initialize action from Guassian noise + noisy_action = torch.randn((28, self.dim_x, self.win_size)).to( + self.device + ) + + # TODO: add previous action + + # init scheduler + self.noise_scheduler.set_timesteps(50) + + for k in self.noise_scheduler.timesteps: + # predict noise + t = [torch.stack([k]).to(self.device) for _ in range(28)] # Shape (1,1) per tensor + t = torch.cat(t, dim=0) + predicted_noise = self.ema_nets( + noisy_action, img_emb, text_features, prior_action, t + ) + + # inverse diffusion step (remove noise) + noisy_action = self.noise_scheduler.step( + model_output=predicted_noise, timestep=k, sample=noisy_action + ).prev_sample + return noisy_action # -----------------------------------------------------------------------------# # --------------------------------- test ------------------------------# # -----------------------------------------------------------------------------# @@ -398,7 +527,7 @@ def online_test(self): self.data_path = self.args.train.data_path else: self.data_path = self.args.test.data_path - test_dataset = Tomato(self.data_path) + #test_dataset = Tomato(self.data_path) elif self.args.train.dataset == "Duck": if self.mode == "train": self.data_path = self.args.train.data_path @@ -411,6 +540,13 @@ def online_test(self): else: self.data_path = self.args.test.data_path test_dataset = Drum(self.data_path) + elif self.args.train.dataset == "SquarePhDataset": + if self.mode == "train": + self.data_path = self.args.train.data_path + else: + self.data_path = self.args.test.data_path + test_dataset = self.dataset.get_validation_dataset() + test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=1, shuffle=True, num_workers=8 ) diff --git a/Diff-Control/dataset/__init__.py b/Diff-Control/dataset/__init__.py index bc134d4..349fa43 100644 --- a/Diff-Control/dataset/__init__.py +++ b/Diff-Control/dataset/__init__.py @@ -1,4 +1,3 @@ from dataset import lid_pick_and_place -from dataset import tomato_pick_and_place from dataset import pick_duck from dataset import drum_hit diff --git a/Diff-Control/dataset/pusht.py b/Diff-Control/dataset/pusht.py new file mode 100644 index 0000000..5d61704 --- /dev/null +++ b/Diff-Control/dataset/pusht.py @@ -0,0 +1,45 @@ +import torch +from torchvision import transforms +from einops import rearrange +import clip + +class PushTDataset(torch.utils.data.Dataset): + def __init__(self, pusht_dataset, image_size=(224, 224)): + self.pusht_dataset = pusht_dataset + self.image_size = image_size + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu" + ) + self.clip_model, _ = clip.load("ViT-B/32", device=self.device) + self.dummy_sentence = "push the T to its goal" + self.transform = transforms.Resize(self.image_size) + + def __len__(self): + return len(self.square_ph_dataset) + + def __getitem__(self, idx): + transition_history = 12 + data = self.pusht_dataset.__getitem__(idx) + if idx < transition_history: + idx_pre = 0 + else: + idx_pre = idx - transition_history + + prev_data = self.pusht_dataset.__getitem__(idx_pre) + + # Extract image and action from the PushT dataset + image = data['obs']['agentview_image'] + action = rearrange(data['action'], "B Ta Da -> B Da Ta") + + prior_action = rearrange(prev_data['action'], "Ta Da -> Da Ta") + # Resize and normalize the image to match Duck dataset + image = self.transform(image) + + # Tokenize dummy sentence using CLIP + sentence = clip.tokenize([self.dummy_sentence])[0] + + return image, prior_action, action, sentence + + def get_validation_dataset(self): + val_set = self.pusht_dataset.get_validation_dataset() + return PushTDataset(val_set, image_size=self.image_size) diff --git a/Diff-Control/dataset/square_ph.py b/Diff-Control/dataset/square_ph.py new file mode 100644 index 0000000..5087334 --- /dev/null +++ b/Diff-Control/dataset/square_ph.py @@ -0,0 +1,45 @@ +import torch +from torchvision import transforms +from einops import rearrange +import clip + +class SquarePhDataset(torch.utils.data.Dataset): + def __init__(self, square_ph_dataset, image_size=(224, 224)): + self.square_ph_dataset = square_ph_dataset + self.image_size = image_size + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu" + ) + self.clip_model, _ = clip.load("ViT-B/32", device=self.device) + self.dummy_sentence = "push the square to its goal" + self.transform = transforms.Resize(self.image_size) + + def __len__(self): + return len(self.square_ph_dataset) + + def __getitem__(self, idx): + transition_history = 12 + data = self.square_ph_dataset.__getitem__(idx) + if idx < transition_history: + idx_pre = 0 + else: + idx_pre = idx - transition_history + + prev_data = self.square_ph_dataset.__getitem__(idx_pre) + + # Extract image and action from the PushT dataset + image = data['obs']['agentview_image'][0] + action = rearrange(data['action'], "Ta Da -> Da Ta") + + prior_action = rearrange(prev_data['action'], "Ta Da -> Da Ta") + # Resize and normalize the image to match Duck dataset + image = self.transform(image) + + # Tokenize dummy sentence using CLIP + sentence = clip.tokenize([self.dummy_sentence])[0] + + return image, prior_action, action, sentence + + def get_validation_dataset(self): + val_set = self.square_ph_dataset.get_validation_dataset() + return SquarePhDataset(val_set, image_size=self.image_size) diff --git a/Diff-Control/dataset/tomato_pick_and_place.py b/Diff-Control/dataset/tomato_pick_and_place.py old mode 100755 new mode 100644 diff --git a/Diff-Control/model/__init__.py b/Diff-Control/model/__init__.py index b1300a1..fec3560 100644 --- a/Diff-Control/model/__init__.py +++ b/Diff-Control/model/__init__.py @@ -4,4 +4,4 @@ from model.stateful_module import StatefulUNet from model.film_model import Backbone as ModAttn_ImageBC from model.film_lstm import Backbone as BCZ_LSTM -from model.stateful_module import StatefulControlNet +from model.stateful_module import StatefulControlNet \ No newline at end of file diff --git a/Diff-Control/prebuild_engine.py b/Diff-Control/prebuild_engine.py index 6a4bdae..295d257 100644 --- a/Diff-Control/prebuild_engine.py +++ b/Diff-Control/prebuild_engine.py @@ -1,31 +1,26 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip +import cv2 +import wandb + +from dataset.square_ph import SquarePhDataset +from diffusion_policy.policy.base_image_policy import BaseImagePolicy from model import UNetwithControl, SensorModel, StatefulUNet from dataset.lid_pick_and_place import * -from dataset.tomato_pick_and_place import * +#from config.tomato_pick_and_place import * from dataset.pick_duck import * from dataset.drum_hit import * from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.training_utils import EMAModel -from diffusers.optimization import get_scheduler -class Engine: - def __init__(self, args, logger): +class Engine(BaseImagePolicy): + def __init__(self, args, logger, diff_pol_dataset=None, env_runner=None): + super().__init__() self.args = args self.logger = logger self.batch_size = self.args.train.batch_size @@ -43,6 +38,7 @@ def __init__(self, args, logger): self.win_size = self.args.train.win_size self.global_step = 0 self.mode = self.args.mode.mode + self.env_runner = env_runner if self.args.train.dataset == "OpenLid": if self.mode == "train": @@ -68,6 +64,14 @@ def __init__(self, args, logger): else: self.data_path = self.args.test.data_path self.dataset = Drum(self.data_path) + elif self.args.train.dataset == "SquarePhDataset": + if self.mode == "train": + self.data_path = self.args.train.data_path + else: + self.data_path = self.args.test.data_path + + self.dataset = SquarePhDataset(diff_pol_dataset) + self.test_dataset = self.dataset.get_validation_dataset() if self.args.train.dataset == "Drum": self.model = StatefulUNet(dim_x=self.dim_x, window_size=self.win_size) @@ -87,8 +91,8 @@ def __init__(self, args, logger): raise TypeError("model must be an instance of nn.Module") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if torch.cuda.is_available(): - self.model.cuda() - self.sensor_model.cuda() + self.model.cuda(self.device) + self.sensor_model.cuda(self.device) # -----------------------------------------------------------------------------# # ------------------------------- use pre-trained? ---------------------------# @@ -142,6 +146,11 @@ def train(self): ) print("Total number of parameters: ", pytorch_total_params) + if self.env_runner: + test_dataloader = torch.utils.data.DataLoader( + self.test_dataset, batch_size=1, shuffle=True, num_workers=8 + ) + # Create optimizer optimizer_ = build_optimizer( [self.model, self.sensor_model], @@ -177,6 +186,8 @@ def train(self): # -----------------------------------------------------------------------------# # --------------------------------- train ------------------------------# # -----------------------------------------------------------------------------# + previous_val_loss = float("inf") + all_val_losses = [] while epoch < self.args.train.num_epochs: step = 0 for data in dataloader: @@ -256,6 +267,7 @@ def train(self): # Save a model based of a chosen save frequency if self.global_step != 0 and (epoch + 1) % self.args.train.save_freq == 0: + #if self.global_step != 0: self.ema_nets = self.model checkpoint = { "global_step": self.global_step, @@ -284,24 +296,136 @@ def train(self): ), ) + if self.global_step != 0 and self.env_runner and (epoch + 1) % self.args.train.eval_freq == 0: + print(f"saving step {self.global_step}, epoch: {epoch}") + runner_log = self.env_runner.run(self) + test_mean_score = runner_log['test/mean_score'] + train_mean_score = runner_log['train/mean_score'] + + self.writer.add_scalar( + "test/mean_score", test_mean_score, self.global_step + ) + self.writer.add_scalar( + "train/mean_score", train_mean_score, self.global_step + ) + videos = {k: v for k, v in runner_log.items() if isinstance(v, wandb.Video)} + for k in videos: + cap = cv2.VideoCapture(videos[k]._path) + frames = [] + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert from OpenCV BGR to RGB + frames.append(frame) + cap.release() + video_tensor = torch.tensor(np.array(frames), dtype=torch.float32).permute(0, 3, 1, 2) / 255.0 + # Add batch dimension (1, T, C, H, W) + video_tensor = video_tensor.unsqueeze(0) + self.writer.add_video(k, video_tensor, self.global_step, fps=10) # online evaluation if ( self.args.mode.do_online_eval and self.global_step != 0 - and (epoch + 1) % self.args.train.eval_freq == 0 + and (epoch + 1) % 25 == 0 ): time.sleep(0.1) self.ema_nets = self.model self.ema_nets.eval() self.sensor_model.eval() self.online_test() - self.ema_nets.train() - self.sensor_model.train() - self.ema.copy_to(self.ema_nets.parameters()) + ############################################################### + val_loss = None + with torch.no_grad(): + val_losses = list() + for data in test_dataloader: + data = [item.to(self.device) for item in data] + (images, prior_action, action, sentence) = data + optimizer_.zero_grad() + + text_features = self.clip_model.encode_text(sentence) + text_features = text_features.clone().detach() + text_features = text_features.to(torch.float32) + + # sample noise to add to actions + noise = torch.randn(action.shape, device=self.device) + # sample a diffusion iteration for each data point + timesteps = torch.randint( + 0, + self.noise_scheduler.config.num_train_timesteps, + (action.shape[0],), + device=self.device, + ).long() + # add noise to the clean images according to the noise magnitude at each diffusion iteration + # (this is the forward diffusion process) + noisy_actions = self.noise_scheduler.add_noise(action, noise, timesteps) + + # forward + img_emb = self.sensor_model(images) + predicted_noise = self.model( + noisy_actions, img_emb, text_features, timesteps + ) + loss = self.criterion(noise, predicted_noise) + val_losses.append(loss.cpu().item()) + if len(val_losses) > 0: + val_loss = torch.mean(torch.tensor(val_losses)).item() + self.writer.add_scalar( + "val_loss", val_loss, self.global_step + ) + all_val_losses.append((epoch, val_loss)) + + ############################################################### + self.ema_nets.train() + self.sensor_model.train() + self.ema.copy_to(self.ema_nets.parameters()) + + if (epoch + 1) % self.args.train.eval_freq == 0: + if val_loss > previous_val_loss and epoch >= 300: + return + + previous_val_loss = val_loss # Update epoch epoch += 1 + def set_normalizer(self, normalizer): + pass + + def predict_action(self, obs_dict): + images = obs_dict["image"] + sentence = obs_dict["sentence"].to(self.device).unsqueeze(0) + + with torch.no_grad(): + text_features = self.clip_model.encode_text(sentence) + text_features = text_features.clone().detach() + text_features = text_features.to(torch.float32) + text_features = torch.stack([text_features] * 28, dim=0) + text_features = text_features.squeeze(1) + + img_emb = self.sensor_model(images) + + # initialize action from Guassian noise + noisy_action = torch.randn((28, self.dim_x, self.win_size)).to( + self.device + ) + + # init scheduler + self.noise_scheduler.set_timesteps(50) + + for k in self.noise_scheduler.timesteps: + # predict noise + t = [torch.stack([k]).to(self.device) for _ in range(28)] # Shape (1,1) per tensor + t = torch.cat(t, dim=0) + predicted_noise = self.ema_nets( + noisy_action, img_emb, text_features, t + ) + + # inverse diffusion step (remove noise) + noisy_action = self.noise_scheduler.step( + model_output=predicted_noise, timestep=k, sample=noisy_action + ).prev_sample + return noisy_action + # -----------------------------------------------------------------------------# # --------------------------------- test ------------------------------# # -----------------------------------------------------------------------------# @@ -331,6 +455,13 @@ def online_test(self): else: self.data_path = self.args.test.data_path test_dataset = Drum(self.data_path) + elif self.args.train.dataset == "SquarePhDataset": + if self.mode == "train": + self.data_path = self.args.train.data_path + else: + self.data_path = self.args.test.data_path + test_dataset = self.dataset.get_validation_dataset() + test_dataloader = torch.utils.data.DataLoader( test_dataset, batch_size=1, shuffle=True, num_workers=8 ) diff --git a/Diff-Control/tomato_engine.py b/Diff-Control/tomato_engine.py index 6da5716..3e8ee7b 100644 --- a/Diff-Control/tomato_engine.py +++ b/Diff-Control/tomato_engine.py @@ -1,24 +1,13 @@ -import argparse -import logging -import os -import numpy as np -import torch -import torch.nn as nn -from einops import rearrange, repeat -import clip from model import UNetwithControl, SensorModel from dataset.tomato_pick_and_place import * from optimizer import build_optimizer from optimizer import build_lr_scheduler from torch.utils.tensorboard import SummaryWriter -import copy import time -import random import pickle from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from diffusers.training_utils import EMAModel -from diffusers.optimization import get_scheduler class Engine: diff --git a/Diff-Control/train.py b/Diff-Control/train.py index ff9c418..5489c98 100644 --- a/Diff-Control/train.py +++ b/Diff-Control/train.py @@ -4,7 +4,7 @@ import yaml from sys import argv from config import cfg -import controlnet_engine, BCZ_engine, BCZ_LSTM_engine, prebuild_engine, tomato_engine +import controlnet_engine, BCZ_engine, BCZ_LSTM_engine, prebuild_engine#, tomato_engine # mini_controlnet_engine, BCZ_LSTM_engine import warnings @@ -18,7 +18,7 @@ style="%", ) logging.basicConfig(**logging_kwargs) -logger = logging.getLogger("diffusion-pilicy") +logger = logging.getLogger("diffusion-policy") def parse_args(): @@ -38,43 +38,43 @@ def parse_args(): def main(): - cfg, config_file = parse_args() - cfg.freeze() + cfg_diff_ctrl, config_file = parse_args() + cfg_diff_ctrl.freeze() ####### check all the parameter settings ####### - logger.info("{}".format(cfg)) - logger.info("check mode - {}".format(cfg.mode.mode)) + logger.info("{}".format(cfg_diff_ctrl)) + logger.info("check mode - {}".format(cfg_diff_ctrl.mode.mode)) # Create directory for logs and experiment name - if not os.path.exists(cfg.train.log_directory): - os.mkdir(cfg.train.log_directory) - if not os.path.exists(os.path.join(cfg.train.log_directory, cfg.train.model_name)): - os.mkdir(os.path.join(cfg.train.log_directory, cfg.train.model_name)) + if not os.path.exists(cfg_diff_ctrl.train.log_directory): + os.mkdir(cfg_diff_ctrl.train.log_directory) + if not os.path.exists(os.path.join(cfg_diff_ctrl.train.log_directory, cfg_diff_ctrl.train.model_name)): + os.mkdir(os.path.join(cfg_diff_ctrl.train.log_directory, cfg_diff_ctrl.train.model_name)) os.mkdir( - os.path.join(cfg.train.log_directory, cfg.train.model_name, "summaries") + os.path.join(cfg_diff_ctrl.train.log_directory, cfg_diff_ctrl.train.model_name, "summaries") ) else: logger.warning( "This logging directory already exists: {}. Over-writing current files".format( - os.path.join(cfg.train.log_directory, cfg.train.model_name) + os.path.join(cfg_diff_ctrl.train.log_directory, cfg_diff_ctrl.train.model_name) ) ) ####### start the training ####### - if cfg.mode.model_zoo == "controlnet": - train_engine = controlnet_engine.Engine(args=cfg, logger=logger) - elif cfg.mode.model_zoo == "diffusion-model": - train_engine = prebuild_engine.Engine(args=cfg, logger=logger) - elif cfg.mode.model_zoo == "BCZ": - train_engine = BCZ_engine.Engine(args=cfg, logger=logger) - elif cfg.mode.model_zoo == "BCZ_LSTM": - train_engine = BCZ_LSTM_engine.Engine(args=cfg, logger=logger) - elif cfg.mode.model_zoo == "tomato-model": - train_engine = tomato_engine.Engine(args=cfg, logger=logger) + if cfg_diff_ctrl.mode.model_zoo == "controlnet": + train_engine = controlnet_engine.Engine(args=cfg_diff_ctrl, logger=logger) + elif cfg_diff_ctrl.mode.model_zoo == "diffusion-model": + train_engine = prebuild_engine.Engine(args=cfg_diff_ctrl, logger=logger) + elif cfg_diff_ctrl.mode.model_zoo == "BCZ": + train_engine = BCZ_engine.Engine(args=cfg_diff_ctrl, logger=logger) + elif cfg_diff_ctrl.mode.model_zoo == "BCZ_LSTM": + train_engine = BCZ_LSTM_engine.Engine(args=cfg_diff_ctrl, logger=logger) + elif cfg_diff_ctrl.mode.model_zoo == "tomato-model": + train_engine = tomato_engine.Engine(args=cfg_diff_ctrl, logger=logger) - if cfg.mode.mode == "train": + if cfg_diff_ctrl.mode.mode == "train": train_engine.train() - if cfg.mode.mode == "pretrain": + if cfg_diff_ctrl.mode.mode == "pretrain": train_engine.train() - if cfg.mode.mode == "test": + if cfg_diff_ctrl.mode.mode == "test": train_engine.test()