Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: minor changes adapting the code to our needs #1

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -150,4 +150,6 @@ data/

# Rollout videos and wandb logs
rollouts/
checkpoints/
adapter_weights
wandb/
142 changes: 142 additions & 0 deletions evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import os
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, AutoProcessor
from peft import PeftModel
import wandb
from tqdm import tqdm
from natsort import natsorted
import gc

# Import necessary modules from your training script
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets import RLDSDataset, RLDSBatchTransform
from prismatic.util.data_utils import PaddedCollatorForActionPrediction
from prismatic.models.backbones.llm.prompting import PurePromptBuilder, VicunaV15ChatPromptBuilder

def load_checkpoint(base_model_path, adapter_path, device):
base_model = AutoModelForVision2Seq.from_pretrained(
base_model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
).to(device)
model = PeftModel.from_pretrained(base_model, adapter_path)
return model

def evaluate(model, dataloader, action_tokenizer, device):
model.eval()
total_loss = 0
total_accuracy = 0
total_l1_loss = 0
total_samples = 0

with torch.no_grad():
for i, batch in enumerate(tqdm(dataloader, desc="Evaluating")):
if i >= len(dataloader):
break
outputs = model(
input_ids=batch["input_ids"].to(device),
attention_mask=batch["attention_mask"].to(device),
pixel_values=batch["pixel_values"].to(torch.bfloat16).to(device),
labels=batch["labels"].to(device)
)

loss = outputs.loss
total_loss += loss.item() * batch["input_ids"].size(0)

action_logits = outputs.logits[:, model.vision_backbone.featurizer.patch_embed.num_patches : -1]
action_preds = action_logits.argmax(dim=2)
action_gt = batch["labels"][:, 1:].to(action_preds.device)
mask = action_gt > action_tokenizer.action_token_begin_idx

correct_preds = (action_preds == action_gt) & mask
accuracy = correct_preds.sum().float() / mask.sum().float()
total_accuracy += accuracy.item() * mask.sum().item()

continuous_actions_pred = torch.tensor(
action_tokenizer.decode_token_ids_to_actions(action_preds[mask].cpu().numpy())
)
continuous_actions_gt = torch.tensor(
action_tokenizer.decode_token_ids_to_actions(action_gt[mask].cpu().numpy())
)
l1_loss = torch.nn.functional.l1_loss(continuous_actions_pred, continuous_actions_gt)
total_l1_loss += l1_loss.item() * mask.sum().item()

total_samples += mask.sum().item()

avg_loss = total_loss / total_samples
avg_accuracy = total_accuracy / total_samples
avg_l1_loss = total_l1_loss / total_samples

return avg_loss, avg_accuracy, avg_l1_loss

def main():
# Configuration
VLA_PATH = "openvla/openvla-7b"
BASE_PATH = "/robo-srv-004-storage-001/home/mkotynia/openvla"
EXP = "openvla-7b+robotec_o3de_panda_dataset_200_train_episodes+b16+lr-0.0005+lr-decay+lora-r128+dropout-0.0"
ADAPTERS_DIR = os.path.join(BASE_PATH, "adapter_weights", EXP),
STEPS_RUN_DIR = os.path.join(BASE_PATH, "checkpoints", EXP)
val_dataset_name = "robotec_o3de_panda_dataset_200_train_episodes"
val_data_root_dir = "/home/mkotynia/tensorflow_datasets"
batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize wandb
wandb.init(project="openvla-evaluation", name=f"val+{EXP}")

# Load processor and create action tokenizer
processor = AutoProcessor.from_pretrained(VLA_PATH, trust_remote_code=True)
action_tokenizer = ActionTokenizer(processor.tokenizer)

# Prepare validation dataset and dataloader
batch_transform = RLDSBatchTransform(
action_tokenizer,
processor.tokenizer,
image_transform=processor.image_processor.apply_transform,
prompt_builder_fn=PurePromptBuilder if "v01" not in BASE_PATH else VicunaV15ChatPromptBuilder,
)
val_dataset = RLDSDataset(
val_data_root_dir,
val_dataset_name,
batch_transform,
resize_resolution=(224, 224), # Update this with the correct image size
shuffle_buffer_size=10000,
image_aug=False,
train=False
)
collator = PaddedCollatorForActionPrediction(
processor.tokenizer.model_max_length, processor.tokenizer.pad_token_id, padding_side="right"
)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
collate_fn=collator,
num_workers=0,
)

# Evaluate checkpoints
for checkpoint_dir in natsorted(Path(ADAPTERS_DIR[0]).glob("step-*")):
step = int(checkpoint_dir.name.split("-")[-1])
print(f"Evaluating checkpoint at step {step}")

model = load_checkpoint(VLA_PATH, checkpoint_dir, device)
avg_loss, avg_accuracy, avg_l1_loss = evaluate(model, val_dataloader, action_tokenizer, device)

