Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add svd temporal decoder #61

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
*.gif
.vscode
1 change: 1 addition & 0 deletions base/configs/sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ run_time: 0
guidance_scale: 7.5
sample_method: 'ddpm'
num_sampling_steps: 50
enable_vae_temporal_decoder: True
text_prompt: [
'a teddy bear walking on the street, 2k, high quality',
'a panda taking a selfie, 2k, high quality',
Expand Down
32 changes: 29 additions & 3 deletions base/pipelines/pipeline_videogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
from diffusers.utils.torch_utils import randn_tensor


from diffusers.pipeline_utils import DiffusionPipeline
# from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from dataclasses import dataclass

import os, sys
Expand Down Expand Up @@ -409,7 +410,28 @@ def decode_latents(self, latents):
latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
video = self.vae.decode(latents).sample
video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
video = ((video / 2 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous()
video = ((video / 2.0 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous()
return video

def decode_latents_with_temporal_decoder(self, latents):
video_length = latents.shape[2]
latents = 1 / self.vae.config.scaling_factor * latents
latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
video = []

decode_chunk_size = 14
for frame_idx in range(0, latents.shape[0], decode_chunk_size):
num_frames_in = latents[frame_idx : frame_idx + decode_chunk_size].shape[0]

decode_kwargs = {}
decode_kwargs["num_frames"] = num_frames_in

video.append(self.vae.decode(latents[frame_idx:frame_idx+decode_chunk_size], **decode_kwargs).sample)

video = torch.cat(video)
video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().contiguous()
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
return video

def prepare_extra_step_kwargs(self, generator, eta):
Expand Down Expand Up @@ -515,6 +537,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
enable_vae_temporal_decoder: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -672,6 +695,9 @@ def __call__(


# 8. Post-processing
video = self.decode_latents(latents)
if enable_vae_temporal_decoder:
video = self.decode_latents_with_temporal_decoder(latents)
else:
video = self.decode_latents(latents)

return StableDiffusionPipelineOutput(video=video)
128 changes: 66 additions & 62 deletions base/pipelines/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from download import find_model
from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
from diffusers.models import AutoencoderKL
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
from omegaconf import OmegaConf

Expand All @@ -17,73 +17,77 @@
import imageio

def main(args):
if args.seed is not None:
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if args.seed is not None:
torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
unet = get_models(args, sd_path).to(device, dtype=torch.float16)
state_dict = find_model(args.ckpt_path)
unet.load_state_dict(state_dict)

vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
sd_path = os.path.join(args.pretrained_path, "stable-diffusion-v1-4")
unet = get_models(args, sd_path).to(device, dtype=torch.float16)
state_dict = find_model(args.ckpt_path)
unet.load_state_dict(state_dict)

# set eval mode
unet.eval()
vae.eval()
text_encoder_one.eval()

if args.sample_method == 'ddim':
scheduler = DDIMScheduler.from_pretrained(sd_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
elif args.sample_method == 'eulerdiscrete':
scheduler = EulerDiscreteScheduler.from_pretrained(sd_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
elif args.sample_method == 'ddpm':
scheduler = DDPMScheduler.from_pretrained(sd_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
else:
raise NotImplementedError
if args.enable_vae_temporal_decoder:
vae = AutoencoderKLTemporalDecoder.from_pretrained(sd_path, subfolder="vae_temporal_decoder", torch_dtype=torch.float16).to(device)
else:
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge

videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder_one,
tokenizer=tokenizer_one,
scheduler=scheduler,
unet=unet).to(device)
videogen_pipeline.enable_xformers_memory_efficient_attention()
# set eval mode
unet.eval()
vae.eval()
text_encoder_one.eval()

if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder)
if args.sample_method == 'ddim':
scheduler = DDIMScheduler.from_pretrained(sd_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
elif args.sample_method == 'eulerdiscrete':
scheduler = EulerDiscreteScheduler.from_pretrained(sd_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
elif args.sample_method == 'ddpm':
scheduler = DDPMScheduler.from_pretrained(sd_path,
subfolder="scheduler",
beta_start=args.beta_start,
beta_end=args.beta_end,
beta_schedule=args.beta_schedule)
else:
raise NotImplementedError

video_grids = []
for prompt in args.text_prompt:
print('Processing the ({}) prompt'.format(prompt))
videos = videogen_pipeline(prompt,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=args.guidance_scale).video
imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0

print('save path {}'.format(args.output_folder))
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder_one,
tokenizer=tokenizer_one,
scheduler=scheduler,
unet=unet).to(device)
videogen_pipeline.enable_xformers_memory_efficient_attention()

if not os.path.exists(args.output_folder):
os.makedirs(args.output_folder)

video_grids = []
for prompt in args.text_prompt:
print('Processing the ({}) prompt'.format(prompt))
videos = videogen_pipeline(prompt,
video_length=args.video_length,
height=args.image_size[0],
width=args.image_size[1],
num_inference_steps=args.num_sampling_steps,
guidance_scale=args.guidance_scale,
enable_vae_temporal_decoder=args.enable_vae_temporal_decoder).video
imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0

print('save path {}'.format(args.output_folder))

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="")
args = parser.parse_args()
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="")
args = parser.parse_args()

main(OmegaConf.load(args.config))
main(OmegaConf.load(args.config))

2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- accelerate==0.19.0
- av==10.0.0
- decord==0.6.0
- diffusers[torch]==0.16.0
- diffusers[torch]==0.24.0
- einops==0.6.1
- ffmpeg==1.4
- imageio==2.31.1
Expand Down