Skip to content

Commit

Permalink
pin transporter debug progress
Browse files Browse the repository at this point in the history
  • Loading branch information
peterdavidfagan committed May 7, 2024
1 parent ef43b93 commit 6a8a486
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 60 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ octo = {path="./robot_learning_baselines/data_preprocessing/octo", develop=true}
dlimp = {path="./robot_learning_baselines/data_preprocessing/dlimp", develop=true}
distrax = "^0.1.5"
huggingface-hub = "^0.22.0"
tf2onnx = "^1.16.1"
onnxruntime = "^1.17.3"
opencv-python = "^4.9.0.80"

[tool.black]
line-length = 120
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ batch_size: 4

# checkpoint manager
checkpoint_dir: ${config.project.base_path}/.checkpoints/${config.wandb.experiment_name}/transporter_pick
max_checkpoints: 2
save_interval: 5
max_checkpoints: 8
save_interval: 10

optimizer:
_target_: optax.adam
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ batch_size: 4

# checkpoint manager
checkpoint_dir: ${config.project.base_path}/.checkpoints/${config.wandb.experiment_name}/transporter_place
max_checkpoints: 2
save_interval: 5
max_checkpoints: 8
save_interval: 10

optimizer:
_target_: optax.adam
Expand Down
4 changes: 2 additions & 2 deletions robot_learning_baselines/config/transporter.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ hf_upload:
repo: "transporter_networks"
branch: "main"
checkpoint_dir: "${config.project.base_path}/.checkpoints/${config.wandb.experiment_name}"
pick_checkpoint_dir: "${config.project.base_path}/.checkpoints/${config.wandb.experiment_name}/transporter_pick/50"
place_checkpoint_dir: "${config.project.base_path}/.checkpoints/${config.wandb.experiment_name}/transporter_place/50"
pick_checkpoint_dir: "${config.project.base_path}/.checkpoints/${config.wandb.experiment_name}/transporter_pick/150"
place_checkpoint_dir: "${config.project.base_path}/.checkpoints/${config.wandb.experiment_name}/transporter_place/150"

wandb:
use: True
Expand Down
17 changes: 9 additions & 8 deletions robot_learning_baselines/train_transporter_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,14 @@ def main(cfg: DictConfig) -> None:
)

# compute ce loss for pick network and update pick network
pick_train_state, pick_loss = pick_train_step(
pick_train_state, pick_loss, pick_success_rate = pick_train_step(
transporter.pick_model_state,
rgbd_normalized,
ids[0])
transporter = transporter.replace(pick_model_state=pick_train_state)

# compute ce loss for place networks and update place network
place_train_state, place_loss = place_train_step(
place_train_state, place_loss, place_success_rate = place_train_step(
transporter.place_model_state,
rgbd_normalized,
rgbd_crop_normalized,
Expand All @@ -170,14 +170,15 @@ def main(cfg: DictConfig) -> None:


# report epoch metrics (optionally add to wandb)
pick_loss_epoch = transporter.pick_model_state.metrics.compute()
place_loss_epoch = transporter.place_model_state.metrics.compute()
print(f"Epoch {epoch}: pick_loss: {pick_loss_epoch}, place_loss: {place_loss_epoch}")
pick_metrics = transporter.pick_model_state.metrics.compute()
place_metrics = transporter.place_model_state.metrics.compute()

if cfg.wandb.use and (epoch%5==0):
wandb.log({
"pick_loss": pick_loss_epoch,
"place_loss": place_loss_epoch,
"pick_train_loss": pick_metrics["loss"],
"place_train_loss": place_metrics["loss"],
"pick_train_success_rate": pick_metrics["success_rate"],
"place_train_success_rate":place_metrics["success_rate"],
"epoch": epoch
})
visualize_transporter_predictions(cfg, transporter, eval_data, epoch)
Expand Down
56 changes: 10 additions & 46 deletions robot_learning_baselines/utils/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,82 +5,46 @@
from huggingface_hub.repocard import metadata_eval_result, metadata_save

def push_model(
branch: str,
checkpoint_dir: str,
upload_path: str,
entity: str = "peterdavidfagan",
repo_name: str = "robot_learning_baselines",
branch: str = "main",
):
"""
Uploads model to hugging face repository.
"""
api = HfApi()

# ensure repo exists
repo_id = f"{entity}/{repo_name}"
repo_url = api.create_repo(
repo_id=repo_id,
exist_ok=True,
private=False,
)

# generate model card
model_card = f"""
# (Robot Learning Baselines) Test**
OMG what a great model this is.
"""

# operations to upload flax model checkpoint
#operations=[]
#def compile_model_upload_ops(src_path):
# if os.path.isfile(src_path):
# print(src_path)
# dest_path = src_path.replace(checkpoint_dir, "")
# print(dest_path)
# operations.append(CommitOperationAdd(path_in_repo=dest_path, path_or_fileobj=src_path))
# else:
# for item in os.listdir(src_path + "/"):
# item = os.path.join(src_path, item)
# if os.path.isfile(item):
# print(item)
# dest_path = src_path.replace(checkpoint_dir, "")
# print(dest_path)
# operations.append(CommitOperationAdd(path_in_repo=dest_path, path_or_fileobj=item))
# else:
# compile_model_upload_ops(item)
#compile_model_upload_ops(checkpoint_dir)

#for filepath in glob(checkpoint_dir + "/**/*", recursive=True):
# if os.path.isfile(filepath):
# operations.append(CommitOperationAdd(path_in_repo="/", path_or_fileobj=filepath))

# create model branch
api.create_branch(
repo_id=repo_id,
branch=branch,
repo_type="model",
exist_ok=True,
)
if os.path.isdir(checkpoint_dir):

# upload requested files
if os.path.isdir(upload_path):
api.upload_folder(
folder_path=checkpoint_dir,
folder_path=upload_path,
repo_id=repo_id,
repo_type="model",
multi_commits=True,
)
elif os.path.isfile(checkpoint_dir):
elif os.path.isfile(upload_path):
api.upload_file(
path_or_fileobj=checkpoint_dir,
path_in_repo="checkpoint.tar.xz",
path_or_fileobj=upload_path,
path_in_repo=upload_path.split("/")[-1],
repo_id=repo_id,
repo_type="model",
)
else:
raise NotImplementedError

# commit changes to branch
#api.create_commit(
# repo_id=repo_id,
# commit_message="Nice Model Dude",
# operations=operations,
# repo_type="model",
# )

0 comments on commit 6a8a486

Please sign in to comment.