Skip to content

Commit

Permalink
chore: set custom lr defaultly as False, added logging actions Euclid…
Browse files Browse the repository at this point in the history
…ean distance in wandb
  • Loading branch information
MagdalenaKotynia committed Aug 23, 2024
1 parent 6fc8adb commit 09435f4
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion vla-scripts/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from prismatic.vla.action_tokenizer import ActionTokenizer
from prismatic.vla.datasets import RLDSBatchTransform, RLDSDataset
from prismatic.vla.datasets.rlds.utils.data_utils import save_dataset_statistics
import numpy as np

# Sane Defaults
os.environ["TOKENIZERS_PARALLELISM"] = "false"
Expand Down Expand Up @@ -98,7 +99,7 @@ class FinetuneConfig:
wandb_entity: str = "stanford-voltron" # Name of entity to log under

# learning rate decay
use_lr_decay: bool = True # Whether to use learning rate decay
use_lr_decay: bool = False # Whether to use learning rate decay
num_warmup_steps: int = 100 # Number of warmup steps
# lr_decay_step_size: int = 50 # Number of steps over which to decay the learning rate
# gamma = 0.5
Expand Down Expand Up @@ -241,6 +242,7 @@ def finetune(cfg: FinetuneConfig) -> None:
recent_losses = deque(maxlen=cfg.grad_accumulation_steps)
recent_action_accuracies = deque(maxlen=cfg.grad_accumulation_steps)
recent_l1_losses = deque(maxlen=cfg.grad_accumulation_steps)
recent_action_distances = deque(maxlen=cfg.grad_accumulation_steps)

# Train!
with tqdm.tqdm(total=cfg.max_steps, leave=False) as progress:
Expand Down Expand Up @@ -280,11 +282,13 @@ def finetune(cfg: FinetuneConfig) -> None:
action_tokenizer.decode_token_ids_to_actions(action_gt[mask].cpu().numpy())
)
action_l1_loss = torch.nn.functional.l1_loss(continuous_actions_pred, continuous_actions_gt)
action_distance = np.linalg.norm(continuous_actions_gt - continuous_actions_pred)

# Store recent train metrics
recent_losses.append(loss.item())
recent_action_accuracies.append(action_accuracy.item())
recent_l1_losses.append(action_l1_loss.item())
recent_action_distances.append(action_distance.item())

# Compute gradient step index
gradient_step_idx = batch_idx // cfg.grad_accumulation_steps
Expand All @@ -295,6 +299,7 @@ def finetune(cfg: FinetuneConfig) -> None:
smoothened_loss = sum(recent_losses) / len(recent_losses)
smoothened_action_accuracy = sum(recent_action_accuracies) / len(recent_action_accuracies)
smoothened_l1_loss = sum(recent_l1_losses) / len(recent_l1_losses)
smoothened_action_distance = sum(recent_action_distances) / len(recent_action_distances)

# Push Metrics to W&B (every 10 gradient steps)
if distributed_state.is_main_process and gradient_step_idx % 5 == 0:
Expand All @@ -303,6 +308,7 @@ def finetune(cfg: FinetuneConfig) -> None:
"action_accuracy": smoothened_action_accuracy,
"l1_loss": smoothened_l1_loss,
"learning_rate": scheduler.get_last_lr()[0] if scheduler else cfg.learning_rate,
"action_distance": smoothened_action_distance
}, step=gradient_step_idx)

# Optimizer Step
Expand Down

0 comments on commit 09435f4

Please sign in to comment.