# Log metrics to wandb
wandb.log({
"step": step,
"val_loss": avg_loss,
"val_accuracy": avg_accuracy,
"val_l1_loss": avg_l1_loss,
})

print(f"Step {step}: Loss = {avg_loss:.4f}, Accuracy = {avg_accuracy:.4f}, L1 Loss = {avg_l1_loss:.4f}")
# Free up GPU memory
model.to('cpu') # Move model to CPU
del model # Delete the model object
torch.cuda.empty_cache() # Clear CUDA cache
gc.collect() # Run garbage collector
wandb.finish()

if __name__ == "__main__":
main()
15 changes: 15 additions & 0 deletions finetune_robotec.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash

torchrun --standalone --nnodes 1 --nproc-per-node 2 vla-scripts/finetune.py \
--vla_path "openvla/openvla-7b" \
--data_root_dir "/home/mkotynia/tensorflow_datasets" \
--dataset_name robotec_o3de_panda_dataset_4_cameras \
--run_root_dir "/robo-srv-004-storage-001/home/mkotynia/openvla/checkpoints" \
--adapter_tmp_dir "/robo-srv-004-storage-001/home/mkotynia/openvla/adapter_weights" \
--lora_rank 128 \
--batch_size 1 \
--grad_accumulation_steps 16 \
--image_aug False \
--wandb_project openvla_test \
--wandb_entity robotecai-ml \
--save_steps 50 \
20 changes: 20 additions & 0 deletions merge_adapter_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from transformers import AutoModelForVision2Seq, AutoProcessor, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
import torch
import os

VLA_PATH = "openvla/openvla-7b"

BASE_PATH = "/robo-srv-004-storage-001/home/mkotynia/openvla"
EXP = "openvla-7b+robotec_o3de_panda_dataset_4_cameras+b16+lr-0.0005+lr-decay+lora-r128+dropout-0.0"
STEP = "step-1200"
ADAPTER_DIR = os.path.join(BASE_PATH, "adapter_weights", EXP, STEP),
STEP_RUN_DIR = os.path.join(BASE_PATH, "checkpoints", EXP, STEP)

base_vla = AutoModelForVision2Seq.from_pretrained(
VLA_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
)
merged_vla = PeftModel.from_pretrained(base_vla, ADAPTER_DIR[0])
merged_vla = merged_vla.merge_and_unload()
merged_vla.save_pretrained(STEP_RUN_DIR)

5 changes: 4 additions & 1 deletion prismatic/vla/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class RLDSBatchTransform:
def __call__(self, rlds_batch: Dict[str, Any]) -> Dict[str, Any]:
"""Converts a RLDS batch to the format expected by the OpenVLA collator/models."""
dataset_name, action = rlds_batch["dataset_name"], rlds_batch["action"][0]
# NOTE(mkotynia): this line causes that only the image_primary is used even if there are configured more camera views in configs.py
img = Image.fromarray(rlds_batch["observation"]["image_primary"][0])
lang = rlds_batch["task"]["language_instruction"].decode().lower()

