Skip to content

Commit

Permalink
change callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
sordonia committed Jun 26, 2023
1 parent 21d77a7 commit c4d6846
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
8 changes: 4 additions & 4 deletions mttl/models/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def validation_step(self, batch, batch_idx):
self.log("val/loss", mean_loss, on_epoch=True, prog_bar=True)
return loss, batch['task_ids']

def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self, outputs):
losses = torch.cat([out[0].sum(-1) for out in outputs], 0)
task_ids = torch.cat([out[1] for out in outputs], 0)

Expand Down Expand Up @@ -229,13 +229,13 @@ def validation_step(self, batch, batch_idx):
def test_step(self, batch, batch_idx):
return self.inference_step(batch)

def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self, outputs):
return self.inference_end(outputs, self.trainer.datamodule.dataset_reader, "val")

def test_epoch_end(self, outputs):
def on_test_epoch_end(self, outputs):
return self.inference_end(outputs, self.trainer.datamodule.dataset_reader, "test")

def training_epoch_end(self, losses):
def on_training_epoch_end(self, losses):
avg_loss = (sum([x["loss"] for x in losses]) / len(losses)).item()
lrs = [x["lr"] for x in self.optimizers().param_groups]
print(f"loss : {avg_loss:.4f}\tlr {lrs}\n")
4 changes: 2 additions & 2 deletions mttl/models/t0_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ def inference_epoch_end(self, outputs, split="val"):
metrics = {}
return metrics

def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self, outputs):
try:
# differentiate between fine-tuning phase / zero-shot phase and
# validation phase during training. this will raise because
Expand All @@ -502,7 +502,7 @@ def validation_epoch_end(self, outputs):
)
f.write(json.dumps(task_losses) + "\n")

def test_epoch_end(self, outputs):
def on_test_epoch_end(self, outputs):
return self.inference_epoch_end(outputs, split="test")

def configure_optimizers(self):
Expand Down
3 changes: 2 additions & 1 deletion mttl/online_eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import torch
from pytorch_lightning.callbacks.base import Callback

from pytorch_lightning.callbacks import Callback
from pytorch_lightning import Trainer

from mttl.datamodule.ni_data_module import NIDataModule
Expand Down
5 changes: 2 additions & 3 deletions pl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,16 @@ def run_multitask(args):
kwargs["enable_checkpointing"] = False

trainer = Trainer(
gpus=-1,
devices=-1,
accelerator="gpu",
logger=loggers,
num_sanity_val_steps=5,
amp_backend="native",
default_root_dir=args.output_dir,
max_epochs=args.num_train_epochs,
max_steps=args.total_steps + 1 if args.total_steps != -1 else -1,
gradient_clip_val=args.max_grad_norm,
log_every_n_steps=50,
strategy=args.compute_strategy if args.compute_strategy else None,
strategy=args.compute_strategy if args.compute_strategy else "auto",
callbacks=callbacks,
accumulate_grad_batches=args.gradient_accumulation_steps,
precision=int(args.precision)
Expand Down

0 comments on commit c4d6846

Please sign in to comment.