-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
```bash python examples/train_pointcloud.py --cuda --dataset="datasets/modelnet40" ```
- Loading branch information
1 parent
5b97f67
commit 630f15a
Showing
1 changed file
with
373 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,373 @@ | ||
# Copyright (c) 2021-2022, InterDigital Communications, Inc | ||
# All rights reserved. | ||
|
||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted (subject to the limitations in the disclaimer | ||
# below) provided that the following conditions are met: | ||
|
||
# * Redistributions of source code must retain the above copyright notice, | ||
# this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright notice, | ||
# this list of conditions and the following disclaimer in the documentation | ||
# and/or other materials provided with the distribution. | ||
# * Neither the name of InterDigital Communications, Inc nor the names of its | ||
# contributors may be used to endorse or promote products derived from this | ||
# software without specific prior written permission. | ||
|
||
# NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY | ||
# THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND | ||
# CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT | ||
# NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A | ||
# PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; | ||
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, | ||
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR | ||
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF | ||
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
import argparse | ||
import random | ||
import shutil | ||
import sys | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
|
||
from torch.utils.data import DataLoader | ||
from torchvision.transforms import Compose | ||
|
||
import compressai.transforms as transforms | ||
|
||
from compressai.datasets import ModelNetDataset | ||
from compressai.losses import ChamferPccRateDistortionLoss | ||
from compressai.optimizers import net_aux_optimizer | ||
from compressai.registry import MODELS | ||
from compressai.zoo import pointcloud_models | ||
|
||
|
||
class AverageMeter: | ||
"""Compute running average.""" | ||
|
||
def __init__(self): | ||
self.val = 0 | ||
self.avg = 0 | ||
self.sum = 0 | ||
self.count = 0 | ||
|
||
def update(self, val, n=1): | ||
self.val = val | ||
self.sum += val * n | ||
self.count += n | ||
self.avg = self.sum / self.count | ||
|
||
|
||
class CustomDataParallel(nn.DataParallel): | ||
"""Custom DataParallel to access the module methods.""" | ||
|
||
def __getattr__(self, key): | ||
try: | ||
return super().__getattr__(key) | ||
except AttributeError: | ||
return getattr(self.module, key) | ||
|
||
|
||
def configure_optimizers(net, args): | ||
"""Separate parameters for the main optimizer and the auxiliary optimizer. | ||
Return two optimizers""" | ||
conf = { | ||
"net": {"type": "Adam", "lr": args.learning_rate}, | ||
"aux": {"type": "Adam", "lr": args.aux_learning_rate}, | ||
} | ||
optimizer = net_aux_optimizer(net, conf) | ||
return optimizer["net"], optimizer["aux"] | ||
|
||
|
||
def train_one_epoch( | ||
model, criterion, train_dataloader, optimizer, aux_optimizer, epoch, clip_max_norm | ||
): | ||
model.train() | ||
device = next(model.parameters()).device | ||
|
||
for i, d in enumerate(train_dataloader): | ||
d = {k: v.to(device) for k, v in d.items()} | ||
|
||
optimizer.zero_grad() | ||
aux_optimizer.zero_grad() | ||
|
||
out_net = model(d) | ||
|
||
out_criterion = criterion(out_net, d) | ||
out_criterion["loss"].backward() | ||
if clip_max_norm > 0: | ||
torch.nn.utils.clip_grad_norm_(model.parameters(), clip_max_norm) | ||
optimizer.step() | ||
|
||
aux_loss = model.aux_loss() | ||
aux_loss.backward() | ||
aux_optimizer.step() | ||
|
||
if i % 10 == 0: | ||
print( | ||
f"Train epoch {epoch}: [" | ||
f"{i*len(d)}/{len(train_dataloader.dataset)} " | ||
f"({100. * i / len(train_dataloader):.0f}%)] " | ||
f'Loss: {out_criterion["loss"].item():.3f} | ' | ||
f'Bpp loss: {out_criterion["bpp_loss"].item():.3f} | ' | ||
f'Rec loss: {out_criterion["rec_loss"].item():.4f} | ' | ||
# f'Aux loss: {aux_loss.item():.0f} | ' | ||
"\n" | ||
) | ||
|
||
|
||
def test_epoch(epoch, test_dataloader, model, criterion): | ||
model.eval() | ||
model.update(force=True, update_quantiles=True) | ||
device = next(model.parameters()).device | ||
|
||
meter_keys = ["loss", "bpp_loss", "rec_loss", "aux_loss"] | ||
meters = {key: AverageMeter() for key in meter_keys} | ||
|
||
with torch.no_grad(): | ||
for d in test_dataloader: | ||
d = {k: v.to(device) for k, v in d.items()} | ||
|
||
out_net = model(d) | ||
out_criterion = criterion(out_net, d) | ||
out_criterion["aux_loss"] = model.aux_loss() | ||
|
||
for key in meters: | ||
if key in out_criterion: | ||
meters[key].update(out_criterion[key]) | ||
|
||
print( | ||
f"Test epoch {epoch}: Average losses: " | ||
f'Loss: {meters["loss"].avg:.3f} | ' | ||
f'Bpp loss: {meters["bpp_loss"].avg:.3f} | ' | ||
f'Rec loss: {meters["rec_loss"].avg:.4f} | ' | ||
# f'Aux loss: {meters["aux_loss"].avg:.0f} | ' | ||
"\n" | ||
) | ||
|
||
return meters["loss"].avg | ||
|
||
|
||
def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): | ||
torch.save(state, filename) | ||
if is_best: | ||
shutil.copyfile(filename, "checkpoint_best_loss.pth.tar") | ||
|
||
|
||
def parse_args(argv): | ||
parser = argparse.ArgumentParser(description="Example training script.") | ||
parser.add_argument( | ||
"-m", | ||
"--model", | ||
default="sfu2023-pcc-rec-pointnet", | ||
choices=pointcloud_models.keys(), | ||
help="Model architecture (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"-d", "--dataset", type=str, required=True, help="Training dataset" | ||
) | ||
parser.add_argument( | ||
"-e", | ||
"--epochs", | ||
default=100, | ||
type=int, | ||
help="Number of epochs (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"-lr", | ||
"--learning-rate", | ||
default=1e-4, | ||
type=float, | ||
help="Learning rate (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"-n", | ||
"--num-workers", | ||
type=int, | ||
default=4, | ||
help="Dataloaders threads (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--lambda", | ||
dest="lmbda", | ||
type=float, | ||
default=100, | ||
help="Bit-rate distortion parameter (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--batch-size", type=int, default=16, help="Batch size (default: %(default)s)" | ||
) | ||
parser.add_argument( | ||
"--test-batch-size", | ||
type=int, | ||
default=64, | ||
help="Test batch size (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--aux-learning-rate", | ||
type=float, | ||
default=1e-3, | ||
help="Auxiliary loss learning rate (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--patch-size", | ||
type=int, | ||
nargs=2, | ||
default=(256, 256), | ||
help="Size of the patches to be cropped (default: %(default)s)", | ||
) | ||
parser.add_argument("--cuda", action="store_true", help="Use cuda") | ||
parser.add_argument( | ||
"--save", action="store_true", default=True, help="Save model to disk" | ||
) | ||
parser.add_argument("--seed", type=int, help="Set random seed for reproducibility") | ||
parser.add_argument( | ||
"--clip_max_norm", | ||
default=1.0, | ||
type=float, | ||
help="gradient clipping max norm (default: %(default)s", | ||
) | ||
parser.add_argument("--checkpoint", type=str, help="Path to a checkpoint") | ||
args = parser.parse_args(argv) | ||
return args | ||
|
||
|
||
def main(argv): | ||
args = parse_args(argv) | ||
|
||
if args.seed is not None: | ||
torch.manual_seed(args.seed) | ||
random.seed(args.seed) | ||
|
||
num_points = 1024 | ||
|
||
train_dataset = ModelNetDataset( | ||
args.dataset, | ||
split="train", | ||
pre_transform=Compose( | ||
[ | ||
transforms.ToDict(wrapper="torch_geometric.data.Data"), | ||
transforms.SamplePointsV2( | ||
num=8192, remove_faces=True, include_normals=True, static_seed=1234 | ||
), | ||
transforms.NormalizeScaleV2(center=True, scale_method="l2"), | ||
transforms.ToDict(wrapper="dict"), | ||
] | ||
), | ||
transform=Compose( | ||
[ | ||
transforms.ToDict(wrapper="torch_geometric.data.Data"), | ||
transforms.RandomSample(num=num_points, attrs=["pos", "normal"]), | ||
transforms.ToDict(wrapper="dict"), | ||
] | ||
), | ||
) | ||
|
||
test_dataset = ModelNetDataset( | ||
args.dataset, | ||
split="test", | ||
pre_transform=Compose( | ||
[ | ||
transforms.ToDict(wrapper="torch_geometric.data.Data"), | ||
transforms.SamplePointsV2( | ||
num=8192, remove_faces=True, include_normals=True, static_seed=1234 | ||
), | ||
transforms.NormalizeScaleV2(center=True, scale_method="l2"), | ||
transforms.ToDict(wrapper="dict"), | ||
] | ||
), | ||
transform=Compose( | ||
[ | ||
transforms.ToDict(wrapper="torch_geometric.data.Data"), | ||
transforms.RandomSample( | ||
num=num_points, attrs=["pos", "normal"], static_seed=1234 | ||
), | ||
transforms.ToDict(wrapper="dict"), | ||
] | ||
), | ||
) | ||
|
||
device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" | ||
|
||
train_dataloader = DataLoader( | ||
train_dataset, | ||
batch_size=args.batch_size, | ||
num_workers=args.num_workers, | ||
shuffle=True, | ||
pin_memory=(device == "cuda"), | ||
) | ||
|
||
test_dataloader = DataLoader( | ||
test_dataset, | ||
batch_size=args.test_batch_size, | ||
num_workers=args.num_workers, | ||
shuffle=False, | ||
pin_memory=(device == "cuda"), | ||
) | ||
|
||
net = MODELS[args.model]() | ||
net = net.to(device) | ||
|
||
if args.cuda and torch.cuda.device_count() > 1: | ||
net = CustomDataParallel(net) | ||
|
||
optimizer, aux_optimizer = configure_optimizers(net, args) | ||
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") | ||
criterion = ChamferPccRateDistortionLoss(lmbda={"bpp": 1.0, "rec": args.lmbda}) | ||
|
||
last_epoch = 0 | ||
if args.checkpoint: # load from previous checkpoint | ||
print("Loading", args.checkpoint) | ||
checkpoint = torch.load(args.checkpoint, map_location=device) | ||
last_epoch = checkpoint["epoch"] + 1 | ||
net.load_state_dict(checkpoint["state_dict"]) | ||
optimizer.load_state_dict(checkpoint["optimizer"]) | ||
aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) | ||
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) | ||
|
||
best_loss = float("inf") | ||
for epoch in range(last_epoch, args.epochs): | ||
print(f"Learning rate: {optimizer.param_groups[0]['lr']}") | ||
train_one_epoch( | ||
net, | ||
criterion, | ||
train_dataloader, | ||
optimizer, | ||
aux_optimizer, | ||
epoch, | ||
args.clip_max_norm, | ||
) | ||
loss = test_epoch(epoch, test_dataloader, net, criterion) | ||
lr_scheduler.step(loss) | ||
|
||
is_best = loss < best_loss | ||
best_loss = min(loss, best_loss) | ||
|
||
if args.save: | ||
save_checkpoint( | ||
{ | ||
"epoch": epoch, | ||
"state_dict": net.state_dict(), | ||
"loss": loss, | ||
"optimizer": optimizer.state_dict(), | ||
"aux_optimizer": aux_optimizer.state_dict(), | ||
"lr_scheduler": lr_scheduler.state_dict(), | ||
}, | ||
is_best, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main(sys.argv[1:]) | ||
|
||
|
||
# NOTE: A more complete trainer with experiment tracking, visualizations, etc | ||
# that uses CompressAI Trainer can be found at: | ||
# | ||
# https://github.com/multimedialabsfu/learned-point-cloud-compression-for-classification |