Expand Down Expand Up @@ -92,8 +93,10 @@ def __init__(
per_dataset_kwargs, weights = get_oxe_dataset_kwargs_and_weights(
self.data_root_dir,
mixture_spec,
# NOTE (mkotynia): uncomment when other camera views are supported
# load_camera_views=("primary", "secondary", "tertiary", "quaternary"),
load_camera_views=("primary",),
load_depth=False,
load_depth=False, #NOTE (mkotynia): change to True when depth is available
load_proprio=False,
load_language=True,
action_proprio_normalization_type=NormalizationType.BOUNDS_Q99,
Expand Down
58 changes: 57 additions & 1 deletion prismatic/vla/datasets/rlds/oxe/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class ActionEncoding(IntEnum):
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"bridge_orig": { # Original version of Bridge V2 from project website
"bridge": { # Original version of Bridge V2 from project website
"image_obs_keys": {"primary": "image_0", "secondary": "image_1", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
Expand Down Expand Up @@ -641,4 +641,60 @@ class ActionEncoding(IntEnum):
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"robotec_o3de_panda_dataset": { # Robotec O3DE dataset
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"robotec_o3de_panda_dataset_one_episode": { # Robotec O3DE dataset
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"robotec_o3de_panda_dataset_200_train_episodes": { # Robotec O3DE dataset
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"robotec_o3de_panda_dataset_4_cameras": { # Robotec O3DE dataset
"image_obs_keys": {"primary": "image1", "secondary": "image2", "tertiary": "image3", "quaternary": "image4", "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"robotec_o3de_panda_dataset_one_step": { # Robotec O3DE dataset
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"robotec_o3de_panda_dataset_5_steps": { # Robotec O3DE dataset
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"robotec_o3de_panda_dataset_200_train_episodes": { # Robotec O3DE dataset
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
"robotec_o3de_panda_dataset_v8": { # Robotec O3DE dataset
"image_obs_keys": {"primary": "image", "secondary": None, "wrist": None},
"depth_obs_keys": {"primary": None, "secondary": None, "wrist": None},
"state_obs_keys": ["EEF_state", None, "gripper_state"],
"state_encoding": StateEncoding.POS_EULER,
"action_encoding": ActionEncoding.EEF_POS,
},
}
14 changes: 7 additions & 7 deletions prismatic/vla/datasets/rlds/oxe/mixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# === Bridge V2 Dataset ===
"bridge": [
# ("bridge_oxe", 1.0), # Version of Bridge V2 in Open-X GCP Bucket
("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website
("bridge", 1.0), # Original Version of Bridge V2 from Project Website
],


# === [Moderate-Scale] Bridge++ Mixtures ===
"bridge_rt_1": [
# ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket
("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website
("bridge", 1.0), # Original Version of Bridge V2 from Project Website

("fractal20220817_data", 1.0), # Google RT-1 Robot Data (Large-Scale)
],
Expand All @@ -29,7 +29,7 @@
("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale)
("kuka", 0.8341046294),
# ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket
("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website
("bridge", 1.0), # Original Version of Bridge V2 from Project Website
("taco_play", 2.0),
("jaco_play", 2.0),
("berkeley_cable_routing", 3.0),
Expand All @@ -44,7 +44,7 @@
("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale)
("kuka", 0.8341046294),
# ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket
("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website
("bridge", 1.0), # Original Version of Bridge V2 from Project Website
("taco_play", 2.0),
("jaco_play", 2.0),
("berkeley_cable_routing", 3.0),
Expand Down Expand Up @@ -79,7 +79,7 @@
("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale)
("kuka", 0.8341046294),
# ("bridge_oxe", 1.0) # Version of Bridge V2 in Open-X GCP Bucket
("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website
("bridge", 1.0), # Original Version of Bridge V2 from Project Website
("taco_play", 2.0),
("jaco_play", 1.0),
("berkeley_cable_routing", 1.0),
Expand Down Expand Up @@ -109,7 +109,7 @@
"oxe_magic_soup_plus": [
("fractal20220817_data", 0.54087122203), # Google RT-1 Robot Data (Large-Scale)
("kuka", 0.8341046294),
("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website
("bridge", 1.0), # Original Version of Bridge V2 from Project Website
("taco_play", 2.0),
("jaco_play", 1.0),
("berkeley_cable_routing", 1.0),
Expand Down Expand Up @@ -140,7 +140,7 @@
"oxe_magic_soup_plus_minus": [
("fractal20220817_data", 1.0), # Google RT-1 Robot Data (Large-Scale)
("kuka", 0.8341046294),
("bridge_orig", 1.0), # Original Version of Bridge V2 from Project Website
("bridge", 1.0), # Original Version of Bridge V2 from Project Website
("taco_play", 2.0),
("jaco_play", 1.0),
("berkeley_cable_routing", 1.0),
Expand Down
17 changes: 14 additions & 3 deletions prismatic/vla/datasets/rlds/oxe/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def bridge_oxe_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory


def bridge_orig_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
def bridge_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
"""
Applies to original version of Bridge V2 from the official project website.

Expand Down Expand Up @@ -824,11 +824,22 @@ def tdroid_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory


def robotec_o3de_panda_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
return trajectory


# === Registry ===
OXE_STANDARDIZATION_TRANSFORMS = {
"robotec_o3de_panda_dataset": robotec_o3de_panda_dataset_transform, # NOTE: it is optional to use transforms
"robotec_o3de_panda_dataset_one_episode": robotec_o3de_panda_dataset_transform,
"robotec_o3de_panda_dataset_4_cameras": robotec_o3de_panda_dataset_transform,
"robotec_o3de_panda_dataset_one_step": robotec_o3de_panda_dataset_transform,
"robotec_o3de_panda_dataset_5_steps": robotec_o3de_panda_dataset_transform,
"robotec_o3de_panda_dataset_200_train_episodes": robotec_o3de_panda_dataset_transform,
"robotec_o3de_panda_dataset_v8": robotec_o3de_panda_dataset_transform,
"bridge_oxe": bridge_oxe_dataset_transform,
"bridge_orig": bridge_orig_dataset_transform,
"bridge_dataset": bridge_orig_dataset_transform,
"bridge": bridge_dataset_transform,
"bridge_dataset": bridge_dataset_transform,
"ppgm": ppgm_dataset_transform,
"ppgm_static": ppgm_dataset_transform,
"ppgm_wrist": ppgm_dataset_transform,
Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
[tool.poetry]
name = "openvla"
version = "0.1.0"
description = ""
authors = [""]

[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"
Expand Down
Loading