diff --git a/.gitignore b/.gitignore index 2d0c929..5473854 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ *.gif +.vscode diff --git a/base/configs/sample.yaml b/base/configs/sample.yaml index b8ca240..6925e25 100644 --- a/base/configs/sample.yaml +++ b/base/configs/sample.yaml @@ -23,6 +23,7 @@ run_time: 0 guidance_scale: 7.5 sample_method: 'ddpm' num_sampling_steps: 50 +enable_vae_temporal_decoder: True text_prompt: [ 'a teddy bear walking on the street, 2k, high quality', 'a panda taking a selfie, 2k, high quality', diff --git a/base/pipelines/pipeline_videogen.py b/base/pipelines/pipeline_videogen.py index 97031fc..8b67c6f 100644 --- a/base/pipelines/pipeline_videogen.py +++ b/base/pipelines/pipeline_videogen.py @@ -37,7 +37,8 @@ from diffusers.utils.torch_utils import randn_tensor -from diffusers.pipeline_utils import DiffusionPipeline +# from diffusers.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.pipeline_utils import DiffusionPipeline from dataclasses import dataclass import os, sys @@ -409,7 +410,28 @@ def decode_latents(self, latents): latents = einops.rearrange(latents, "b c f h w -> (b f) c h w") video = self.vae.decode(latents).sample video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length) - video = ((video / 2 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous() + video = ((video / 2.0 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous() + return video + + def decode_latents_with_temporal_decoder(self, latents): + video_length = latents.shape[2] + latents = 1 / self.vae.config.scaling_factor * latents + latents = einops.rearrange(latents, "b c f h w -> (b f) c h w") + video = [] + + decode_chunk_size = 14 + for frame_idx in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[frame_idx : frame_idx + decode_chunk_size].shape[0] + + decode_kwargs = {} + decode_kwargs["num_frames"] = num_frames_in + + video.append(self.vae.decode(latents[frame_idx:frame_idx+decode_chunk_size], **decode_kwargs).sample) + + video = torch.cat(video) + video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length) + video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous() + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 return video def prepare_extra_step_kwargs(self, generator, eta): @@ -515,6 +537,7 @@ def __call__( callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + enable_vae_temporal_decoder: bool = False, ): r""" Function invoked when calling the pipeline for generation. @@ -672,6 +695,9 @@ def __call__( # 8. Post-processing - video = self.decode_latents(latents) + if enable_vae_temporal_decoder: + video = self.decode_latents_with_temporal_decoder(latents) + else: + video = self.decode_latents(latents) return StableDiffusionPipelineOutput(video=video) diff --git a/base/pipelines/sample.py b/base/pipelines/sample.py index 3a99d3a..c3d17fd 100644 --- a/base/pipelines/sample.py +++ b/base/pipelines/sample.py @@ -7,7 +7,7 @@ from download import find_model from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler -from diffusers.models import AutoencoderKL +from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection from omegaconf import OmegaConf @@ -17,73 +17,77 @@ import imageio def main(args): - if args.seed is not None: - torch.manual_seed(args.seed) - torch.set_grad_enabled(False) - device = "cuda" if torch.cuda.is_available() else "cpu" + if args.seed is not None: + torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" - sd_path = args.pretrained_path + "/stable-diffusion-v1-4" - unet = get_models(args, sd_path).to(device, dtype=torch.float16) - state_dict = find_model(args.ckpt_path) - unet.load_state_dict(state_dict) - - vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device) - tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer") - text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge + sd_path = os.path.join(args.pretrained_path, "stable-diffusion-v1-4") + unet = get_models(args, sd_path).to(device, dtype=torch.float16) + state_dict = find_model(args.ckpt_path) + unet.load_state_dict(state_dict) - # set eval mode - unet.eval() - vae.eval() - text_encoder_one.eval() - - if args.sample_method == 'ddim': - scheduler = DDIMScheduler.from_pretrained(sd_path, - subfolder="scheduler", - beta_start=args.beta_start, - beta_end=args.beta_end, - beta_schedule=args.beta_schedule) - elif args.sample_method == 'eulerdiscrete': - scheduler = EulerDiscreteScheduler.from_pretrained(sd_path, - subfolder="scheduler", - beta_start=args.beta_start, - beta_end=args.beta_end, - beta_schedule=args.beta_schedule) - elif args.sample_method == 'ddpm': - scheduler = DDPMScheduler.from_pretrained(sd_path, - subfolder="scheduler", - beta_start=args.beta_start, - beta_end=args.beta_end, - beta_schedule=args.beta_schedule) - else: - raise NotImplementedError + if args.enable_vae_temporal_decoder: + vae = AutoencoderKLTemporalDecoder.from_pretrained(sd_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device) + else: + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device) + tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer") + text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge - videogen_pipeline = VideoGenPipeline(vae=vae, - text_encoder=text_encoder_one, - tokenizer=tokenizer_one, - scheduler=scheduler, - unet=unet).to(device) - videogen_pipeline.enable_xformers_memory_efficient_attention() + # set eval mode + unet.eval() + vae.eval() + text_encoder_one.eval() - if not os.path.exists(args.output_folder): - os.makedirs(args.output_folder) + if args.sample_method == 'ddim': + scheduler = DDIMScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + elif args.sample_method == 'eulerdiscrete': + scheduler = EulerDiscreteScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + elif args.sample_method == 'ddpm': + scheduler = DDPMScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + else: + raise NotImplementedError - video_grids = [] - for prompt in args.text_prompt: - print('Processing the ({}) prompt'.format(prompt)) - videos = videogen_pipeline(prompt, - video_length=args.video_length, - height=args.image_size[0], - width=args.image_size[1], - num_inference_steps=args.num_sampling_steps, - guidance_scale=args.guidance_scale).video - imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0 - - print('save path {}'.format(args.output_folder)) + videogen_pipeline = VideoGenPipeline(vae=vae, + text_encoder=text_encoder_one, + tokenizer=tokenizer_one, + scheduler=scheduler, + unet=unet).to(device) + videogen_pipeline.enable_xformers_memory_efficient_attention() + + if not os.path.exists(args.output_folder): + os.makedirs(args.output_folder) + + video_grids = [] + for prompt in args.text_prompt: + print('Processing the ({}) prompt'.format(prompt)) + videos = videogen_pipeline(prompt, + video_length=args.video_length, + height=args.image_size[0], + width=args.image_size[1], + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale, + enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video + imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0 + + print('save path {}'.format(args.output_folder)) if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="") - args = parser.parse_args() + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="") + args = parser.parse_args() - main(OmegaConf.load(args.config)) + main(OmegaConf.load(args.config)) diff --git a/environment.yml b/environment.yml index 987bae6..53200ff 100644 --- a/environment.yml +++ b/environment.yml @@ -11,7 +11,7 @@ dependencies: - accelerate==0.19.0 - av==10.0.0 - decord==0.6.0 - - diffusers[torch]==0.16.0 + - diffusers[torch]==0.24.0 - einops==0.6.1 - ffmpeg==1.4 - imageio==2.31.1