Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: video_unet_generator_attn #669

Merged
merged 1 commit into from
Aug 21, 2024
Merged

WIP: video_unet_generator_attn #669

merged 1 commit into from
Aug 21, 2024

Conversation

wr0124
Copy link
Collaborator

@wr0124 wr0124 commented Jul 10, 2024

  • step0 Get a similar architecture of UNet(ResBlock+AttentionBlock) in joliGEN comparable as in AnimateDiff(ResBlock+TransformerBlock+MotionModule)

  • step1 Modify ResBlock to process 5D tensor image for input and output

  • step1.1 Test ResBlock to process 5D tensor for input and output

  • step1.2 ResBlock embedding ?

  • step2 Replace AttentionBlock by MotionModule

  • step2.1 Using code of MotionModule to replace AttentionBlock

  • step2.2 Test AttentionBlock by MotionModule for 5D tensor for input and output

  • step2.3 MotionModule embedding ?

  • step3 Merge MotionModule in the Video_generator_attn file

  • step3.1 Aligned attention head, input/output channels and maybe other variables, clear the code ?

  • step3.2 Test MotionModule in the Video_generator_attn file for 5D tensor input/output

  • step3.3 Using QKVAttention for attention score calculation for the whole file ?

  • step4 Test UNet for 5D tensor input/output ?

  • step5 Create Dataloader

  • step6 Test UNet with Dataloader

  • step7 Test for training and visualization with visdom

  • step8 Inference script

  • step9 unite test

@wr0124 wr0124 changed the title feat(ml):step1 ResBlock input/output 5D tensor image WIP: video_gen Jul 10, 2024
@wr0124 wr0124 changed the title WIP: video_gen WIP: video_unet_generator_attn Jul 10, 2024
@wr0124
Copy link
Collaborator Author

wr0124 commented Jul 10, 2024

UNet=((ResBlock+Attention )*2)*4 for input_blocks
python3 -W ignore::UserWarning train.py
--dataroot /data1/juliew/mini_dataset/online_mario2sonic_lite
--checkpoints_dir /data1/juliew/checkpoints
--name mario
--config_json examples/example_ddpm_mario.json
--gpu_ids 1
--output_display_env test_mario_unet
--output_display_freq 1
--output_print_freq 1
--G_diff_n_timestep_test 5
--G_diff_n_timestep_train 2000
--G_unet_mha_channel_mults 1 2 4 8
--G_unet_mha_res_blocks 2 2 2 2
--train_batch_size 1
--G_unet_mha_attn_res 1 2 4 8
--data_num_threads 1
~

)
]
ch = int(mult * self.inner_channel)
if ds in attn_res:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable attn_res should not condition the motion module (MM). The MM is mandatory, not conditioned I believe.

Also, attn_res conditions the AttentionBlock in the frame-only UNet, and we should keep this code here as well.

This is because the MM is an addition to any configuration of the frame-only UNet.

efficient=efficient,
freq_space=self.freq_space,
),
MotionModule(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's verify this because :

  • frame-only UNet has a "within-frame" AttentionBlock here, that needs to be kept.
  • I'm not sure the MM applies to the bottleneck : please double-check in publications an code.

Copy link
Collaborator Author

@wr0124 wr0124 Jul 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AttentionBlock is kept, MM is added after it

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the publication's code, whether MM is applied to the bottleneck depends on two options. However, in the two illustration figures in the publication, the bottleneck does not have MM.

)
]
ch = int(self.inner_channel * mult)
if ds in attn_res:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same remark here.

@wr0124
Copy link
Collaborator Author

wr0124 commented Jul 18, 2024

Since joliGEN DDPM temporal use_temporal, it creates tensor in the shape(b,f,c,h,w), which differs from the priginal paper's formate of (b,c,f,h,w). So, in this version, all tensor flow is the formate of (b,f,c,h,w). Due to compatibility with other models in joliGEN, it may be advantageous to treat the tensor in 4D format during trainning ?

@wr0124
Copy link
Collaborator Author

wr0124 commented Jul 18, 2024

