-
Notifications
You must be signed in to change notification settings - Fork 31
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(ml): The implementation of UNetVid for generating video with tem…
…poral consistency and inference feat(ml): The implementation of UNetVid for generating video with temporal consistency and inference feat(ml): add UNetVid for generating video with temporal consistency and inference feat(ml):add UNetVid for generating video with temporal consistency feat(ml):step1 ResBlock input/output 5D tensor image feat(ml):step2 replace AttentionBlock by MotionModule. ResBlock/MotionModule class instance pass feat(ml): debug for UNet works for 5D tensor feat(ml):UNet=ResBlock+Attention(optional)+MM feat(ml): create UNetVid class with temporal MHA for U-Net feat(ml):add dataloader feat(ml): dataloader works with UNet feat(ml): dataloader and UNetVid works for input (b,f,c,h,w),not visdom yet feat(ml):visdom shows the trainning feat(ml): debug for 5D feat(ml):debug 5D black feat(ml):typo feat(ml):dataloader with mask feat(ml): dataloader fixed with command-line feat(ml): visdom show one batch of frame feat(ml): frame is treated as a batch, so no additional normailisation is needed feat(ml): inference for UNetVid feat(ml): use efficient_attention_xformers for attention feat(ml): xformer bug PR feat(ml): create video based on generated and orig images feat(ml):remove unnecessary option --UNetVid feat(ml): add doc for trainning and inference feat(ml): fix inference paths requirement feat(ml):clear code feat(ml):black format feat(ml): improve the inference for any paths.txt and longer frames
- Loading branch information
Showing
12 changed files
with
2,734 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
201 changes: 201 additions & 0 deletions
201
data/self_supervised_temporal_labeled_mask_online_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import os | ||
import random | ||
import re | ||
|
||
import torch | ||
|
||
from data.base_dataset import BaseDataset, get_transform_list | ||
from data.image_folder import make_labeled_path_dataset | ||
from data.online_creation import crop_image | ||
from data.online_creation import fill_mask_with_random, fill_mask_with_color | ||
|
||
|
||
def atoi(text): | ||
return int(text) if text.isdigit() else text | ||
|
||
|
||
def natural_keys(text): | ||
return [atoi(c) for c in re.split("(\d+)", text)] | ||
|
||
|
||
class SelfSupervisedTemporalLabeledMaskOnlineDataset(BaseDataset): | ||
def __len__(self): | ||
"""Return the total number of images in the dataset. | ||
As we have two datasets with potentially different number of images, | ||
we take a maximum of | ||
""" | ||
if hasattr(self, "B_img_paths"): | ||
return max(self.A_size, self.B_size) | ||
else: | ||
return self.A_size | ||
|
||
def __init__(self, opt, phase, name=""): | ||
BaseDataset.__init__(self, opt, phase, name) | ||
|
||
self.A_img_paths, self.A_label_mask_paths = make_labeled_path_dataset( | ||
self.dir_A, "/paths.txt" | ||
) # load images from '/path/to/data/trainA/paths.txt' as well as labels | ||
# sort | ||
self.A_img_paths.sort(key=natural_keys) | ||
self.A_label_mask_paths.sort(key=natural_keys) | ||
|
||
if self.opt.data_sanitize_paths: | ||
self.sanitize() | ||
elif opt.data_max_dataset_size != float("inf"): | ||
self.A_img_paths, self.A_label_mask_paths = ( | ||
self.A_img_paths[: opt.data_max_dataset_size], | ||
self.A_label_mask_paths[: opt.data_max_dataset_size], | ||
) | ||
|
||
self.transform = get_transform_list(self.opt, grayscale=(self.input_nc == 1)) | ||
|
||
self.num_frames = opt.data_temporal_number_frames | ||
self.frame_step = opt.data_temporal_frame_step | ||
|
||
self.num_A = len(self.A_img_paths) | ||
self.range_A = self.num_A - self.num_frames * self.frame_step | ||
|
||
self.num_common_char = self.opt.data_temporal_num_common_char | ||
|
||
self.opt = opt | ||
|
||
self.A_size = len(self.A_img_paths) # get the size of dataset A | ||
|
||
def get_img( | ||
self, | ||
A_img_path, | ||
A_label_mask_path, | ||
A_label_cls, | ||
B_img_path=None, | ||
B_label_mask_path=None, | ||
B_label_cls=None, | ||
index=None, | ||
): # all params are unused | ||
index_A = random.randint(0, self.range_A - 1) | ||
|
||
images_A = [] | ||
labels_A = [] | ||
|
||
ref_A_img_path = self.A_img_paths[index_A] | ||
|
||
ref_name_A = ref_A_img_path.split("/")[-1][: self.num_common_char] | ||
|
||
for i in range(self.num_frames): | ||
cur_index_A = index_A + i * self.frame_step | ||
|
||
if ( | ||
self.num_common_char != -1 | ||
and self.A_img_paths[cur_index_A].split("/")[-1][: self.num_common_char] | ||
not in ref_name_A | ||
): | ||
return None | ||
|
||
cur_A_img_path, cur_A_label_path = ( | ||
self.A_img_paths[cur_index_A], | ||
self.A_label_mask_paths[cur_index_A], | ||
) | ||
if self.opt.data_relative_paths: | ||
cur_A_img_path = os.path.join(self.root, cur_A_img_path) | ||
if cur_A_label_path is not None: | ||
cur_A_label_path = os.path.join(self.root, cur_A_label_path) | ||
|
||
try: | ||
if self.opt.data_online_creation_mask_delta_A_ratio == [[]]: | ||
mask_delta_A = self.opt.data_online_creation_mask_delta_A | ||
else: | ||
mask_delta_A = self.opt.data_online_creation_mask_delta_A_ratio | ||
|
||
if i == 0: | ||
crop_coordinates = crop_image( | ||
cur_A_img_path, | ||
cur_A_label_path, | ||
mask_delta=mask_delta_A, | ||
mask_random_offset=self.opt.data_online_creation_mask_random_offset_A, | ||
crop_delta=self.opt.data_online_creation_crop_delta_A, | ||
mask_square=self.opt.data_online_creation_mask_square_A, | ||
crop_dim=self.opt.data_online_creation_crop_size_A, | ||
output_dim=self.opt.data_load_size, | ||
context_pixels=self.opt.data_online_context_pixels, | ||
load_size=self.opt.data_online_creation_load_size_A, | ||
get_crop_coordinates=True, | ||
fixed_mask_size=self.opt.data_online_fixed_mask_size, | ||
) | ||
cur_A_img, cur_A_label, ref_A_bbox, A_ref_bbox_id = crop_image( | ||
cur_A_img_path, | ||
cur_A_label_path, | ||
mask_delta=mask_delta_A, | ||
mask_random_offset=self.opt.data_online_creation_mask_random_offset_A, | ||
crop_delta=self.opt.data_online_creation_crop_delta_A, | ||
mask_square=self.opt.data_online_creation_mask_square_A, | ||
crop_dim=self.opt.data_online_creation_crop_size_A, | ||
output_dim=self.opt.data_load_size, | ||
context_pixels=self.opt.data_online_context_pixels, | ||
load_size=self.opt.data_online_creation_load_size_A, | ||
crop_coordinates=crop_coordinates, | ||
fixed_mask_size=self.opt.data_online_fixed_mask_size, | ||
) | ||
if i == 0: | ||
A_ref_bbox = ref_A_bbox[1:] | ||
|
||
except Exception as e: | ||
print(e, f"{i+1}th frame of domain A in temporal dataloading") | ||
return None | ||
|
||
images_A.append(cur_A_img) | ||
labels_A.append(cur_A_label) | ||
|
||
images_A, labels_A, A_ref_bbox = self.transform(images_A, labels_A, A_ref_bbox) | ||
A_ref_label = labels_A[0] | ||
A_ref_img = images_A[0] | ||
images_A = torch.stack(images_A) | ||
labels_A = torch.stack(labels_A) | ||
|
||
result = { | ||
"A_ref": A_ref_img, | ||
"A": images_A, | ||
"A_img_paths": ref_A_img_path, | ||
"A_ref_bbox": A_ref_bbox, | ||
"A_label_mask": labels_A, | ||
"A_ref_label_mask": A_ref_label, | ||
"B_ref": A_ref_img, | ||
"B": images_A, | ||
"B_img_paths": ref_A_img_path, | ||
"B_ref_bbox": A_ref_bbox, | ||
"B_label_mask": labels_A, | ||
"B_ref_label_mask": A_ref_label, | ||
} | ||
|
||
try: | ||
if self.opt.data_online_creation_rand_mask_A: | ||
A_ref_img = fill_mask_with_random( | ||
result["A_ref"], result["A_ref_label_mask"], -1 | ||
) | ||
images_A = fill_mask_with_random( | ||
result["A"], result["A_label_mask"], -1 | ||
) | ||
elif self.opt.data_online_creation_color_mask_A: | ||
A_ref_img = fill_mask_with_color( | ||
result["A_ref"], result["A_ref_label_mask"], {} | ||
) | ||
images_A = fill_mask_with_color(result["A"], result["A_label_mask"], {}) | ||
else: | ||
raise Exception( | ||
"self supervised dataset: no self supervised method specified" | ||
) | ||
|
||
result.update( | ||
{ | ||
"A_ref": A_ref_img, | ||
"A": images_A, | ||
"A_img_paths": ref_A_img_path, | ||
} | ||
) | ||
except Exception as e: | ||
print( | ||
e, | ||
"self supervised temporal labeled mask online data loading from ", | ||
ref_A_img_path, | ||
) | ||
return None | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.