Skip to content

Commit

Permalink
debug training pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdavidfagan committed Mar 11, 2024
1 parent 690a50e commit ef82be5
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 155 deletions.
1 change: 1 addition & 0 deletions .docker/train_job_jax/Dockerfile.train_job
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ RUN eval "$(/pyenv/bin/pyenv init -)" && /pyenv/bin/pyenv local 3.10.6 && poetry

FROM ghcr.io/peterdavidfagan/jax_container:latest as targetsys

ENV WANDB_API_KEY=b45be466fddfec3e14a65ff903fac3b3b7e78c41
COPY --from=compilesys /pyenv /pyenv
COPY --from=compilesys /app /app
WORKDIR /app
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ jaxonnxruntime = "^0.3.0"
dm-reverb = {version="0.13.0", markers = "sys_platform == 'linux'"}
tensorflow-cpu = {version="^2.14.0", markers = "sys_platform == 'linux'"}
envlogger = {version="^1.1", extras=["tfds"], markers = "sys_platform == 'linux'"}
tensorflow-datasets = "4.9.3"
rlds = {version="^0.1.7", markers = "sys_platform == 'linux'"}

# submodules
multi_modal_transformers = {path="./robot_learning_baselines/model_architectures/multi_modal_transformers", develop=true}
transporter_networks = {path="./robot_learning_baselines/model_architectures/transporter_networks", develop=true}
octo = {path="./robot_learning_baselines/submodules/octo", develop=true}
dlimp = {path="./robot_learning_baselines/submodules/dlimp", develop=true}
distrax = "^0.1.5"

[tool.black]
line-length = 120
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
stacked_encoder_1d_block:
_target_: multi_modal_transformers.attention_blocks.attention.StackedEncoder1DBlock
num_blocks: 1
num_blocks: 6
encoder_1d_block:
_target_: multi_modal_transformers.attention_blocks.attention.Encoder1DBlock

Expand Down
9 changes: 8 additions & 1 deletion robot_learning_baselines/config/octo-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@ project:
base_path: /home/peter/Code/research_projects/robot_learning_baselines

wandb:
use: True
project: "robot_learning_baselines"
entity: "ipab-rad"
use: True
experiment_name: octo-base
tags: ["debugging octo pipeline"]
notes: "Debugging octo training pipeline."
resume: False

dataset_visualization:
columns:
task/language_instruction: "text"
observation/image_primary: "image"



defaults:
- architecture/multi_modal_transformer: octo-base
- dataset: debug-open-x-embodiment
Expand Down
2 changes: 1 addition & 1 deletion robot_learning_baselines/config/training/octo-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ save_interval: 50

# training hyperparameters
num_epochs: 50
batch_size: 8
batch_size: 32
momentum: 0.9

# learning rate scheduler
Expand Down
106 changes: 44 additions & 62 deletions robot_learning_baselines/train_multi_modal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Training script for concept learning model."""
# standard libraries
import os
import gc
from time import time
from functools import partial

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
Expand All @@ -15,6 +19,7 @@
import jax.numpy as jnp
import jax.random as random
from jax.nn import softmax
import flax.linen as nn
import optax
import orbax.checkpoint as ocp

Expand All @@ -28,23 +33,28 @@
from omegaconf import DictConfig
from clu import metrics

from utils.pipeline import (
# custom training pipeline utilities
from utils.data import (
oxe_load_single_dataset,
oxe_load_dataset,
preprocess_batch,
)

from utils.pipeline import (
inspect_model,
setup_checkpointing,
create_optimizer,
)

from utils.wandb import init_wandb
from utils.wandb import (
init_wandb,
visualize_dataset,
)

@hydra.main(version_base=None, config_path="./config", config_name="octo-base")
def main(cfg: DictConfig) -> None:
"""Model training loop."""

# initialize weights and biases
if cfg.wandb.use:
init_wandb(cfg)

# set up random number generators
key = random.PRNGKey(0)
key, model_key, dropout_key, image_tokenizer_key, diffusion_key = random.split(key, 5)
rngs = {
Expand All @@ -54,57 +64,39 @@ def main(cfg: DictConfig) -> None:
"diffusion": diffusion_key,
}

# load the dataset
train_data = oxe_load_single_dataset(cfg.dataset) # for now debug the datapipeline
#train_data = oxe_load_dataset(cfg.data.open-x-embodiment, cfg.training.decoder_only)

# set up model checkpointing
chkpt_manager = setup_checkpointing(cfg.training)
train_data = oxe_load_single_dataset(cfg.dataset) # load dataset for debugging
#train_data = oxe_load_dataset(cfg.data.open-x-embodiment, cfg.training.decoder_only) # load dataset

# instantiate model optimizer
learning_rate_scheduler = optax.warmup_cosine_decay_schedule(
init_value=cfg.training.initial_lr,
peak_value=cfg.training.peak_lr,
warmup_steps=cfg.training.warmup_epochs * cfg.training.steps_per_epoch,
decay_steps=(cfg.training.num_epochs - cfg.training.warmup_epochs)
* cfg.training.steps_per_epoch,
end_value=cfg.training.end_lr,
)

optimizer = optax.chain(
optax.clip_by_global_norm(cfg.training.max_grad_norm),
optax.adamw(learning_rate_scheduler, weight_decay=cfg.training.weight_decay),
)

# instantiate model and text tokenizer
model = Octo(cfg.architecture.multi_modal_transformer)
text_tokenizer = instantiate(cfg.architecture.multi_modal_transformer.tokenizers.text.tokenizer)
if cfg.wandb.use: # optionally initialise wandb
init_wandb(cfg)
visualize_dataset(cfg, next(train_data.as_numpy_iterator()))


# initialize the training state with a batch of data
batch = next(train_data.as_numpy_iterator())
text = [task.decode() for task in batch["task"]["language_instruction"]]
text_tokens = text_tokenizer(
text,
return_tensors="jax",
max_length=16, # hardcode while debugging
padding="max_length",
truncation=True,
)["input_ids"]
images = batch["observation"]["image_primary"]
time = jnp.ones((images.shape[0], 1))
actions = jnp.take(batch["action"], -1, axis=1)
chkpt_manager = setup_checkpointing(cfg.training) # set up model checkpointing
optimizer = create_optimizer(cfg) # instantiate model optimizer
model = Octo(cfg.architecture.multi_modal_transformer) # instantiate model
text_tokenizer = instantiate(cfg.architecture.multi_modal_transformer.tokenizers.text.tokenizer) # instantiate text tokenizer
text_tokenize_fn = partial(text_tokenizer,
return_tensors="jax",
max_length=16, # hardcode while debugging
padding="max_length",
truncation=True
)

# initialize the training state
batch = next(train_data.as_numpy_iterator())
input_data = preprocess_batch(batch, text_tokenize_fn, dummy=True)
inspect_model(model, rngs, input_data, method="predict_diffusion_denoise_term")
train_state = create_octo_train_state(
text_tokens,
images,
input_data["text_tokens"],
input_data["images"],
text_tokenizer,
{"time": time, "noisy_actions": actions},
{"time": input_data["time"], "noisy_actions": input_data["noisy_actions"]},
rngs,
model,
optimizer
)

# training loop
for epoch in tqdm(range(cfg.training.num_epochs), leave=False):

# epoch metrics
Expand All @@ -116,25 +108,15 @@ def main(cfg: DictConfig) -> None:
train_data = train_data.shuffle(10)
train_data_iter = train_data.as_numpy_iterator()

# cycle through batches of data
for batch in train_data_iter:

# tokenize text
text = [task.decode() for task in batch["task"]["language_instruction"]]
text_tokens = train_state.text_tokenize_fn(text)["input_ids"]
actions = jnp.take(batch["action"], -1, axis=1)

# perform diffusion train step
data = preprocess_batch(batch, train_state.text_tokenize_fn)
train_state = train_state.diffusion_train_step(
model,
train_state,
text_tokens,
batch["observation"]["image_primary"],
actions,
data["text_tokens"],
data["images"],
data["gt_action"],
)

# break loop for the purpose of debugging training pipeline
break

# compute and track metrics
for metric, value in train_state.metrics.compute().items():
Expand Down
127 changes: 127 additions & 0 deletions robot_learning_baselines/utils/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Utilities for loading and inspecting datasets."""

# standard libraries
import os
import shutil
import urllib.request

# dataset
import tensorflow as tf
import tensorflow_datasets as tfds
from octo.data.oxe import make_oxe_dataset_kwargs, make_oxe_dataset_kwargs_and_weights
from octo.data.dataset import make_interleaved_dataset, make_single_dataset

import jax.numpy as jnp

def oxe_load_single_dataset(cfg):
dataset_kwargs = make_oxe_dataset_kwargs(
cfg.dataset,
cfg.tfds_data_dir,
)
dataset = make_single_dataset(
dataset_kwargs,
train=True,
traj_transform_kwargs = {
"window_size": 2, # for octo we will take a history of two
},
frame_transform_kwargs = {
"resize_size": (280,280)
},
)
train_dataset = (
dataset.flatten() # flattens trajectories into individual frames
.shuffle(cfg.shuffle_buffer_size) # shuffles the frames
.batch(cfg.batch_size) # batches the frames
)

return train_dataset

def oxe_load_dataset(cfg):
"""Load dataset using the oxe dataset loader."""
dataset_kwargs_list, sample_weights = make_oxe_dataset_kwargs_and_weights(
"oxe_magic_soup",
cfd.tfds_data_dir,
load_camera_views=("primary", "wrist"),
)

# each element of `dataset_kwargs_list` can be used with `make_single_dataset`, but let's
# use the more powerful `make_interleaved_dataset` to combine them for us!
dataset = make_interleaved_dataset(
dataset_kwargs_list,
sample_weights,
train=True,
shuffle_buffer_size=cfg.shuffle_buffer_size,
batch_size=config.batch_size,
traj_transform_kwargs=dict(
goal_relabeling_strategy="uniform", # let's get some goal images
window_size=2, # let's get some history
future_action_window_size=3, # let's get some future actions for action chunking
subsample_length=100, # subsampling long trajectories improves shuffling a lot
),
frame_transform_kwargs=dict(
image_augment_kwargs=dict(
augment_order=["random_resized_crop", "random_brightness"],
random_resized_crop=dict(scale=[0.8, 1.0], ratio=[0.9, 1.1]),
random_brightness=[0.1],
),
resize_size=dict(
primary=(256, 256),
wrist=(128, 128),
),
# If parallelism options are not provided, they will default to tf.Data.AUTOTUNE.
# However, we would highly recommend setting them manually if you run into issues
# with memory or dataloading speed. Frame transforms are usually the speed
# bottleneck (due to image decoding, augmentation, and resizing), so you can set
# this to a very high value if you have a lot of CPU cores. Keep in mind that more
# parallel calls also use more memory, though.
num_parallel_calls=64,
),
# Same spiel as above about performance, although trajectory transforms and data reading
# are usually not the speed bottleneck. One reason to manually set these is if you want
# to reduce memory usage (since autotune may spawn way more threads than necessary).
traj_transform_threads=16,
traj_read_threads=16,
)

# Another performance knob to tune is the number of batches to prefetch -- again,
# the default of tf.data.AUTOTUNE can sometimes use more memory than necessary.
iterator = dataset.iterator(prefetch=1)

return iterator


def preprocess_batch(batch, text_tokenize_fn, dummy=False):
"""
Preprocess a batch of data.
"""

# tokenize text
text = [task.decode("utf-8") for task in batch["task"]["language_instruction"]]
text_tokens = text_tokenize_fn(
text,
)["input_ids"]

# get image observations
images = batch["observation"]["image_primary"]

# get action
gt_action = jnp.take(batch["action"], -1, axis=1)

# create dummy data for diffusion-based model init
if dummy:
time = jnp.ones((images.shape[0], 1))
actions = jnp.take(batch["action"], -1, axis=1)
data = {
"images": images,
"text_tokens": text_tokens,
"time": time,
"noisy_actions": actions,
}
else:
data = {
"images": images,
"text_tokens": text_tokens,
"gt_action": gt_action,
}

return data
Loading

0 comments on commit ef82be5

Please sign in to comment.