Skip to content

Commit

Permalink
add hf upload
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdavidfagan committed Mar 25, 2024
1 parent ba03717 commit 10d8e02
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ A set of baseline models for learning from demonstration with supporting trainin

<img src="./assets/robot_learning.jpeg" height=300/>

[**[Training Runs]**](https://wandb.ai/ipab-rad/robot_learning_baselines) &ensp; [**[Pretrained Models]**](https://huggingface.co/peterdavidfagan/robot_learning_baselines) &ensp; [**[Documentation]**](https://peterdavidfagan.com/robot_learning_baselines/) &ensp;
[**[Training Runs]**](https://wandb.ai/ipab-rad/robot_learning_baselines) &ensp; [**[Pretrained Models]**](https://huggingface.co/peterdavidfagan) &ensp; [**[Documentation]**](https://peterdavidfagan.com/robot_learning_baselines/) &ensp;


# Configuring your Local Development Environment
Expand Down
6 changes: 6 additions & 0 deletions robot_learning_baselines/config/octo-categorical.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
project:
base_path: /home/peter/Code/research_projects/robot_learning_baselines

hf_upload:
entity: "peterdavidfagan"
repo: "octo_categorical"
branch: "main"
checkpoint_dir: "${config.project.base_path}/.checkpoints/octo-categorical/octo/30"

wandb:
use: True
project: "robot_learning_baselines"
Expand Down
6 changes: 6 additions & 0 deletions robot_learning_baselines/config/octo-continuous.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
project:
base_path: /home/peter/Code/research_projects/robot_learning_baselines

hf_upload:
entity: "peterdavidfagan"
repo: "octo_continuous"
branch: "main"
checkpoint_dir: "${config.project.base_path}/.checkpoints/octo-continuous/octo/30"

wandb:
use: True
project: "robot_learning_baselines"
Expand Down
6 changes: 6 additions & 0 deletions robot_learning_baselines/config/octo-diffusion.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
project:
base_path: /home/peter/Code/research_projects/robot_learning_baselines

hf_upload:
entity: "peterdavidfagan"
repo: "octo_diffusion"
branch: "main"
checkpoint_dir: "${config.project.base_path}/.checkpoints/octo-diffusion/octo/30"

wandb:
use: True
project: "robot_learning_baselines"
Expand Down
2 changes: 1 addition & 1 deletion robot_learning_baselines/config/training/octo-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ parallel_training: false
# checkpoint manager
checkpoint_dir: ${config.project.base_path}/.checkpoints/${config.wandb.experiment_name}/octo
max_checkpoints: 2
save_interval: 50
save_interval: 10

# training hyperparameters
num_epochs: 50
Expand Down
172 changes: 172 additions & 0 deletions robot_learning_baselines/hf_upload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""
A script to export model files and upload to huggingface.
"""
import os
from functools import partial

import jax
import jax.numpy as jnp
import jax.random as random
from jax.experimental import jax2tf
import flax
from flax.training import orbax_utils
import orbax
import tensorflow as tf

# model architecture/train state
from multi_modal_transformers.models.octo.octo import Octo
from multi_modal_transformers.models.octo.octo import create_octo_train_state

# tokenizer from huggingface
from transformers import AutoTokenizer

import hydra
from hydra.utils import instantiate, call
from omegaconf import DictConfig

# custom training pipeline utilities
from utils.data import (
oxe_load_single_dataset,
oxe_load_dataset,
preprocess_batch,
)

from utils.pipeline import (
inspect_model,
setup_checkpointing,
create_optimizer,
)

from utils.hugging_face import (
push_model,
)

from utils.wandb import (
init_wandb,
visualize_dataset,
visualize_multi_modal_predictions,
track_gradients,
)


@hydra.main(version_base=None, config_path=".")
def main(cfg: DictConfig) -> None:

assert jax.default_backend() != "cpu" # ensure accelerator is available
cfg = cfg["config"] # some hacky and wacky stuff from hydra (TODO: revise)

key = random.PRNGKey(0)
key, model_key, dropout_key, image_tokenizer_key, diffusion_key = random.split(key, 5)
rngs = {
"params": model_key,
"patch_encoding": image_tokenizer_key,
"dropout": dropout_key,
"diffusion": diffusion_key,
}

train_data = oxe_load_single_dataset(cfg.dataset) # load dataset for debugging

chkpt_manager = setup_checkpointing(cfg.training, reinitialise=False) # set up model checkpointing
optimizer, lr_scheduler = create_optimizer(cfg, lr_schedule="cosine_decay") # instantiate model optimizer
model = Octo(cfg.architecture.multi_modal_transformer) # instantiate model
text_tokenizer = instantiate(cfg.architecture.multi_modal_transformer.tokenizers.text.tokenizer) # instantiate text tokenizer
text_tokenize_fn = partial(text_tokenizer,
return_tensors="jax",
max_length=16, # hardcode while debugging
padding="max_length",
truncation=True
)

# initialize the training state
batch = next(train_data.as_numpy_iterator())
input_data = preprocess_batch(
batch,
text_tokenize_fn,
action_head_type=cfg.architecture.multi_modal_transformer.prediction_type,
dummy=True
)
inspect_model(model, rngs, input_data, method=cfg.architecture.multi_modal_transformer.forward_method)


# for now due to api we need to generate time + noisy actions data, this should be fixed in future
input_data = preprocess_batch(
batch,
text_tokenize_fn,
dummy=True
)
train_state = create_octo_train_state(
input_data["text_tokens"],
input_data["images"],
text_tokenizer,
{"time": input_data["time"], "noisy_actions": input_data["noisy_actions"]},
rngs,
model,
optimizer,
method=cfg.architecture.multi_modal_transformer.forward_method
)


# load model using orbax
train_state = chkpt_manager.restore(cfg.hf_upload.checkpoint_dir, items=train_state)

# upload to hugging face
push_model(
entity = cfg.hf_upload.entity,
repo_name = cfg.hf_upload.repo,
branch = cfg.hf_upload.branch,
checkpoint_dir = cfg.hf_upload.checkpoint_dir,
)

# TODO: spend time getting onnx version working

#if cfg.architecture.multi_modal_transformer.prediction_type == "continuous":
# def predict(text_tokens, images):
# return train_state.apply_fn(
# {"params": train_state.params},
# text_tokens,
# images,
# method="predict_continuous_action")

#elif cfg.architecture.multi_modal_transformer.prediction_type == "categorical":
# def predict(text_tokens, images):
# return train_state.apply_fn(
# {"params": train_state.params},
# text_tokens,
# images,
# rngs=train_state.rngs,
# method="predict_action_logits")

#elif cfg.architecture.multi_modal_transformer.prediction_type == "diffusion":
# def predict(text_tokens, images):
# return train_state.apply_fn(
# {"params": train_state.params},
# text_tokens,
# images,
# method="predict_diffusion_action")

#else:
# raise NotImplementedError

# convert model to tflite
#tf_predict = tf.function(
# jax2tf.convert(predict, enable_xla=True),
# input_signature=[
# tf.TensorSpec(shape=input_data["text_tokens"].shape, dtype=tf.int32, name='text_tokens'),
# tf.TensorSpec(shape=input_data["images"].shape, dtype=tf.float32, name='images'),
# ],
# autograph=False)

#converter = tf.lite.TFLiteConverter.from_concrete_functions(
# [tf_predict.get_concrete_function()], tf_predict)

#tflite_float_model = converter.convert()

# apply quantisation

# convert model to onnx

# upload onnx model to huggingface


if __name__=="__main__":
main()
77 changes: 77 additions & 0 deletions robot_learning_baselines/utils/hugging_face.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Utilities for interfacing with hugging face."""
import os
from glob import glob
from huggingface_hub import CommitOperationAdd, CommitOperationDelete, HfApi
from huggingface_hub.repocard import metadata_eval_result, metadata_save

def push_model(
branch: str,
checkpoint_dir: str,
entity: str = "peterdavidfagan",
repo_name: str = "robot_learning_baselines",
):
"""
Uploads model to hugging face repository.
"""
api = HfApi()

repo_id = f"{entity}/{repo_name}"
repo_url = api.create_repo(
repo_id=repo_id,
exist_ok=True,
private=False,
)

# generate model card
model_card = f"""
# (Robot Learning Baselines) Test**
OMG what a great model this is.
"""

# operations to upload flax model checkpoint
#operations=[]
#def compile_model_upload_ops(src_path):
# if os.path.isfile(src_path):
# print(src_path)
# dest_path = src_path.replace(checkpoint_dir, "")
# print(dest_path)
# operations.append(CommitOperationAdd(path_in_repo=dest_path, path_or_fileobj=src_path))
# else:
# for item in os.listdir(src_path + "/"):
# item = os.path.join(src_path, item)
# if os.path.isfile(item):
# print(item)
# dest_path = src_path.replace(checkpoint_dir, "")
# print(dest_path)
# operations.append(CommitOperationAdd(path_in_repo=dest_path, path_or_fileobj=item))
# else:
# compile_model_upload_ops(item)
#compile_model_upload_ops(checkpoint_dir)

#for filepath in glob(checkpoint_dir + "/**/*", recursive=True):
# if os.path.isfile(filepath):
# operations.append(CommitOperationAdd(path_in_repo="/", path_or_fileobj=filepath))

# create model branch
api.create_branch(
repo_id=repo_id,
branch=branch,
repo_type="model",
exist_ok=True,
)

api.upload_folder(
folder_path=checkpoint_dir,
repo_id=repo_id,
repo_type="model",
multi_commits=True,
)

# commit changes to branch
#api.create_commit(
# repo_id=repo_id,
# commit_message="Nice Model Dude",
# operations=operations,
# repo_type="model",
# )
7 changes: 4 additions & 3 deletions robot_learning_baselines/utils/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,15 @@ def inspect_model(model, variables, data, method="__call__"):


# saving model checkpoints
def setup_checkpointing(cfg):
def setup_checkpointing(cfg, reinitialise=True):
"""Set up checkpointing."""
if os.path.exists(cfg.checkpoint_dir):
if os.path.exists(cfg.checkpoint_dir) and reinitialise:
# remove old files
shutil.rmtree(cfg.checkpoint_dir)

# create checkpoint directory
os.makedirs(cfg.checkpoint_dir)
if reinitialise:
os.makedirs(cfg.checkpoint_dir)

# setup checkpoint manager
chkpt_options = ocp.CheckpointManagerOptions(
Expand Down

0 comments on commit 10d8e02

Please sign in to comment.