From 2ef3726c0bc1cd4216ca1f09799c26a9c3b2d034 Mon Sep 17 00:00:00 2001 From: julie wang Date: Wed, 10 Jul 2024 09:41:34 +0200 Subject: [PATCH] feat(ml): The implementation of UNetVid for generating video with temporal consistency and inference feat(ml): The implementation of UNetVid for generating video with temporal consistency and inference feat(ml): add UNetVid for generating video with temporal consistency and inference feat(ml):add UNetVid for generating video with temporal consistency feat(ml):step1 ResBlock input/output 5D tensor image feat(ml):step2 replace AttentionBlock by MotionModule. ResBlock/MotionModule class instance pass feat(ml): debug for UNet works for 5D tensor feat(ml):UNet=ResBlock+Attention(optional)+MM feat(ml): create UNetVid class with temporal MHA for U-Net feat(ml):add dataloader feat(ml): dataloader works with UNet feat(ml): dataloader and UNetVid works for input (b,f,c,h,w),not visdom yet feat(ml):visdom shows the trainning feat(ml): debug for 5D feat(ml):debug 5D black feat(ml):typo feat(ml):dataloader with mask feat(ml): dataloader fixed with command-line feat(ml): visdom show one batch of frame feat(ml): frame is treated as a batch, so no additional normailisation is needed feat(ml): inference for UNetVid feat(ml): use efficient_attention_xformers for attention feat(ml): xformer bug PR feat(ml): create video based on generated and orig images feat(ml):remove unnecessary option --UNetVid feat(ml): add doc for trainning and inference feat(ml): fix inference paths requirement feat(ml):clear code feat(ml):black format feat(ml): improve the inference for any paths.txt and longer frames --- data/__init__.py | 7 +- ...ed_temporal_labeled_mask_online_dataset.py | 201 +++ docs/source/inference.rst | 35 + docs/source/training.rst | 12 + models/base_model.py | 4 +- models/diffusion_networks.py | 29 + models/modules/diffusion_generator.py | 13 +- .../unet_generator_attn/unet_attn_utils.py | 30 +- .../unet_generator_attn_vid.py | 1404 +++++++++++++++++ models/palette_model.py | 52 +- options/common_options.py | 2 + scripts/gen_vid_diffusion.py | 968 ++++++++++++ 12 files changed, 2734 insertions(+), 23 deletions(-) create mode 100644 data/self_supervised_temporal_labeled_mask_online_dataset.py create mode 100644 models/modules/unet_generator_attn/unet_generator_attn_vid.py create mode 100644 scripts/gen_vid_diffusion.py diff --git a/data/__init__.py b/data/__init__.py index aea5cfc7c..97dfbbfb5 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -61,7 +61,12 @@ def create_dataloader(opt, rank, dataset, batch_size): def create_dataset_temporal(opt, phase): - dataset_class = find_dataset_using_name("temporal_labeled_mask_online") + if opt.model_type == "palette": + dataset_class = find_dataset_using_name( + "self_supervised_temporal_labeled_mask_online" + ) + if opt.model_type == "cut": + dataset_class = find_dataset_using_name("temporal_labeled_mask_online") dataset = dataset_class(opt, phase) return dataset diff --git a/data/self_supervised_temporal_labeled_mask_online_dataset.py b/data/self_supervised_temporal_labeled_mask_online_dataset.py new file mode 100644 index 000000000..acd5217e0 --- /dev/null +++ b/data/self_supervised_temporal_labeled_mask_online_dataset.py @@ -0,0 +1,201 @@ +import os +import random +import re + +import torch + +from data.base_dataset import BaseDataset, get_transform_list +from data.image_folder import make_labeled_path_dataset +from data.online_creation import crop_image +from data.online_creation import fill_mask_with_random, fill_mask_with_color + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + return [atoi(c) for c in re.split("(\d+)", text)] + + +class SelfSupervisedTemporalLabeledMaskOnlineDataset(BaseDataset): + def __len__(self): + """Return the total number of images in the dataset. + As we have two datasets with potentially different number of images, + we take a maximum of + """ + if hasattr(self, "B_img_paths"): + return max(self.A_size, self.B_size) + else: + return self.A_size + + def __init__(self, opt, phase, name=""): + BaseDataset.__init__(self, opt, phase, name) + + self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset( + self.dir_A, "/paths.txt" + ) # load images from '/path/to/data/trainA/paths.txt' as well as labels + # sort + self.A_img_paths.sort(key=natural_keys) + self.A_label_mask_paths.sort(key=natural_keys) + + if self.opt.data_sanitize_paths: + self.sanitize() + elif opt.data_max_dataset_size != float("inf"): + self.A_img_paths, self.A_label_mask_paths = ( + self.A_img_paths[: opt.data_max_dataset_size], + self.A_label_mask_paths[: opt.data_max_dataset_size], + ) + + self.transform = get_transform_list(self.opt, grayscale=(self.input_nc == 1)) + + self.num_frames = opt.data_temporal_number_frames + self.frame_step = opt.data_temporal_frame_step + + self.num_A = len(self.A_img_paths) + self.range_A = self.num_A - self.num_frames * self.frame_step + + self.num_common_char = self.opt.data_temporal_num_common_char + + self.opt = opt + + self.A_size = len(self.A_img_paths) # get the size of dataset A + + def get_img( + self, + A_img_path, + A_label_mask_path, + A_label_cls, + B_img_path=None, + B_label_mask_path=None, + B_label_cls=None, + index=None, + ): # all params are unused + index_A = random.randint(0, self.range_A - 1) + + images_A = [] + labels_A = [] + + ref_A_img_path = self.A_img_paths[index_A] + + ref_name_A = ref_A_img_path.split("/")[-1][: self.num_common_char] + + for i in range(self.num_frames): + cur_index_A = index_A + i * self.frame_step + + if ( + self.num_common_char != -1 + and self.A_img_paths[cur_index_A].split("/")[-1][: self.num_common_char] + not in ref_name_A + ): + return None + + cur_A_img_path, cur_A_label_path = ( + self.A_img_paths[cur_index_A], + self.A_label_mask_paths[cur_index_A], + ) + if self.opt.data_relative_paths: + cur_A_img_path = os.path.join(self.root, cur_A_img_path) + if cur_A_label_path is not None: + cur_A_label_path = os.path.join(self.root, cur_A_label_path) + + try: + if self.opt.data_online_creation_mask_delta_A_ratio == [[]]: + mask_delta_A = self.opt.data_online_creation_mask_delta_A + else: + mask_delta_A = self.opt.data_online_creation_mask_delta_A_ratio + + if i == 0: + crop_coordinates = crop_image( + cur_A_img_path, + cur_A_label_path, + mask_delta=mask_delta_A, + mask_random_offset=self.opt.data_online_creation_mask_random_offset_A, + crop_delta=self.opt.data_online_creation_crop_delta_A, + mask_square=self.opt.data_online_creation_mask_square_A, + crop_dim=self.opt.data_online_creation_crop_size_A, + output_dim=self.opt.data_load_size, + context_pixels=self.opt.data_online_context_pixels, + load_size=self.opt.data_online_creation_load_size_A, + get_crop_coordinates=True, + fixed_mask_size=self.opt.data_online_fixed_mask_size, + ) + cur_A_img, cur_A_label, ref_A_bbox, A_ref_bbox_id = crop_image( + cur_A_img_path, + cur_A_label_path, + mask_delta=mask_delta_A, + mask_random_offset=self.opt.data_online_creation_mask_random_offset_A, + crop_delta=self.opt.data_online_creation_crop_delta_A, + mask_square=self.opt.data_online_creation_mask_square_A, + crop_dim=self.opt.data_online_creation_crop_size_A, + output_dim=self.opt.data_load_size, + context_pixels=self.opt.data_online_context_pixels, + load_size=self.opt.data_online_creation_load_size_A, + crop_coordinates=crop_coordinates, + fixed_mask_size=self.opt.data_online_fixed_mask_size, + ) + if i == 0: + A_ref_bbox = ref_A_bbox[1:] + + except Exception as e: + print(e, f"{i+1}th frame of domain A in temporal dataloading") + return None + + images_A.append(cur_A_img) + labels_A.append(cur_A_label) + + images_A, labels_A, A_ref_bbox = self.transform(images_A, labels_A, A_ref_bbox) + A_ref_label = labels_A[0] + A_ref_img = images_A[0] + images_A = torch.stack(images_A) + labels_A = torch.stack(labels_A) + + result = { + "A_ref": A_ref_img, + "A": images_A, + "A_img_paths": ref_A_img_path, + "A_ref_bbox": A_ref_bbox, + "A_label_mask": labels_A, + "A_ref_label_mask": A_ref_label, + "B_ref": A_ref_img, + "B": images_A, + "B_img_paths": ref_A_img_path, + "B_ref_bbox": A_ref_bbox, + "B_label_mask": labels_A, + "B_ref_label_mask": A_ref_label, + } + + try: + if self.opt.data_online_creation_rand_mask_A: + A_ref_img = fill_mask_with_random( + result["A_ref"], result["A_ref_label_mask"], -1 + ) + images_A = fill_mask_with_random( + result["A"], result["A_label_mask"], -1 + ) + elif self.opt.data_online_creation_color_mask_A: + A_ref_img = fill_mask_with_color( + result["A_ref"], result["A_ref_label_mask"], {} + ) + images_A = fill_mask_with_color(result["A"], result["A_label_mask"], {}) + else: + raise Exception( + "self supervised dataset: no self supervised method specified" + ) + + result.update( + { + "A_ref": A_ref_img, + "A": images_A, + "A_img_paths": ref_A_img_path, + } + ) + except Exception as e: + print( + e, + "self supervised temporal labeled mask online data loading from ", + ref_A_img_path, + ) + return None + + return result diff --git a/docs/source/inference.rst b/docs/source/inference.rst index 150eb1469..1a7bd1b29 100644 --- a/docs/source/inference.rst +++ b/docs/source/inference.rst @@ -310,3 +310,38 @@ The output files will be in the ``mapillary_inference_output`` folder, with: - ``img_0_y_0.png``: the original image resized after conditioning image insertion - ``img_0_y_t.png``: the noisy image given to the model + +****************************************************** + Generate a video with diffusion model for inpainting +****************************************************** + +Download the testdataset & pretrained model +===================================== + +.. code:: bash + + wget https://www.joligen.com/models/mario_vid.zip + unzip mario_vid.zip -d checkpoints + rm mario_vid.zip + + wget https://www.joligen.com/datasets/online_mario2sonic_lite2.zip + unzip online_mario2sonic_lite2.zip -d online_mario2sonic_lite2 + rm online_mario2sonic_lite2.zip + +Run the inference script +======================== + +.. code:: bash + + cd scripts + python3 gen_vid_diffusion.py\ + --model_in_file ../checkpoints/latest_net_G_A.pth\ + --img_in ../image_path\ + --paths_file ../datasets/online_mario2sonic_video/trainA/paths.txt\ + --mask_in ../mask_file\ + --data_root ../datasets/online_mario2sonic_video/ + --dir_out ../inference_mario_vid\ + --img_width 128\ + --img_height 128\ + +The output files will be in the ``inference_mario_vid`` folder, with ``mario_video_0_generated.avi`` for the generated video and ``mario_video_0_orig.avi`` for the original frames. diff --git a/docs/source/training.rst b/docs/source/training.rst index c9fceb05f..b06c754ba 100644 --- a/docs/source/training.rst +++ b/docs/source/training.rst @@ -196,3 +196,15 @@ Trains a consistency model to insert glasses onto faces. .. code:: bash python3 train.py --dataroot /path/to/data/noglasses2glasses_ffhq --checkpoints_dir /path/to/checkpoints --name noglasses2glasses --config_json examples/example_cm_noglasses2glasses.json + +************************************************************* + DDPM training for video generation with inpainting +************************************************************* + +Dataset: https://joligen.com/datasets/online_mario2sonic_full.tar + +Train a DDPM model to generate a sequence of frame images for inpainting, ensuring temporal consistency throughout the series of frames. + +.. code:: bash + + python3 train.py --dataroot /path/to/data/online_mario2sonic_full --checkpoints_dir /path/to/checkpoints --name mario_vid --config_json examples/example_ddpm_vid_mario.json diff --git a/models/base_model.py b/models/base_model.py index b9f465024..13462663c 100644 --- a/models/base_model.py +++ b/models/base_model.py @@ -768,7 +768,9 @@ def get_current_visuals(self, nb_imgs, phase="train", test_name=""): cur_visual[name] = getattr(self, name) visual_ret.append(cur_visual) if ( - self.opt.model_type != "cut" and self.opt.model_type != "cycle_gan" + self.opt.model_type != "cut" + and self.opt.model_type != "cycle_gan" + and not self.opt.G_netG == "unet_vid" ): # GANs have more outputs in practice, including semantics if i == nb_imgs - 1: break diff --git a/models/diffusion_networks.py b/models/diffusion_networks.py index 5e3450867..376112c2e 100644 --- a/models/diffusion_networks.py +++ b/models/diffusion_networks.py @@ -7,6 +7,8 @@ UViT, UNetGeneratorRefAttn, ) +from .modules.unet_generator_attn.unet_generator_attn_vid import UNetVid + from .modules.hdit.hdit import HDiT, HDiTConfig from .modules.palette_denoise_fn import PaletteDenoiseFn @@ -124,6 +126,33 @@ def define_G( freq_space=train_feat_wavelet, ) + elif G_netG == "unet_vid": + if model_prior_321_backwardcompatibility: + cond_embed_dim = G_ngf * 4 + else: + cond_embed_dim = alg_diffusion_cond_embed_dim + + model = UNetVid( + image_size=data_crop_size, + in_channel=in_channel, + inner_channel=G_ngf, + out_channel=model_output_nc, + res_blocks=G_unet_mha_res_blocks, + attn_res=G_unet_mha_attn_res, + num_heads=G_unet_mha_num_heads, + num_head_channels=G_unet_mha_num_head_channels, + tanh=False, + dropout=G_dropout, + n_timestep_train=G_diff_n_timestep_train, + n_timestep_test=G_diff_n_timestep_test, + channel_mults=G_unet_mha_channel_mults, + norm=G_unet_mha_norm_layer, + group_norm_size=G_unet_mha_group_norm_size, + efficient=G_unet_mha_vit_efficient, + cond_embed_dim=cond_embed_dim, + freq_space=train_feat_wavelet, + ) + elif G_netG == "unet_mha_ref_attn": cond_embed_dim = alg_diffusion_cond_embed_dim diff --git a/models/modules/diffusion_generator.py b/models/modules/diffusion_generator.py index 981596ab5..bf4ebb997 100644 --- a/models/modules/diffusion_generator.py +++ b/models/modules/diffusion_generator.py @@ -194,8 +194,10 @@ def p_mean_variance( embed_noise_level = self.compute_gammas(noise_level) - input = torch.cat([y_cond, y_t], dim=1) - + if len(y_cond.shape) == 5 and len(y_t.shape) == 5: + input = torch.cat([y_cond, y_t], dim=2) + else: + input = torch.cat([y_cond, y_t], dim=1) if guidance_scale > 0.0 and phase == "test": y_0_hat_uncond = predict_start_from_noise( self.denoise_fn.model, @@ -451,9 +453,10 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0): if mask is not None: temp_mask = torch.clamp(mask, min=0.0, max=1.0) y_noisy = y_noisy * temp_mask + (1.0 - temp_mask) * y_0 - - input = torch.cat([y_cond, y_noisy], dim=1) - + if len(y_cond.shape) == 5 and len(y_noisy.shape) == 5: + input = torch.cat([y_cond, y_noisy], dim=2) + else: + input = torch.cat([y_cond, y_noisy], dim=1) noise_hat = self.denoise_fn( input, embed_sample_gammas, cls=cls, mask=mask, ref=ref ) diff --git a/models/modules/unet_generator_attn/unet_attn_utils.py b/models/modules/unet_generator_attn/unet_attn_utils.py index e9046a910..06955ba15 100644 --- a/models/modules/unet_generator_attn/unet_attn_utils.py +++ b/models/modules/unet_generator_attn/unet_attn_utils.py @@ -6,10 +6,38 @@ import numpy as np import torch import torch.nn as nn - +from einops import rearrange from .switchable_norm import SwitchNorm2d +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[1] + input_channels = x.shape[2] + x = rearrange(x, "b f c h w -> (b f) c h w") + expected_channels = self.in_channels + if input_channels != expected_channels: + raise ValueError( + f"Expected input channels: {expected_channels}, but got: {input_channels}" + ) + + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b f c h w", f=video_length) + + return x + + +class InflatedGroupNorm(nn.GroupNorm): + def forward(self, x): + video_length = x.shape[1] + + x = rearrange(x, "b f c h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b f c h w", f=video_length) + + return x + + class GroupNorm(nn.Module): def __init__(self, group_size, channels): super().__init__() diff --git a/models/modules/unet_generator_attn/unet_generator_attn_vid.py b/models/modules/unet_generator_attn/unet_generator_attn_vid.py new file mode 100644 index 000000000..333fb4f57 --- /dev/null +++ b/models/modules/unet_generator_attn/unet_generator_attn_vid.py @@ -0,0 +1,1404 @@ +from abc import abstractmethod +import math +from typing import Optional +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + +from einops import rearrange +from einops.layers.torch import Rearrange + +from positional_encodings.torch_encodings import PositionalEncoding1D, Summer + +from .unet_attn_utils import ( + checkpoint, + zero_module, + normalization, + normalization1d, + count_flops_attn, + InflatedConv3d, + InflatedGroupNorm, +) + +from models.modules.diffusion_utils import gamma_embedding +import xformers, xformers.ops + + +class EmbedBlock(nn.Module): + """ + Any module where forward() takes embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` embeddings. + """ + + +class EmbedSequential(nn.Sequential, EmbedBlock): + """ + A sequential module that passes embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb): + for layer in self: + if isinstance(layer, EmbedBlock): + x = layer(x, emb) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + + """ + + def __init__( + self, channels, use_conv, out_channel=None, efficient=False, freq_space=False + ): + super().__init__() + self.channels = channels + self.out_channel = out_channel or channels + self.use_conv = use_conv + self.freq_space = freq_space + + if freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + self.channels = int(self.channels / 4) + self.out_channel = int(self.out_channel / 4) + + if use_conv: + self.conv = nn.Conv2d(self.channels, self.out_channel, 3, padding=1) + self.efficient = efficient + + def forward(self, x): + if self.freq_space: + x = self.iwt(x) + + assert x.shape[1] == self.channels + if not self.efficient: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + if self.efficient: # if efficient, we do the interpolation after the conv + x = F.interpolate(x, scale_factor=2, mode="nearest") + + if self.freq_space: + x = self.dwt(x) + + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + """ + + def __init__(self, channels, use_conv, out_channel=None, freq_space=False): + super().__init__() + self.channels = channels + self.out_channel = out_channel or channels + self.use_conv = use_conv + self.freq_space = freq_space + + if self.freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + self.channels = int(self.channels / 4) + self.out_channel = int(self.out_channel / 4) + + stride = 2 + if use_conv: + self.op = nn.Conv2d( + self.channels, self.out_channel, 3, stride=stride, padding=1 + ) + else: + assert self.channels == self.out_channel + self.op = nn.AvgPool2d(kernel_size=stride, stride=stride) + + def forward(self, x): + if self.freq_space: + x = self.iwt(x) + + assert x.shape[1] == self.channels + opx = self.op(x) + + if self.freq_space: + opx = self.dwt(opx) + + return opx + + +class ResBlock(EmbedBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of embedding channels. + :param dropout: the rate of dropout. + :param out_channel: if specified, the number of out channels. + :param use_conv: if True and out_channel is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + norm, + out_channel=None, + use_conv=False, + use_scale_shift_norm=False, + use_checkpoint=False, + up=False, + down=False, + efficient=False, + freq_space=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channel = out_channel or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + self.up = up + self.efficient = efficient + self.freq_space = freq_space + self.updown = up or down + + self.in_layers = nn.Sequential( + normalization(self.channels, norm), + torch.nn.SiLU(), + nn.Conv2d(self.channels, self.out_channel, 3, padding=1), + ) + + if up: + self.h_upd = Upsample(channels, False, freq_space=self.freq_space) + self.x_upd = Upsample(channels, False, freq_space=self.freq_space) + elif down: + self.h_upd = Downsample(channels, False, freq_space=self.freq_space) + self.x_upd = Downsample(channels, False, freq_space=self.freq_space) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + torch.nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channel if use_scale_shift_norm else self.out_channel, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channel, norm), + torch.nn.SiLU(), + nn.Dropout(p=dropout), + zero_module(nn.Conv2d(self.out_channel, self.out_channel, 3, padding=1)), + ) + + if self.out_channel == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = nn.Conv2d(channels, self.out_channel, 3, padding=1) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channel, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + b, f, c, h, w = x.shape + x = x.contiguous().view(b * f, c, h, w) + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + + h = in_rest(x) + + if self.efficient and self.up: + h = in_conv(h) + h = self.h_upd(h) + x = self.x_upd(x) + else: + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out.unsqueeze(-1) + # emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + + skipw = 1.0 + if self.efficient: + skipw = 1.0 / math.sqrt(2) + output = self.skip_connection(x) + h + bf, c, h, w = output.shape + b = bf // f + f = bf // b + return output.view(b, f, c, h, w) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + use_transformer=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.use_transformer = use_transformer + self.norm = normalization1d(channels) + self.qkv = nn.Conv1d(channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(nn.Conv1d(channels, channels, 1)) + + def forward(self, x): + return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) + + def _forward(self, x): + b, f, c, *spatial = x.shape + if self.use_transformer: + x = x.reshape(b, -1, c) + else: + x = x.reshape(b * f, c, -1) + qkv = self.qkv(self.norm(x)) + h = self.attention(qkv) + h = self.proj_out(h) + return (x + h).reshape(b, f, c, *spatial) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = torch.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + a = torch.einsum( + "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length) + ) + return a.reshape(bs, -1, length) + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +###################motion_module +# from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py#L187 + + +class MotionModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads=8, + num_transformer_block=2, + attention_block_types=("Temporal_Self", "Temporal_Self"), + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + temporal_attention_dim_div=1, + zero_initialize=True, + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels + // num_attention_heads + // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module( + self.temporal_transformer.proj_out + ) + + def forward( + self, + input_tensor, + temb=None, + encoder_hidden_states=None, + attention_mask=None, + anchor_frame_idx=None, + ): + hidden_states = input_tensor + hidden_states = self.temporal_transformer( + hidden_states, encoder_hidden_states, attention_mask + ) + + output = hidden_states + return output + + +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + num_layers, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=768, + activation_fn="geglu", + attention_bias=False, + upcast_attention=False, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm( + num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + assert ( + hidden_states.dim() == 5 + ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "b f c h w -> (b f) c h w") + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape( + batch, height * weight, inner_dim + ) + hidden_states = self.proj_in(hidden_states) + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + video_length=video_length, + ) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim) + .permute(0, 3, 1, 2) + .contiguous() + ) + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b f c h w", f=video_length) + + return output + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types=( + "Temporal_Self", + "Temporal_Self", + ), + dropout=0.0, + norm_num_groups=32, + cross_attention_dim=768, + activation_fn="geglu", + attention_bias=False, + upcast_attention=False, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + ): + super().__init__() + + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + VersatileAttention( + attention_mode=block_name.split("_")[0], + cross_attention_dim=( + cross_attention_dim if block_name.endswith("_Cross") else None + ), + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + ): + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + hidden_states = ( + attention_block( + norm_hidden_states, + encoder_hidden_states=( + encoder_hidden_states + if attention_block.is_cross_attention + else None + ), + video_length=video_length, + ) + + hidden_states + ) + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=inner_dim, + num_groups=norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size * head_size, seq_len, dim // head_size + ) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError( + f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}." + ) + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim( + encoder_hidden_states_key_proj + ) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim( + encoder_hidden_states_value_proj + ) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers( + query, key, value, attention_mask + ) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention( + query, key, value, sequence_length, dim, attention_mask + ) + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention( + self, query, key, value, sequence_length, dim, attention_mask + ): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), + device=query.device, + dtype=query.dtype, + ) + slice_size = ( + self._slice_size if self._slice_size is not None else hidden_states.shape[0] + ) + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty( + slice_size, + query.shape[1], + key.shape[1], + dtype=query_slice.dtype, + device=query.device, + ), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask + ) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, dropout=0.0, max_len=24): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + x = x + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class VersatileAttention(CrossAttention): + def __init__( + self, + attention_mode=None, + cross_frame_attention_mode=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + assert attention_mode == "Temporal" + + self.attention_mode = attention_mode + self.is_cross_attention = kwargs["cross_attention_dim"] is not None + + self.pos_encoder = ( + PositionalEncoding( + kwargs["query_dim"], + dropout=0.0, + max_len=temporal_position_encoding_max_len, + ) + if (temporal_position_encoding and attention_mode == "Temporal") + else None + ) + + def extra_repr(self): + return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + + def forward( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + video_length=None, + ): + batch_size, sequence_length, _ = hidden_states.shape + + if self.attention_mode == "Temporal": + d = hidden_states.shape[1] + hidden_states = rearrange( + hidden_states, "(b f) d c -> (b d) f c", f=video_length + ) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = ( + repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) + if encoder_hidden_states is not None + else encoder_hidden_states + ) + else: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = ( + encoder_hidden_states + if encoder_hidden_states is not None + else hidden_states + ) + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers( + query, key, value, attention_mask + ) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention( + query, key, value, sequence_length, dim, attention_mask + ) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + if self.attention_mode == "Temporal": + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class UNetVid(nn.Module): + """ + The full UNet model with attention and embedding. + :param in_channel: channels in the input Tensor, for image colorization : Y_channels + X_channels . + :param inner_channel: base channel count for the model. + :param out_channel: channels in the output Tensor. + :param res_blocks: number of residual blocks per downsample. + :param attn_res: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mults: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channel, + inner_channel, + out_channel, + res_blocks, + attn_res, + tanh, + n_timestep_train, + n_timestep_test, + norm, + group_norm_size, + cond_embed_dim, + dropout=0, + channel_mults=(1, 2, 4, 8), + conv_resample=True, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=True, + resblock_updown=True, + use_new_attention_order=True, # False, + efficient=False, + freq_space=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.image_size = image_size + self.in_channel = in_channel + self.inner_channel = inner_channel + self.out_channel = out_channel + self.res_blocks = res_blocks + self.attn_res = attn_res + self.dropout = dropout + self.zero_dropout = 0.0 + self.channel_mults = channel_mults + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.freq_space = freq_space + + if self.freq_space: + from ..freq_utils import InverseHaarTransform, HaarTransform + + self.iwt = InverseHaarTransform(3) + self.dwt = HaarTransform(3) + in_channel *= 4 + out_channel *= 4 + + if norm == "groupnorm": + norm = norm + str(group_norm_size) + + self.cond_embed_dim = cond_embed_dim + + ch = input_ch = int(channel_mults[0] * self.inner_channel) + self.input_blocks = nn.ModuleList( + [EmbedSequential(InflatedConv3d(in_channel, ch, 3, padding=1))] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + for level, mult in enumerate(channel_mults): + for _ in range(res_blocks[level]): + layers = [ + ResBlock( + ch, + self.cond_embed_dim, + self.zero_dropout, + out_channel=int(mult * self.inner_channel), + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ), + ] + ch = int(mult * self.inner_channel) + if ds in attn_res: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + layers.append( + MotionModule( + in_channels=ch, + num_attention_heads=8, + num_transformer_block=2, + attention_block_types=("Temporal_self", "Temporal_Self"), + cross_frame_attention_mode=None, + temporal_position_encoding=True, + temporal_position_encoding_max_len=24, + temporal_attention_dim_div=1, + zero_initialize=True, + ) + ) + self.input_blocks.append(EmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mults) - 1: + out_ch = ch + self.input_blocks.append( + EmbedSequential( + ResBlock( + ch, + self.cond_embed_dim, + self.zero_dropout, + out_channel=out_ch, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ) + if resblock_updown + else Downsample( + ch, + conv_resample, + out_channel=out_ch, + freq_space=self.freq_space, + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = EmbedSequential( + ResBlock( + ch, + self.cond_embed_dim, + dropout, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + self.cond_embed_dim, + dropout, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mults))[::-1]: + for i in range(res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + self.cond_embed_dim, + self.zero_dropout, + out_channel=int(self.inner_channel * mult), + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ), + ] + ch = int(self.inner_channel * mult) + if ds in attn_res: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + + layers.append( + MotionModule( + in_channels=ch, + num_attention_heads=8, + num_transformer_block=2, + attention_block_types=("Temporal_self", "Temporal_Self"), + cross_frame_attention_mode=None, + temporal_position_encoding=True, + temporal_position_encoding_max_len=24, + temporal_attention_dim_div=1, + zero_initialize=True, + ) + ) + + if level and i == res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + self.cond_embed_dim, + self.zero_dropout, + out_channel=out_ch, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + norm=norm, + efficient=efficient, + freq_space=self.freq_space, + ) + if resblock_updown + else Upsample( + ch, + conv_resample, + out_channel=out_ch, + freq_space=self.freq_space, + ) + ) + ds //= 2 + self.output_blocks.append(EmbedSequential(*layers)) + self._feature_size += ch + + if tanh: + self.out = nn.Sequential( + normalization(ch, norm), + nn.Conv2d(input_ch, out_channel, 3, padding=1), + nn.Tanh(), + ) + else: + self.out = nn.Sequential( + normalization(ch, norm), + torch.nn.SiLU(), + zero_module(nn.Conv2d(input_ch, out_channel, 3, padding=1)), + ) + + self.beta_schedule = { + "train": { + "schedule": "linear", + "n_timestep": n_timestep_train, + "linear_start": 1e-6, + "linear_end": 0.01, + }, + "test": { + "schedule": "linear", + "n_timestep": n_timestep_test, + "linear_start": 1e-4, + "linear_end": 0.09, + }, + } + + def compute_feats(self, input, embed_gammas): + if embed_gammas is None: + # Only for GAN + b = (input.shape[0], self.cond_embed_dim) + embed_gammas = torch.ones(b).to(input.device) + + emb = embed_gammas + + hs = [] + + h = input.type(torch.float32) + + if self.freq_space: + h = self.dwt(h) + + for module in self.input_blocks: + h = module(h, emb) + hs.append(h) + h = self.middle_block(h, emb) + + outs, feats = h, hs + return outs, feats, emb + + def forward(self, input, embed_gammas=None): + + h, hs, emb = self.compute_feats(input, embed_gammas=embed_gammas) + for i, module in enumerate(self.output_blocks): + h = torch.cat([h, hs.pop()], dim=2) + h = module(h, emb) + h = h.type(input.dtype) + b, f, c, h_dim, w_dim = h.shape + h = h.reshape(b * f, c, h_dim, w_dim) + + outh = self.out(h) + + if self.freq_space: + outh = self.iwt(outh) + outh = outh.reshape(b, f, -1, h_dim, w_dim) + + return outh + + def get_feats(self, input, extract_layer_ids): + _, hs, _ = self.compute_feats(input, embed_gammas=None) + feats = [] + + for i, feat in enumerate(hs): + if i in extract_layer_ids: + feats.append(feat) + + return feats diff --git a/models/palette_model.py b/models/palette_model.py index 447a27ae6..f66dd19b8 100644 --- a/models/palette_model.py +++ b/models/palette_model.py @@ -103,12 +103,14 @@ def __init__(self, opt, rank): batch_size = self.opt.train_batch_size else: batch_size = self.opt.test_batch_size - max_visual_outputs = min( max(self.opt.train_batch_size, self.opt.num_test_images), self.opt.output_num_images, ) - + if self.opt.G_netG == "unet_vid": + max_visual_outputs = min( + self.opt.output_num_images, self.opt.data_temporal_number_frames + ) self.num_classes = max( self.opt.f_s_semantic_nclasses, self.opt.cls_semantic_nclasses ) @@ -251,8 +253,12 @@ def set_input(self, data): len(data["A"].to(self.device).shape) == 5 ): # we're using temporal successive frames self.previous_frame = data["A"].to(self.device)[:, 0] - self.y_t = data["A"].to(self.device)[:, 1] - self.gt_image = data["B"].to(self.device)[:, 1] + if self.opt.G_netG == "unet_vid": + self.y_t = data["A"].to(self.device) + self.gt_image = data["B"].to(self.device) + else: + self.y_t = data["A"].to(self.device)[:, 1] + self.gt_image = data["B"].to(self.device)[:, 1] if self.task == "inpainting": self.previous_frame_mask = data["B_label_mask"].to(self.device)[:, 0] ### Note: the sam related stuff should eventually go into the dataloader @@ -278,7 +284,10 @@ def set_input(self, data): self.mask[self.mask == 2] = 0 self.y_t = fill_mask_with_random(self.gt_image, self.mask, -1) else: - self.mask = data["B_label_mask"].to(self.device)[:, 1] + if self.opt.G_netG == "unet_vid": + self.mask = data["B_label_mask"].to(self.device) + else: + self.mask = data["B_label_mask"].to(self.device)[:, 1] else: self.mask = None else: @@ -461,6 +470,7 @@ def compute_palette_loss(self): min_snr_loss_weight * mask_binary * noise, min_snr_loss_weight * mask_binary * noise_hat, ) + else: loss = self.loss_fn( min_snr_loss_weight * noise, min_snr_loss_weight * noise_hat @@ -478,6 +488,7 @@ def compute_palette_loss(self): self.loss_G_tot = self.opt.alg_diffusion_lambda_G * loss def inference(self, nb_imgs, offset=0): + if hasattr(self.netG_A, "module"): netG = self.netG_A.module else: @@ -611,18 +622,29 @@ def inference(self, nb_imgs, offset=0): self.output, self.visuals = netG.restoration( y_cond=self.cond_image[:nb_imgs], sample_num=self.sample_num ) + if not self.opt.G_netG == "unet_vid": + for name in self.gen_visual_names: + whole_tensor = getattr(self, name[:-1]) # i.e. self.output, ... + for k in range(min(nb_imgs, self.get_current_batch_size())): + cur_name = name + str(offset + k) + cur_tensor = whole_tensor[k : k + 1] + if "mask" in name: + cur_tensor = cur_tensor.squeeze(0) + setattr(self, cur_name, cur_tensor) + for k in range(min(nb_imgs, self.get_current_batch_size())): + self.fake_B_pool.query(self.visuals[k : k + 1]) - for name in self.gen_visual_names: - whole_tensor = getattr(self, name[:-1]) # i.e. self.output, ... + else: + for name in self.gen_visual_names: + whole_tensor = getattr(self, name[:-1]) # i.e. self.output, ... + for k in range(self.opt.data_temporal_number_frames): + cur_name = name + str(offset + k) + cur_tensor = whole_tensor[:, k, :, :, :] + if "mask" in name: + cur_tensor = cur_tensor.squeeze(0) + setattr(self, cur_name, cur_tensor) for k in range(min(nb_imgs, self.get_current_batch_size())): - cur_name = name + str(offset + k) - cur_tensor = whole_tensor[k : k + 1] - if "mask" in name: - cur_tensor = cur_tensor.squeeze(0) - setattr(self, cur_name, cur_tensor) - - for k in range(min(nb_imgs, self.get_current_batch_size())): - self.fake_B_pool.query(self.visuals[k : k + 1]) + self.fake_B_pool.query(self.visuals[k : k + 1, :, :, :, :]) if len(self.opt.gpu_ids) > 1 and self.opt.G_unet_mha_norm_layer == "batchnorm": netG = torch.nn.SyncBatchNorm.convert_sync_batchnorm(netG) diff --git a/options/common_options.py b/options/common_options.py index a4bc5b481..8e2fc1400 100644 --- a/options/common_options.py +++ b/options/common_options.py @@ -210,6 +210,7 @@ def initialize(self, parser): "dit", "hdit", "img2img_turbo", + "unet_vid", ], help="specify generator architecture", ) @@ -658,6 +659,7 @@ def initialize(self, parser): "aligned", "nuplet_unaligned_labeled_mask", "temporal_labeled_mask_online", + "self_supervised_temporal_labeled_mask_online", "self_supervised_temporal", "single", "unaligned_labeled_mask_ref", diff --git a/scripts/gen_vid_diffusion.py b/scripts/gen_vid_diffusion.py new file mode 100644 index 000000000..ffbabd2d5 --- /dev/null +++ b/scripts/gen_vid_diffusion.py @@ -0,0 +1,968 @@ +import argparse +import copy +import json +import math +import os +import random +import re +import sys +import tempfile +import warnings +import logging + +import cv2 +import numpy as np +import torch +import torchvision.transforms as T +from PIL import Image +import torchvision +from torchvision import transforms +from torchvision.utils import save_image +from tqdm import tqdm + +import re +from collections import defaultdict +from PIL import Image + +jg_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../") +sys.path.append(jg_dir) +from segment_anything import SamPredictor + +from data.online_creation import crop_image, fill_mask_with_color, fill_mask_with_random +from models import diffusion_networks +from models.modules.diffusion_utils import set_new_noise_schedule +from models.modules.sam.sam_inference import ( + compute_mask_with_sam, + init_sam_net, + load_sam_weight, + predict_sam_mask, +) +from models.modules.utils import download_sam_weight +from options.inference_diffusion_options import InferenceDiffusionOptions +from options.train_options import TrainOptions +from util.mask_generation import ( + fill_img_with_canny, + fill_img_with_depth, + fill_img_with_hed, + fill_img_with_hough, + fill_img_with_sam, + fill_img_with_sketch, +) +from util.script import get_override_options_names +from util.util import flatten_json + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + return [atoi(c) for c in re.split("(\d+)", text)] + + +def load_model( + model_in_dir, + model_in_filename, + device, + sampling_steps, + sampling_method, + model_prior_321_backwardcompatibility, +): + train_json_path = model_in_dir + "/train_config.json" + with open(train_json_path, "r") as jsonf: + train_json = json.load(jsonf) + + opt = TrainOptions().parse_json(train_json) + opt.jg_dir = jg_dir + if opt.data_online_creation_mask_random_offset_A != [0.0]: + warnings.warn( + f"disabling data_online_creation_mask_random_offset_A in inference mode" + ) + opt.data_online_creation_mask_random_offset_A = [0.0] + + opt.model_prior_321_backwardcompatibility = model_prior_321_backwardcompatibility + if opt.model_type in ["cm", "cm_gan"]: + opt.alg_palette_sampling_method = sampling_method + opt.alg_diffusion_cond_embed_dim = 256 + model = diffusion_networks.define_G(**vars(opt)) + model.eval() + + # handle old models + weights = torch.load( + os.path.join(model_in_dir, model_in_filename), map_location=torch.device(device) + ) + if opt.model_prior_321_backwardcompatibility: + weights = { + k.replace("denoise_fn.cond_embed", "cond_embed"): v + for k, v in weights.items() + } + if not any(k.startswith("denoise_fn.model") for k in weights.keys()): + weights = { + k.replace("denoise_fn", "denoise_fn.model"): v for k, v in weights.items() + } + if not any(k.startswith("denoise_fn.netl_embedder_") for k in weights.keys()): + weights = { + k.replace("l_embedder_", "denoise_fn.netl_embedder_"): v + for k, v in weights.items() + } + model.load_state_dict(weights, strict=False) + + # sampling steps + if sampling_steps > 0: + model.denoise_fn.model.beta_schedule["test"]["n_timestep"] = sampling_steps + set_new_noise_schedule(model.denoise_fn.model, "test") + + if opt.model_type == "palette": + model.set_new_sampling_method(sampling_method) + + if opt.alg_diffusion_task == "pix2pix": + opt.alg_diffusion_cond_image_creation = "pix2pix" + + model = model.to(device) + return model, opt + + +def to_np(img): + img = img.detach().data.cpu().float().numpy()[0] + img = (np.transpose(img, (1, 2, 0)) + 1) / 2.0 * 255.0 + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + return img + + +def cond_augment(cond, rotation, persp_horizontal, persp_vertical): + cond = Image.fromarray(cond) + cond = transforms.RandomRotation(rotation, expand=True)(cond) + w, h = cond.size + startpoints = [[0, 0], [w, 0], [w, h], [0, h]] + endpoints = copy.deepcopy(startpoints) + # horizontal perspective + d = h * persp_horizontal * random.random() + if random.choice([True, False]): + # seen from left + endpoints[1][1] += d + endpoints[2][1] -= d + else: + # seen from right + endpoints[0][1] += d + endpoints[3][1] -= d + # vertical perspective + d = h * persp_vertical * random.random() + if random.choice([True, False]): + # seen from above + endpoints[3][0] += d + endpoints[2][0] -= d + else: + # seen from below + endpoints[0][0] += d + endpoints[1][0] -= d + cond = cond.crop(cond.getbbox()) + cond = transforms.functional.perspective(cond, startpoints, endpoints) + return np.array(cond) + + +def generate( + seed, + model_in_file, + lmodel, + lopt, + cpu, + gpuid, + sampling_steps, + img_in, + mask_in, + ref_in, + bbox_in, + cond_in, + cond_keep_ratio, + bbox_width_factor, + bbox_height_factor, + bbox_ref_id, + crop_width, + crop_height, + img_width, + img_height, + dir_out, + write, + previous_frame, + name, + mask_delta, + mask_square, + sampling_method, + cond_rotation, + cond_persp_horizontal, + cond_persp_vertical, + alg_diffusion_cond_image_creation, + alg_diffusion_sketch_canny_thresholds, + cls, + alg_diffusion_super_resolution_downsample, + alg_diffusion_guidance_scale, + data_refined_mask, + min_crop_bbox_ratio, + alg_palette_ddim_num_steps, + alg_palette_ddim_eta, + model_prior_321_backwardcompatibility, + logger, + iteration, + nb_samples, + **unused_options, +): + PROGRESS_NUM_STEPS = 4 + # seed + if seed >= 0: + torch.manual_seed(seed) + + if not cpu: + device = torch.device("cuda:" + str(gpuid)) + else: + device = torch.device("cpu") + + # loading model + if lmodel is None: + model, opt = load_model( + os.path.dirname(model_in_file), + os.path.basename(model_in_file), + device, + sampling_steps, + sampling_method, + model_prior_321_backwardcompatibility, + ) + else: + model = lmodel + opt = lopt + + if alg_diffusion_cond_image_creation is not None: + opt.alg_diffusion_cond_image_creation = alg_diffusion_cond_image_creation + + if logger: + logger.info( + f"[it: %i/%i] - [1/%i] model loaded" + % (iteration, nb_samples, PROGRESS_NUM_STEPS) + ) + + conditioning = opt.alg_diffusion_cond_embed + + for i, delta_values in enumerate(mask_delta): + if len(delta_values) == 1: + mask_delta[i].append(delta_values[0]) + # Load image + with open(args.paths_file, "r") as file: + lines = file.readlines() + paths_img = [] + paths_bbox = [] + + image_bbox_pairs = [] + for line in lines: + parts = line.strip().split() + image_bbox_pairs.append((parts[0], parts[1])) + + image_bbox_pairs.sort(key=lambda x: natural_keys(x[0])) + startframe = random.randint(100, 10000) + limited_image_bbox_pairs = image_bbox_pairs[ + startframe : startframe + opt.data_temporal_number_frames + 10 + ] + limited_paths_img = [pair[0] for pair in limited_image_bbox_pairs] + limited_paths_bbox = [pair[1] for pair in limited_image_bbox_pairs] + + cond_image_list = [] + y_t_list = [] + y0_tensor_list = [] + mask_list = [] + img_orig_list = [] + bbox_select_list = [] + img_tensor_list = [] + out_img_list = [] + + for img_path, bbox_path in zip(limited_paths_img, limited_paths_bbox): + img_in = os.path.join(args.data_root, img_path) + maskin = os.path.join(args.data_root, bbox_path) + bbox_select = None + # reading image + if opt.data_image_bits > 8: + img = Image.open(img_in) # we use PIL + local_img_width, local_img_height = img.size + else: + img = cv2.imread(img_in) + img_orig = img.copy() + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + local_img_width, local_img_height = img.shape[:2] + # reading the mask + mask = None + if mask_in: + mask = cv2.imread(mask_in, 0) + + # reading reference image + ref = None + if ref_in: + ref = cv2.imread(ref_in) + ref_orig = ref.copy() + ref = cv2.cvtColor(ref, cv2.COLOR_BGR2RGB) + + bboxes = [] + if bbox_in: + # mask = np.zeros(img.shape[:2], dtype=np.uint8) + with open(bbox_in, "r") as bboxf: + for line in bboxf: + elts = line.rstrip().split() + bboxes.append( + [int(elts[1]), int(elts[2]), int(elts[3]), int(elts[4])] + ) + if conditioning: + if cls <= 0: + cls = int(elts[0]) + else: + cls = 1 + + if bbox_ref_id == -1: + # sample a bbox here since we are calling crop_image multiple times + bbox_idx = random.choice(range(len(bboxes))) + else: + bbox_idx = bbox_ref_id + + if crop_width > 0 or crop_height > 0: + hc_width = int(crop_width / 2) + hc_height = int(crop_height / 2) + bbox_orig = bboxes[bbox_idx] + if bbox_width_factor > 0.0: + bbox_orig[0] -= max(0, int(bbox_width_factor * bbox_orig[0])) + bbox_orig[2] += max(0, int(bbox_width_factor * bbox_orig[2])) + if bbox_height_factor > 0.0: + bbox_orig[1] -= max(0, int(bbox_height_factor * bbox_orig[1])) + bbox_orig[3] += max(0, int(bbox_height_factor * bbox_orig[3])) + + # TODO: unused? + bbox_select = bbox_orig.copy() + bbox_select[0] -= max(0, hc_width) + bbox_select[0] = max(0, bbox_select[0]) + bbox_select[1] -= max(0, hc_height) + bbox_select[1] = max(0, bbox_select[1]) + bbox_select[2] += hc_width + bbox_select[2] = min(img.shape[1], bbox_select[2]) + bbox_select[3] += hc_height + bbox_select[3] = min(img.shape[0], bbox_select[3]) + else: + bbox = bboxes[bbox_idx] + + crop_coordinates = crop_image( + img_path=img_in, + bbox_path=bbox_in, + mask_delta=mask_delta, # =opt.data_online_creation_mask_delta_A, + mask_random_offset=opt.data_online_creation_mask_random_offset_A, + crop_delta=0, + mask_square=mask_square, # opt.data_online_creation_mask_square_A, + crop_dim=opt.data_online_creation_crop_size_A, # we use the average crop_dim + output_dim=opt.data_load_size, + context_pixels=opt.data_online_context_pixels, + load_size=opt.data_online_creation_load_size_A, + get_crop_coordinates=True, + crop_center=True, + bbox_ref_id=bbox_idx, + min_crop_bbox_ratio=min_crop_bbox_ratio, + ) + + img, mask, ref_bbox, bbox_ref_id = crop_image( + img_path=img_in, + bbox_path=bbox_in, + mask_delta=mask_delta, # opt.data_online_creation_mask_delta_A, + mask_random_offset=opt.data_online_creation_mask_random_offset_A, + crop_delta=0, + mask_square=mask_square, # opt.data_online_creation_mask_square_A, + crop_dim=opt.data_online_creation_crop_size_A, # we use the average crop_dim + output_dim=opt.data_load_size, + context_pixels=opt.data_online_context_pixels, + load_size=opt.data_online_creation_load_size_A, + crop_coordinates=crop_coordinates, + crop_center=True, + bbox_ref_id=bbox_idx, + override_class=cls, + ) + + x_crop, y_crop, crop_size = crop_coordinates + + bbox = bboxes[bbox_idx] + + bbox_select = bbox.copy() + if len(mask_delta) == 1: + index_cls = 0 + else: + index_cls = int(cls) - 1 + + if not isinstance(mask_delta[0][0], float): + bbox_select[0] -= mask_delta[index_cls][0] + bbox_select[1] -= mask_delta[index_cls][1] + bbox_select[2] += mask_delta[index_cls][0] + bbox_select[3] += mask_delta[index_cls][1] + else: + bbox_select[0] *= 1 + mask_delta[index_cls][0] + bbox_select[1] *= 1 + mask_delta[index_cls][1] + bbox_select[2] *= 1 + mask_delta[index_cls][0] + bbox_select[3] *= 1 + mask_delta[index_cls][1] + + if mask_square: + sdiff = (bbox_select[2] - bbox_select[0]) - ( + bbox_select[3] - bbox_select[1] + ) # (xmax - xmin) - (ymax - ymin) + if sdiff > 0: + bbox_select[3] += int(sdiff / 2) + bbox_select[1] -= int(sdiff / 2) + else: + bbox_select[2] += -int(sdiff / 2) + bbox_select[0] -= -int(sdiff / 2) + + bbox_select[1] += y_crop + bbox_select[0] += x_crop + + bbox_select[3] = bbox_select[1] + crop_size + bbox_select[2] = bbox_select[0] + crop_size + + bbox_select[1] -= opt.data_online_context_pixels + bbox_select[0] -= opt.data_online_context_pixels + + bbox_select[3] += opt.data_online_context_pixels + bbox_select[2] += opt.data_online_context_pixels + + img, mask = np.array(img), np.array(mask) + + if img_width > 0 and img_height > 0: + if img_height != local_img_height or img_width != local_img_width: + if opt.data_image_bits > 8: + print( + "Requested image size differs from training crop size, resizing is not supported for images with more than 8 bits per channel" + ) + exit(1) + img = cv2.resize(img, (img_width, img_height)) + if mask is not None: + mask = cv2.resize(mask, (img_width, img_height)) + if ref is not None: + ref = cv2.resize(ref, (img_width, img_height)) + + if logger: + logger.info( + f"[it: %i/%i] - [2/%i] image loaded" + % (iteration, nb_samples, PROGRESS_NUM_STEPS) + ) + + # insert cond image into original image + generated_bbox = None + if cond_in: + generated_bbox = bbox + # fill the mask with cond image + mask_bbox = Image.fromarray(mask).getbbox() + x0, y0, x1, y1 = mask_bbox + bbox_w = x1 - x0 + bbox_h = y1 - y0 + cond = cv2.imread(cond_in) + cond = cv2.cvtColor(cond, cv2.COLOR_RGB2BGR) + cond = cond_augment( + cond, + cond_rotation, + cond_persp_horizontal, + cond_persp_vertical, + ) + if cond_keep_ratio: + # pad cond image to match bbox aspect ratio + bbox_ratio = bbox_w / bbox_h + new_w = cond_w = cond.shape[1] + new_h = cond_h = cond.shape[0] + cond_ratio = cond_w / cond_h + if cond_ratio < bbox_ratio: + new_w = round(cond_w * bbox_ratio / cond_ratio) + elif cond_ratio > bbox_ratio: + new_h = round(cond_h * cond_ratio / bbox_ratio) + cond_pad = np.zeros((new_h, new_w, 3), dtype=np.uint8) + x = (new_w - cond_w) // 2 + y = (new_h - cond_h) // 2 + cond_pad[y : y + cond_h, x : x + cond_w] = cond + cond = cond_pad + # bbox inside mask + generated_bbox = [ + (x, y), + (x + cond_w, y + cond_h), + ] + # bbox inside crop + generated_bbox = [ + (x0 + x * bbox_w / cond.shape[1], y0 + y * bbox_h / cond.shape[0]) + for x, y in generated_bbox + ] + # bbox inside original image + real_width = min(img_orig.shape[1], bbox_select[2] - bbox_select[0]) + real_height = min(img_orig.shape[0], bbox_select[3] - bbox_select[1]) + generated_bbox = [ + ( + bbox_select[0] + x * real_width / img.shape[1], + bbox_select[1] + y * real_height / img.shape[0], + ) + for x, y in generated_bbox + ] + # round & flatten + generated_bbox = list(map(round, generated_bbox[0] + generated_bbox[1])) + + # add 1 pixel margin for sketches + cond = cv2.resize( + cond, (bbox_w - 2, bbox_h - 2), interpolation=cv2.INTER_CUBIC + ) + cond = np.pad(cond, [(1, 1), (1, 1), (0, 0)]) + img[y0:y1, x0:x1] = cond + + # preprocessing to torch + totensor = transforms.ToTensor() + if opt.data_image_bits > 8: + tranlist = [totensor, torchvision.transforms.v2.ToDtype(torch.float32)] + bit_scaling = 2**opt.data_image_bits - 1 + tranlist += [transforms.Lambda(lambda img: img * (1 / float(bit_scaling)))] + tranlist += [ + transforms.Normalize((0.5,), (0.5,)) + ] # XXX: > 8bit, mono canal only for now + else: + tranlist = [ + totensor, + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + # resize, + ] + + tran = transforms.Compose(tranlist) + img_tensor = tran(img).clone().detach() + + if mask is not None: + mask = torch.from_numpy(np.array(mask, dtype=np.int64)).unsqueeze(0) + """if crop_width > 0 and crop_height > 0: + mask = resize(mask).clone().detach()""" + if ref is not None: + ref = cv2.resize( + ref, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC + ) + ref_tensor = tran(ref).clone().detach() + + if not cpu: + img_tensor = img_tensor.to(device).clone().detach() + if mask is not None: + mask = mask.to(device).clone().detach() + if ref is not None: + ref_tensor = ref_tensor.to(device).clone().detach() + + if mask is not None: + if data_refined_mask or opt.data_refined_mask: + opt.f_s_weight_sam = "../" + opt.f_s_weight_sam + sam_model, _ = init_sam_net( + model_type_sam=opt.model_type_sam, + model_path=opt.f_s_weight_sam, + device=device, + ) + mask = compute_mask_with_sam( + img_tensor, mask, sam_model, device, batched=False + ).unsqueeze(0) + + if opt.data_inverted_mask: + mask[mask > 0] = 2 + mask[mask == 0] = 1 + mask[mask == 2] = 0 + + if opt.data_online_creation_rand_mask_A: + y_t = fill_mask_with_random( + img_tensor.clone().detach(), mask.clone().detach(), -1 + ) + elif opt.data_online_creation_color_mask_A: + y_t = fill_mask_with_color( + img_tensor.clone().detach(), mask.clone().detach(), {} + ) + else: + y_t = torch.randn_like(img_tensor) + + if opt.alg_diffusion_cond_image_creation == "previous_frame": + if previous_frame is not None: + if isinstance(previous_frame, str): + # load the previous frame + previous_frame = cv2.imread(previous_frame) + + previous_frame = cv2.cvtColor(previous_frame, cv2.COLOR_BGR2RGB) + previous_frame = previous_frame[ + bbox_select[1] : bbox_select[3], bbox_select[0] : bbox_select[2] + ] + previous_frame = cv2.resize( + previous_frame, (opt.data_load_size, opt.data_load_size) + ) + previous_frame = tran(previous_frame) + previous_frame = previous_frame.to(device).clone().detach().unsqueeze(0) + + cond_image = previous_frame + else: + cond_image = -1 * torch.ones_like(y_t.unsqueeze(0), device=y_t.device) + elif opt.alg_diffusion_cond_image_creation == "y_t": + if opt.model_type == "palette": + cond_image = y_t.unsqueeze(0) + else: + cond_image = None + elif opt.alg_diffusion_cond_image_creation == "sketch": + cond_image = fill_img_with_sketch( + img_tensor.unsqueeze(0), mask.unsqueeze(0) + ) + elif opt.alg_diffusion_cond_image_creation == "canny": + clamp = torch.clamp(mask, 0, 1) + if cond_in: + # mask the background to avoid canny edges around cond image + img_tensor_canny = clamp * img_tensor + clamp - 1 + else: + img_tensor_canny = img_tensor + cond_image = fill_img_with_canny( + img_tensor_canny.unsqueeze(0), + mask.unsqueeze(0), + low_threshold=alg_diffusion_sketch_canny_thresholds[0], + high_threshold=alg_diffusion_sketch_canny_thresholds[1], + low_threshold_random=-1, + high_threshold_random=-1, + ) + if cond_in: + # restore background + cond_image = cond_image * clamp + img_tensor * (1 - clamp) + elif opt.alg_diffusion_cond_image_creation == "sam": + opt.f_s_weight_sam = "../" + opt.f_s_weight_sam + if not os.path.exists(opt.f_s_weight_sam): + download_sam_weight(opt.f_s_weight_sam) + sam, _ = load_sam_weight(opt.f_s_weight_sam) + sam = sam.to(device) + cond_image = fill_img_with_sam( + img_tensor.unsqueeze(0), mask.unsqueeze(0), sam, opt + ) + elif opt.alg_diffusion_cond_image_creation == "hed": + cond_image = fill_img_with_hed(img_tensor.unsqueeze(0), mask.unsqueeze(0)) + elif opt.alg_diffusion_cond_image_creation == "hough": + cond_image = fill_img_with_hough(img_tensor.unsqueeze(0), mask.unsqueeze(0)) + elif opt.alg_diffusion_cond_image_creation == "depth": + cond_image = fill_img_with_depth(img_tensor.unsqueeze(0), mask.unsqueeze(0)) + elif opt.alg_diffusion_cond_image_creation == "low_res": + if alg_diffusion_super_resolution_downsample: + data_crop_size_low_res = int( + opt.data_crop_size / opt.alg_diffusion_super_resolution_scale + ) + transform_lr = T.Resize( + (data_crop_size_low_res, data_crop_size_low_res) + ) + cond_image = transform_lr(img_tensor.unsqueeze(0)).detach() + else: + cond_image = img_tensor.unsqueeze(0).clone().detach() + transform_hr = T.Resize((opt.data_crop_size, opt.data_crop_size)) + cond_image = transform_hr(cond_image).detach() + elif opt.alg_diffusion_cond_image_creation == "pix2pix": + # use same interpolation as get_transform + if (img_height > 0 and img_height != opt.data_crop_size) or ( + img_width > 0 and img_width != opt.data_crop_size + ): + transform_hr = T.Resize( + (img_height, img_width), interpolation=T.InterpolationMode.BICUBIC + ) + cond_image = transform_hr(img_tensor.unsqueeze(0)).detach() + else: + cond_image = img_tensor.unsqueeze(0).detach() + + if mask is None: + cl_mask = None + else: + cl_mask = mask.unsqueeze(0).clone().detach() + + y_t, cond_image, img_tensor, mask = ( + y_t.unsqueeze(0).clone().detach(), + cond_image.clone().detach() if cond_image is not None else None, + img_tensor.unsqueeze(0).clone().detach(), + cl_mask, + ) + if mask == None: + y0_tensor = None + else: + y0_tensor = img_tensor + + if opt.model_type == "palette": + if "class" in model.denoise_fn.conditioning: + cls_tensor = torch.ones(1, dtype=torch.int64, device=device) * cls + else: + cls_tensor = None + if ref is not None: + ref_tensor = ref_tensor.unsqueeze(0) + else: + ref_tensor = None + + cond_image_list.append(cond_image) + y_t_list.append(y_t) + y0_tensor_list.append(y0_tensor) + mask_list.append(mask) + bbox_select_list.append(bbox_select) + img_orig_list.append(img_orig) + img_tensor_list.append(img_tensor) + + cond_image = torch.stack(cond_image_list, dim=0).permute(1, 0, 2, 3, 4) + y_t = torch.stack(y_t_list, dim=0).permute(1, 0, 2, 3, 4) + if all(tensor is not None for tensor in y0_tensor_list): + y0_tensor = torch.stack(y0_tensor_list, dim=0).permute(1, 0, 2, 3, 4) + + if all(tensor is not None for tensor in mask_list): + mask = torch.stack(mask_list, dim=0).permute(1, 0, 2, 3, 4) + + # run through model + with torch.no_grad(): + if opt.model_type == "palette": + out_tensor, visu = model.restoration( + y_cond=cond_image, + y_t=y_t, + y_0=y0_tensor, + mask=mask, + cls=cls_tensor, + ref=ref_tensor, + sample_num=2, + guidance_scale=alg_diffusion_guidance_scale, + ddim_num_steps=alg_palette_ddim_num_steps, + ddim_eta=alg_palette_ddim_eta, + ) + elif opt.model_type == "cm" or opt.model_type == "cm_gan": + sampling_sigmas = (80.0, 24.4, 5.84, 0.9, 0.661) + + out_tensor = model.restoration(y_t, cond_image, sampling_sigmas, mask) + + # XXX: !=8bit images are converted to 8bit RGB for now + out_tensor = out_tensor.squeeze(0) # since batchsize is 1 with form [b,f,c,h,w] + for i in range(out_tensor.shape[0]): + out_img = to_np( + out_tensor[i : i + 1, :, :, :] + ) # out_img = out_img.detach().data.cpu().float().numpy()[0] + out_img_list.append(out_img) + + if logger: + logger.info( + f"[it: %i/%i] - [3/%i] processing completed" + % (iteration, nb_samples, PROGRESS_NUM_STEPS) + ) + + """ post-processing + + out_img = (np.transpose(out_img, (1, 2, 0)) + 1) / 2.0 * 255.0 + out_img = cv2.cvtColor(out_img, cv2.COLOR_RGB2BGR)""" + for i in range(len(out_img_list)): + out_img = out_img_list[i] + img_orig = img_orig_list[i] + y_t = y_t_list[i] + cond_image = cond_image_list[i] + y0_tensor = y0_tensor_list[i] + mask = mask_list[i] + bbox_select = bbox_select_list[i] + img_tensor = img_tensor_list[i] + + if bbox_in: + bbox_select = bbox_select[i] + out_img_resized = cv2.resize( + out_img, + ( + min(img_orig.shape[1], bbox_select[2] - bbox_select[0]), + min(img_orig.shape[0], bbox_select[3] - bbox_select[1]), + ), + ) + + out_img_real_size = img_orig.copy() + else: + out_img_resized = out_img + out_img_real_size = img_orig.copy() + + # fill out crop into original image + if bbox_in: + bbox_select = bbox_select[i] + out_img_real_size[ + bbox_select[1] : bbox_select[3], bbox_select[0] : bbox_select[2] + ] = out_img_resized + + if cond_image is not None: + cond_image = cond_image_list[i] + cond_img = to_np(cond_image) + name = str(i) + name + if write: + if opt.data_image_bits > 8: + img_tensor = img_tensor_list[i] + img_np = to_np(img_tensor) # comes from PIL + cv2.imwrite(os.path.join(dir_out, name + "_orig.png"), img_np) + if cond_image is not None: + cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img) + cv2.imwrite( + os.path.join(dir_out, name + "_generated.png"), out_img_resized + ) + else: + cv2.imwrite(os.path.join(dir_out, name + "_orig.png"), img_orig) + if cond_image is not None: + cv2.imwrite(os.path.join(dir_out, name + "_cond.png"), cond_img) + cv2.imwrite( + os.path.join(dir_out, name + "_generated.png"), out_img_real_size + ) + cv2.imwrite(os.path.join(dir_out, name + "_y_t.png"), to_np(y_t)) + if mask is not None: + cv2.imwrite( + os.path.join(dir_out, name + "_y_0.png"), to_np(img_tensor) + ) + cv2.imwrite( + os.path.join(dir_out, name + "_generated_crop.png"), out_img + ) + cv2.imwrite(os.path.join(dir_out, name + "_mask.png"), to_np(mask)) + if ref is not None: + cv2.imwrite(os.path.join(dir_out, name + "_ref_orig.png"), ref_orig) + if cond_in: + # crop before cond image + orig_crop = img_orig[ + bbox_select[1] : bbox_select[3], bbox_select[0] : bbox_select[2] + ] + cv2.imwrite( + os.path.join(dir_out, name + "_orig_crop.png"), orig_crop + ) + if bbox_in: + with open( + os.path.join(dir_out, name + "_orig_bbox.json"), "w" + ) as out: + out.write(json.dumps(bbox)) + if generated_bbox: + with open( + os.path.join(dir_out, name + "_generated_bbox.json"), "w" + ) as out: + out.write(json.dumps(generated_bbox)) + + print("Successfully generated image ", name) + + if logger: + logger.info( + f"[it: %i/%i] - [4/%i] image written" + % (iteration, nb_samples, PROGRESS_NUM_STEPS) + ) + + return out_img_real_size, model, opt + + +def inference_logger(name): + PROCESS_NAME = "gen_video_diffusion" + LOG_PATH = os.environ.get( + "LOG_PATH", os.path.join(os.path.dirname(__file__), "../logs") + ) + if not os.path.exists(LOG_PATH): + os.makedirs(LOG_PATH) + + logging.basicConfig( + level=logging.DEBUG, + handlers=[ + logging.FileHandler(f"{LOG_PATH}/{name}.log", mode="w"), + logging.StreamHandler(), + ], + ) + + return logging.getLogger(f"inference %s %s" % (PROCESS_NAME, name)) + + +def inference(args): + PROGRESS_NUM_STEPS = 6 + logger = inference_logger(args.name) + + args.logger = logger + + if len(args.mask_delta_ratio[0]) == 1 and args.mask_delta_ratio[0][0] == 0.0: + mask_delta = args.mask_delta + else: + mask_delta = args.mask_delta_ratio + args.mask_delta = mask_delta + + args.write = True + + real_name = args.name + + args.lmodel = None + args.lopt = None + + for i in tqdm(range(args.nb_samples)): + args.iteration = i + 1 + logger.info(f"[it: %i/%i] launch inference" % (args.iteration, args.nb_samples)) + args.name = real_name + "_" + str(i).zfill(len(str(args.nb_samples))) + frame, lmodel, lopt = generate(**vars(args)) + args.lmodel = lmodel + args.lopt = lopt + + logger.info(f"success - end of inference") + + +def extract_number(filename): + number = "" + for char in filename: + if char.isdigit(): + number += char + else: + break + return int(number) + + +def img2video(args): + image_folder = args.dir_out + video_base_name = "mario_video" + + # Regular expression pattern to capture the number before "_generated.png" + patterns = { + "generated": re.compile(r"(\d+)_generated\.png$"), + "orig": re.compile(r"(\d+)_orig\.png$"), + } + + for suffix, pattern in patterns.items(): + generated_files = defaultdict(list) + images = [img for img in os.listdir(image_folder) if img.endswith(".png")] + + for image in images: + match = pattern.search(image) + if match: + number = match.group(1) + generated_files[number].append(image) + + # Process each category and create a video + for number, file_list in sorted(generated_files.items()): + sorted_list = sorted(file_list, key=extract_number) + + if not sorted_list: + logging.warning( + f"No sorted images to process for number {number} with suffix {suffix}." + ) + continue + + first_image_path = os.path.join(image_folder, sorted_list[0]) + frame = cv2.imread(first_image_path) + if frame is None: + print(f"Error reading the first image: {first_image_path}") + continue + height, width, layers = frame.shape + + video_name = f"{video_base_name}_{number}_{suffix}.avi" + video_path = os.path.join(image_folder, video_name) + + video = cv2.VideoWriter( + video_path, + cv2.VideoWriter_fourcc("M", "J", "P", "G"), + 0.5, + (width, height), + ) + for image in sorted_list: + image_path = os.path.join(image_folder, image) + frame = cv2.imread(image_path) + if frame is not None: + video.write(frame) + else: + print(f"Error reading image: {image_path}") + + # Release the video writer object + cv2.destroyAllWindows() + video.release() + logging.info(f"Video created: {video_path}") + print(f"Video created: {video_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + inference_options = InferenceDiffusionOptions() + parser = inference_options.initialize(parser) + + parser.add_argument( + "--paths_file", + required=True, + help="Path to the paths file which contains pairs of image and bounding box paths", + ) + parser.add_argument( + "--data_root", + required=True, + help="Path to the data folder which contains pairs of image and bounding box paths", + ) + args = parser.parse_args() + inference(args) + img2video(args)