python3 -W ignore::UserWarning train.py
--dataroot /data1/juliew/dataset/online_mario2sonic_full_mario
--checkpoints_dir /data1/juliew/checkpoints
--name mario_temporal
--config_json examples/example_ddpm_mario.json
--gpu_ids 2
--output_display_env test_mario_temporal
--output_print_freq 1
--output_display_freq 1
--data_dataset_mode self_supervised_temporal_labeled_mask_online
--train_batch_size 1
--train_iter_size 4
--data_temporal_number_frames 4
--data_temporal_frame_step 1
--data_num_threads 1
--train_temporal_criterion
--G_diff_n_timestep_test 1000
--G_diff_n_timestep_train 2000
--train_temporal_criterion_lambda 1.0
--G_netG unet_vid
--data_online_creation_crop_size_A 64
--data_online_creation_crop_size_B 64
--data_crop_size 64
--data_load_size 64
--G_unet_mha_attn_res 1 2 4 8
--output_verbose \

  • this UNetVid has 4 blocks of (ResBlock+Attention+MM) down and up with Middle block (ResBlock+Attention+ResBlock). It has a similar architecture to the paper, but due to GPU limitations, it can not handle image sizes of 128.
  • batchsize bug

@wr0124
Copy link
Collaborator Author

wr0124 commented Jul 22, 2024

works with command line
python3 -W ignore::UserWarning train.py
--dataroot /data1/juliew/dataset/online_mario2sonic_full_mario
--checkpoints_dir /data1/juliew/checkpoints
--name mario_antoine
--gpu_ids 2
--output_display_env test_mario_antoine
--model_type palette
--output_print_freq 1
--output_display_freq 1
--data_dataset_mode self_supervised_temporal_labeled_mask_online
--train_batch_size 1
--train_iter_size 1
--model_input_nc 3
--model_output_nc 3
--data_relative_paths
--train_G_ema
--train_optim adamw
--train_temporal_criterion_lambda 1.0
--G_netG unet_vid
--data_online_creation_crop_size_A 64
--data_online_creation_crop_size_B 64
--data_crop_size 64
--data_load_size 64
--G_unet_mha_attn_res 16
--data_online_creation_rand_mask_A
--train_G_lr 0.0001
--dataaug_no_rotate
--G_diff_n_timestep_train 5
--G_diff_n_timestep_test 6
--data_temporal_number_frames 4
--data_temporal_frame_step 1
--data_num_threads 4
--UNetVid \

@wr0124
Copy link
Collaborator Author

wr0124 commented Jul 22, 2024

python3 -W ignore::UserWarning train.py
--dataroot /data1/juliew/dataset/online_mario2sonic_full_mario
--checkpoints_dir /data1/juliew/checkpoints
--name mario_vid_bs1
--gpu_ids 2
--model_type palette
--output_print_freq 1
--output_display_freq 1
--data_dataset_mode self_supervised_temporal_labeled_mask_online
--train_batch_size 1
--train_iter_size 4
--model_input_nc 3
--model_output_nc 3
--data_relative_paths
--train_G_ema
--train_optim adamw
--train_temporal_criterion_lambda 1.0
--G_netG unet_vid
--data_online_creation_crop_size_A 64
--data_online_creation_crop_size_B 64
--data_crop_size 64
--data_load_size 64
--G_unet_mha_attn_res 1 2 4 8
--data_online_creation_rand_mask_A
--train_G_lr 0.0001
--dataaug_no_rotate
--G_diff_n_timestep_train 8
--G_diff_n_timestep_test 6
--data_temporal_number_frames 10
--data_temporal_frame_step 1
--data_num_threads 8
--UNetVid
--output_verbose \

  • due to broadcasting in PyTorch, this works when batch size is 1. But when batch size is larger than 1, diffusion_generator.py which mainly works with 4D tensor will encounter an issue.
  • this setting reaches "22625 / 24564 MB" on one GPU

@wr0124
Copy link
Collaborator Author

wr0124 commented Jul 26, 2024

lanch inference

cd scripts/
python3 gen_vid_diffusion.py
--model_in_file /data1/juliew/checkpoints/mario_vid_bs1/latest_net_G_A.pth
--img_in /data1/juliew/mini_dataset/online_mario2sonic_video/trainA/paths_part.txt
--paths_file /data1/juliew/mini_dataset/online_mario2sonic_video/trainA/paths_part.txt
--mask_in /data1/juliew/mini_dataset/online_mario2sonic_video/trainA/paths_part.txt
--data_root /data1/juliew/mini_dataset/online_mario2sonic_video/
--dir_out ../inference_mario
--img_width 128
--img_height 128 \

@wr0124
Copy link
Collaborator Author

wr0124 commented Jul 29, 2024

create videos by this command_line:

cd scripts/
python3 gen_vid_diffusion.py
--model_in_file /data1/juliew/checkpoints/test_vid/latest_net_G_A.pth
--img_in /data1/paths_part.txt
--paths_file /data1/juliew/ori_dataset/online_mario2sonic_full/trainA/paths_part4.txt
--mask_in /paths_part.txt
--data_root /data1/juliew/ori_dataset/online_mario2sonic_full/
--dir_out ../inference_mario_vid
--img_width 128
--img_height 128
--nb_samples 2 \

for k in range(min(nb_imgs, self.get_current_batch_size())):
self.fake_B_pool.query(self.visuals[k : k + 1])

if self.opt.G_netG == "unet_vid":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else ?

efficient=efficient,
freq_space=self.freq_space,
),
# MotionModule(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove commented code ?


# attention, what we cannot get enough of
###attention_score get
# hidden_states_select = self._attention(query, key, value, attention_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove commented code ?

data/__init__.py Outdated
@@ -61,11 +61,20 @@ def create_dataloader(opt, rank, dataset, batch_size):


def create_dataset_temporal(opt, phase):
dataset_class = find_dataset_using_name("temporal_labeled_mask_online")
dataset_class = find_dataset_using_name(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this function needs to be change so that either temporal_labeled_mask_online or self_supervised_temporal_labeled_mask_online is selected based on whether cut or palette is running.

# sort
self.A_img_paths.sort(key=natural_keys)
self.A_label_mask_paths.sort(key=natural_keys)
if self.use_domain_B:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In self_supervised dataloader, domain B is not needed

@wr0124
Copy link
Collaborator Author

wr0124 commented Aug 5, 2024

create one unite test file "test_run_video_diffusion_online.py " for unite test

@wr0124 wr0124 force-pushed the video_gen branch 3 times, most recently from c7888cb to 1fe3f60 Compare August 9, 2024 13:44
@wr0124
Copy link
Collaborator Author

wr0124 commented Aug 14, 2024

during inference, additional frames beyong the specified opt.data_temporal_number_frames can be added for video generation, but according to the literature, this often results in degraded outcomes. the additional_frame in gen_vid_diffusion file needs to be tested when its value is negative.

@@ -346,7 +346,8 @@ def generate(
bbox_select[3] = min(img.shape[0], bbox_select[3])
else:
bbox = bboxes[bbox_idx]

opt.data_online_creation_load_size_A = (1280, 720)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we don´t want hardcoded values here.

Copy link
Collaborator Author

@wr0124 wr0124 Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we temporarily did this to do inference with your model of bdd100k_vid_64_2, since in this model opt.data_online_creation_load_size_A is 720. Normally, this hardcoded line is not required. It is delected.

"train_batch_size": 1,
"data_temporal_number_frames": 8,
"data_temporal_frame_step": 1,
"G_diff_n_timestep_train": 6,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beware I don´t believe you can theoretically have timestep_test < timestep_train.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I misunderstood opt.G_diff_n_timestep_train is 2000 and opt.G_diff_n_timestep_test is 1000 in defalut setting ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had overwritten "G_diff_n_timestep_test", it is corrected now.

…poral consistency and inference

feat(ml):step2 replace AttentionBlock by MotionModule.
ResBlock/MotionModule class instance pass

feat(ml):UNet=ResBlock+Attention(optional)+MM

feat(ml): create UNetVid class with temporal MHA for U-Net

feat(ml):add dataloader

feat(ml): dataloader works with UNet

feat(ml): dataloader and UNetVid works for input (b,f,c,h,w),not visdom
yet

feat(ml):visdom shows the trainning

feat(ml):dataloader with mask

feat(ml): dataloader fixed with command-line

feat(ml): visdom show one batch of frame

feat(ml): frame is treated as a batch, so no additional normailisation
is needed

feat(ml): inference for UNetVid

feat(ml): use efficient_attention_xformers for attention

feat(ml): xformer bug PR

feat(ml): create video based on generated and orig images

feat(ml):remove unnecessary option --UNetVid

feat(ml): add doc for trainning and inference

feat(ml): fix inference paths requirement

feat(ml): improve the inference for any paths.txt and longer frames

feat(ml):unite test only for vid

feat(ml): debug for unite test on metrics

doc: modify scripte for inference

feat(ml):debug inference paths_file

feat(ml): add one option for max frame

feat(ml): inference debug bbox_in not img_in, and for bdd100k video

feat(ml): delect hardcoding in inference

feat(ml): dataloader load frames from same video

feat(ml): adapt processing of frames from either a video series or a single video
@beniz beniz merged commit 43b7018 into jolibrain:master Aug 21, 2024
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants