Skip to content

Commit

Permalink
feat(ml): The implementation of UNetVid for generating video with tem…
Browse files Browse the repository at this point in the history
…poral consistency and inference

feat(ml): The implementation of UNetVid for generating video with temporal consistency and inference

feat(ml): add UNetVid for generating video with temporal consistency and
inference

feat(ml):add UNetVid for generating video with temporal consistency

feat(ml):step1 ResBlock input/output 5D tensor image

feat(ml):step2 replace AttentionBlock by MotionModule.
ResBlock/MotionModule class instance pass

feat(ml): debug for UNet works for 5D tensor

feat(ml):UNet=ResBlock+Attention(optional)+MM

feat(ml): create UNetVid class with temporal MHA for U-Net

feat(ml):add dataloader

feat(ml): dataloader works with UNet

feat(ml): dataloader and UNetVid works for input (b,f,c,h,w),not visdom
yet

feat(ml):visdom shows the trainning

feat(ml): debug for 5D

feat(ml):debug 5D black

feat(ml):typo

feat(ml):dataloader with mask

feat(ml): dataloader fixed with command-line

feat(ml): visdom show one batch of frame

feat(ml): frame is treated as a batch, so no additional normailisation
is needed

feat(ml): inference for UNetVid

feat(ml): use efficient_attention_xformers for attention

feat(ml): xformer bug PR

feat(ml): create video based on generated and orig images

feat(ml):remove unnecessary option --UNetVid

feat(ml): add doc for trainning and inference

feat(ml): fix inference paths requirement

feat(ml):clear code

feat(ml):black format

feat(ml): improve the inference for any paths.txt and longer frames
  • Loading branch information
wr0124 committed Aug 2, 2024
1 parent 81ab4d2 commit 2ef3726
Show file tree
Hide file tree
Showing 12 changed files with 2,734 additions and 23 deletions.
7 changes: 6 additions & 1 deletion data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,12 @@ def create_dataloader(opt, rank, dataset, batch_size):


def create_dataset_temporal(opt, phase):
dataset_class = find_dataset_using_name("temporal_labeled_mask_online")
if opt.model_type == "palette":
dataset_class = find_dataset_using_name(
"self_supervised_temporal_labeled_mask_online"
)
if opt.model_type == "cut":
dataset_class = find_dataset_using_name("temporal_labeled_mask_online")
dataset = dataset_class(opt, phase)
return dataset

Expand Down
201 changes: 201 additions & 0 deletions data/self_supervised_temporal_labeled_mask_online_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
import os
import random
import re

import torch

from data.base_dataset import BaseDataset, get_transform_list
from data.image_folder import make_labeled_path_dataset
from data.online_creation import crop_image
from data.online_creation import fill_mask_with_random, fill_mask_with_color


def atoi(text):
return int(text) if text.isdigit() else text


def natural_keys(text):
return [atoi(c) for c in re.split("(\d+)", text)]


class SelfSupervisedTemporalLabeledMaskOnlineDataset(BaseDataset):
def __len__(self):
"""Return the total number of images in the dataset.
As we have two datasets with potentially different number of images,
we take a maximum of
"""
if hasattr(self, "B_img_paths"):
return max(self.A_size, self.B_size)
else:
return self.A_size

def __init__(self, opt, phase, name=""):
BaseDataset.__init__(self, opt, phase, name)

self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset(
self.dir_A, "/paths.txt"
) # load images from '/path/to/data/trainA/paths.txt' as well as labels
# sort
self.A_img_paths.sort(key=natural_keys)
self.A_label_mask_paths.sort(key=natural_keys)

if self.opt.data_sanitize_paths:
self.sanitize()
elif opt.data_max_dataset_size != float("inf"):
self.A_img_paths, self.A_label_mask_paths = (
self.A_img_paths[: opt.data_max_dataset_size],
self.A_label_mask_paths[: opt.data_max_dataset_size],
)

self.transform = get_transform_list(self.opt, grayscale=(self.input_nc == 1))

self.num_frames = opt.data_temporal_number_frames
self.frame_step = opt.data_temporal_frame_step

self.num_A = len(self.A_img_paths)
self.range_A = self.num_A - self.num_frames * self.frame_step

self.num_common_char = self.opt.data_temporal_num_common_char

self.opt = opt

self.A_size = len(self.A_img_paths) # get the size of dataset A

def get_img(
self,
A_img_path,
A_label_mask_path,
A_label_cls,
B_img_path=None,
B_label_mask_path=None,
B_label_cls=None,
index=None,
): # all params are unused
index_A = random.randint(0, self.range_A - 1)

images_A = []
labels_A = []

ref_A_img_path = self.A_img_paths[index_A]

ref_name_A = ref_A_img_path.split("/")[-1][: self.num_common_char]

for i in range(self.num_frames):
cur_index_A = index_A + i * self.frame_step

if (
self.num_common_char != -1
and self.A_img_paths[cur_index_A].split("/")[-1][: self.num_common_char]
not in ref_name_A
):
return None

cur_A_img_path, cur_A_label_path = (
self.A_img_paths[cur_index_A],
self.A_label_mask_paths[cur_index_A],
)
if self.opt.data_relative_paths:
cur_A_img_path = os.path.join(self.root, cur_A_img_path)
if cur_A_label_path is not None:
cur_A_label_path = os.path.join(self.root, cur_A_label_path)

try:
if self.opt.data_online_creation_mask_delta_A_ratio == [[]]:
mask_delta_A = self.opt.data_online_creation_mask_delta_A
else:
mask_delta_A = self.opt.data_online_creation_mask_delta_A_ratio

if i == 0:
crop_coordinates = crop_image(
cur_A_img_path,
cur_A_label_path,
mask_delta=mask_delta_A,
mask_random_offset=self.opt.data_online_creation_mask_random_offset_A,
crop_delta=self.opt.data_online_creation_crop_delta_A,
mask_square=self.opt.data_online_creation_mask_square_A,
crop_dim=self.opt.data_online_creation_crop_size_A,
output_dim=self.opt.data_load_size,
context_pixels=self.opt.data_online_context_pixels,
load_size=self.opt.data_online_creation_load_size_A,
get_crop_coordinates=True,
fixed_mask_size=self.opt.data_online_fixed_mask_size,
)
cur_A_img, cur_A_label, ref_A_bbox, A_ref_bbox_id = crop_image(
cur_A_img_path,
cur_A_label_path,
mask_delta=mask_delta_A,
mask_random_offset=self.opt.data_online_creation_mask_random_offset_A,
crop_delta=self.opt.data_online_creation_crop_delta_A,
mask_square=self.opt.data_online_creation_mask_square_A,
crop_dim=self.opt.data_online_creation_crop_size_A,
output_dim=self.opt.data_load_size,
context_pixels=self.opt.data_online_context_pixels,
load_size=self.opt.data_online_creation_load_size_A,
crop_coordinates=crop_coordinates,
fixed_mask_size=self.opt.data_online_fixed_mask_size,
)
if i == 0:
A_ref_bbox = ref_A_bbox[1:]

except Exception as e:
print(e, f"{i+1}th frame of domain A in temporal dataloading")
return None

images_A.append(cur_A_img)
labels_A.append(cur_A_label)

images_A, labels_A, A_ref_bbox = self.transform(images_A, labels_A, A_ref_bbox)
A_ref_label = labels_A[0]
A_ref_img = images_A[0]
images_A = torch.stack(images_A)
labels_A = torch.stack(labels_A)

result = {
"A_ref": A_ref_img,
"A": images_A,
"A_img_paths": ref_A_img_path,
"A_ref_bbox": A_ref_bbox,
"A_label_mask": labels_A,
"A_ref_label_mask": A_ref_label,
"B_ref": A_ref_img,
"B": images_A,
"B_img_paths": ref_A_img_path,
"B_ref_bbox": A_ref_bbox,
"B_label_mask": labels_A,
"B_ref_label_mask": A_ref_label,
}

try:
if self.opt.data_online_creation_rand_mask_A:
A_ref_img = fill_mask_with_random(
result["A_ref"], result["A_ref_label_mask"], -1
)
images_A = fill_mask_with_random(
result["A"], result["A_label_mask"], -1
)
elif self.opt.data_online_creation_color_mask_A:
A_ref_img = fill_mask_with_color(
result["A_ref"], result["A_ref_label_mask"], {}
)
images_A = fill_mask_with_color(result["A"], result["A_label_mask"], {})
else:
raise Exception(
"self supervised dataset: no self supervised method specified"
)

result.update(
{
"A_ref": A_ref_img,
"A": images_A,
"A_img_paths": ref_A_img_path,
}
)
except Exception as e:
print(
e,
"self supervised temporal labeled mask online data loading from ",
ref_A_img_path,
)
return None

return result
35 changes: 35 additions & 0 deletions docs/source/inference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -310,3 +310,38 @@ The output files will be in the ``mapillary_inference_output`` folder, with:
- ``img_0_y_0.png``: the original image resized after conditioning image insertion

- ``img_0_y_t.png``: the noisy image given to the model

******************************************************
Generate a video with diffusion model for inpainting
******************************************************

Download the testdataset & pretrained model
=====================================

.. code:: bash
wget https://www.joligen.com/models/mario_vid.zip
unzip mario_vid.zip -d checkpoints
rm mario_vid.zip
wget https://www.joligen.com/datasets/online_mario2sonic_lite2.zip
unzip online_mario2sonic_lite2.zip -d online_mario2sonic_lite2
rm online_mario2sonic_lite2.zip
Run the inference script
========================

.. code:: bash
cd scripts
python3 gen_vid_diffusion.py\
--model_in_file ../checkpoints/latest_net_G_A.pth\
--img_in ../image_path\
--paths_file ../datasets/online_mario2sonic_video/trainA/paths.txt\
--mask_in ../mask_file\
--data_root ../datasets/online_mario2sonic_video/
--dir_out ../inference_mario_vid\
--img_width 128\
--img_height 128\
The output files will be in the ``inference_mario_vid`` folder, with ``mario_video_0_generated.avi`` for the generated video and ``mario_video_0_orig.avi`` for the original frames.
12 changes: 12 additions & 0 deletions docs/source/training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,15 @@ Trains a consistency model to insert glasses onto faces.
.. code:: bash
python3 train.py --dataroot /path/to/data/noglasses2glasses_ffhq --checkpoints_dir /path/to/checkpoints --name noglasses2glasses --config_json examples/example_cm_noglasses2glasses.json
*************************************************************
DDPM training for video generation with inpainting
*************************************************************

Dataset: https://joligen.com/datasets/online_mario2sonic_full.tar

Train a DDPM model to generate a sequence of frame images for inpainting, ensuring temporal consistency throughout the series of frames.

.. code:: bash
python3 train.py --dataroot /path/to/data/online_mario2sonic_full --checkpoints_dir /path/to/checkpoints --name mario_vid --config_json examples/example_ddpm_vid_mario.json
4 changes: 3 additions & 1 deletion models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,9 @@ def get_current_visuals(self, nb_imgs, phase="train", test_name=""):
cur_visual[name] = getattr(self, name)
visual_ret.append(cur_visual)
if (
self.opt.model_type != "cut" and self.opt.model_type != "cycle_gan"
self.opt.model_type != "cut"
and self.opt.model_type != "cycle_gan"
and not self.opt.G_netG == "unet_vid"
): # GANs have more outputs in practice, including semantics
if i == nb_imgs - 1:
break
Expand Down
29 changes: 29 additions & 0 deletions models/diffusion_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
UViT,
UNetGeneratorRefAttn,
)
from .modules.unet_generator_attn.unet_generator_attn_vid import UNetVid

from .modules.hdit.hdit import HDiT, HDiTConfig

from .modules.palette_denoise_fn import PaletteDenoiseFn
Expand Down Expand Up @@ -124,6 +126,33 @@ def define_G(
freq_space=train_feat_wavelet,
)

elif G_netG == "unet_vid":
if model_prior_321_backwardcompatibility:
cond_embed_dim = G_ngf * 4
else:
cond_embed_dim = alg_diffusion_cond_embed_dim

model = UNetVid(
image_size=data_crop_size,
in_channel=in_channel,
inner_channel=G_ngf,
out_channel=model_output_nc,
res_blocks=G_unet_mha_res_blocks,
attn_res=G_unet_mha_attn_res,
num_heads=G_unet_mha_num_heads,
num_head_channels=G_unet_mha_num_head_channels,
tanh=False,
dropout=G_dropout,
n_timestep_train=G_diff_n_timestep_train,
n_timestep_test=G_diff_n_timestep_test,
channel_mults=G_unet_mha_channel_mults,
norm=G_unet_mha_norm_layer,
group_norm_size=G_unet_mha_group_norm_size,
efficient=G_unet_mha_vit_efficient,
cond_embed_dim=cond_embed_dim,
freq_space=train_feat_wavelet,
)

elif G_netG == "unet_mha_ref_attn":
cond_embed_dim = alg_diffusion_cond_embed_dim

Expand Down
13 changes: 8 additions & 5 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,10 @@ def p_mean_variance(

embed_noise_level = self.compute_gammas(noise_level)

input = torch.cat([y_cond, y_t], dim=1)

if len(y_cond.shape) == 5 and len(y_t.shape) == 5:
input = torch.cat([y_cond, y_t], dim=2)
else:
input = torch.cat([y_cond, y_t], dim=1)
if guidance_scale > 0.0 and phase == "test":
y_0_hat_uncond = predict_start_from_noise(
self.denoise_fn.model,
Expand Down Expand Up @@ -451,9 +453,10 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
if mask is not None:
temp_mask = torch.clamp(mask, min=0.0, max=1.0)
y_noisy = y_noisy * temp_mask + (1.0 - temp_mask) * y_0

input = torch.cat([y_cond, y_noisy], dim=1)

if len(y_cond.shape) == 5 and len(y_noisy.shape) == 5:
input = torch.cat([y_cond, y_noisy], dim=2)
else:
input = torch.cat([y_cond, y_noisy], dim=1)
noise_hat = self.denoise_fn(
input, embed_sample_gammas, cls=cls, mask=mask, ref=ref
)
Expand Down
30 changes: 29 additions & 1 deletion models/modules/unet_generator_attn/unet_attn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,38 @@
import numpy as np
import torch
import torch.nn as nn

from einops import rearrange
from .switchable_norm import SwitchNorm2d


class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[1]
input_channels = x.shape[2]
x = rearrange(x, "b f c h w -> (b f) c h w")
expected_channels = self.in_channels
if input_channels != expected_channels:
raise ValueError(
f"Expected input channels: {expected_channels}, but got: {input_channels}"
)

x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b f c h w", f=video_length)

return x


class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
video_length = x.shape[1]

x = rearrange(x, "b f c h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b f c h w", f=video_length)

return x


class GroupNorm(nn.Module):
def __init__(self, group_size, channels):
super().__init__()
Expand Down
Loading

0 comments on commit 2ef3726

Please sign in to comment.