Skip to content

Commit

Permalink
profile cogvideox
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Sep 19, 2024
1 parent 2b443a5 commit 5687dc6
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 111 deletions.
99 changes: 52 additions & 47 deletions src/diffusers/models/transformers/cogvideox_transformer_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
from torch import nn
from torch.profiler import record_function

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
Expand Down Expand Up @@ -433,49 +434,52 @@ def forward(
batch_size, num_frames, channels, height, width = hidden_states.shape

# 1. Time embedding
timesteps = timestep
t_emb = self.time_proj(timesteps)
with record_function("time embedding"):
timesteps = timestep
t_emb = self.time_proj(timesteps)

# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)

# 2. Patch embedding
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
with record_function("patch embedding"):
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
hidden_states = self.embedding_dropout(hidden_states)

text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
text_seq_length = encoder_hidden_states.shape[1]
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]

# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)
with record_function("blocks"):
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=emb,
image_rotary_emb=image_rotary_emb,
)

if not self.config.use_rotary_positional_embeddings:
# CogVideoX-2B
Expand All @@ -487,16 +491,17 @@ def custom_forward(*inputs):
hidden_states = hidden_states[:, text_seq_length:]

# 4. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)

# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
with record_function("final output"):
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)

# 5. Unpatchify
# Note: we use `-1` instead of `channels`:
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
p = self.config.patch_size
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
Expand Down
100 changes: 68 additions & 32 deletions src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from torch.profiler import record_function
from transformers import T5EncoderModel, T5Tokenizer

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
Expand Down Expand Up @@ -679,39 +680,73 @@ def __call__(
timestep = t.expand(latent_model_input.shape[0])

# predict noise model_output
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_pred.float()
with record_function(f"transformer_iteration_{i}"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
image_rotary_emb=image_rotary_emb,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# noise_pred = noise_pred.float()

# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

# compute the previous noisy sample x_t -> x_t-1
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
else:
latents, old_pred_original_sample = self.scheduler.step(
noise_pred,
old_pred_original_sample,
t,
timesteps[i - 1] if i > 0 else None,
latents,
**extra_step_kwargs,
return_dict=False,
with record_function(f"guidance_{i}"):
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0))
/ 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

with record_function("1.1 scheduler"):
prev_timestep = t - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps

with record_function("1.2 scheduler"):
alpha_prod_t = self.scheduler.alphas_cumprod[t]

with record_function("1.3 scheduler"):
alpha_prod_t_prev = (
self.scheduler.alphas_cumprod[prev_timestep]
if prev_timestep >= 0
else self.scheduler.final_alpha_cumprod
)
latents = latents.to(prompt_embeds.dtype)

with record_function("1.4 scheduler"):
beta_prod_t = 1 - alpha_prod_t

with record_function("1.5 scheduler"):
pred_original_sample = (alpha_prod_t**0.5) * latents - (beta_prod_t**0.5) * noise_pred

with record_function("1.6 scheduler"):
a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5

with record_function("1.7 scheduler"):
b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t

with record_function("1.8 scheduler"):
prev_sample = a_t * latents + b_t * pred_original_sample

latents = prev_sample

# # compute the previous noisy sample x_t -> x_t-1
# with record_function(f"scheduler_step_{i}"):
# if not isinstance(self.scheduler, CogVideoXDPMScheduler):
# latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# else:
# latents, old_pred_original_sample = self.scheduler.step(
# noise_pred,
# old_pred_original_sample,
# t,
# timesteps[i - 1] if i > 0 else None,
# latents,
# **extra_step_kwargs,
# return_dict=False,
# )
# # latents = latents.to(prompt_embeds.dtype)

# call the callback, if provided
if callback_on_step_end is not None:
Expand All @@ -728,8 +763,9 @@ def __call__(
progress_bar.update()

if not output_type == "latent":
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
with record_function("decode_latents"):
video = self.decode_latents(latents)
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
else:
video = latents

Expand Down
99 changes: 67 additions & 32 deletions src/diffusers/schedulers/scheduling_ddim_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import numpy as np
import torch
from torch.profiler import record_function

from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
Expand Down Expand Up @@ -362,41 +363,75 @@ def step(
# - pred_prev_sample -> "x_t-1"

# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

beta_prod_t = 1 - alpha_prod_t

# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
with record_function("get original prediction"):
with record_function("step 1"):
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

# 2. compute alphas, betas
with record_function("step 2"):
print(self.alphas_cumprod.device, self.alphas_cumprod.dtype)
print(timestep.device, timestep.type)
print(prev_timestep.device, prev_timestep.dtype)
with record_function("step 2.1"):
alpha_prod_t = self.alphas_cumprod[timestep]

with record_function("step 2.2"):
alpha_prod_t_prev = (
self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
)

with record_function("step 2.3"):
beta_prod_t = 1 - alpha_prod_t
print(beta_prod_t.device, beta_prod_t.dtype)
print("======")

with record_function("step 3"):
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
# pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
print(
"vpred:",
sample.dtype,
model_output.dtype,
alpha_prod_t.dtype,
beta_prod_t.dtype,
pred_original_sample.dtype,
)
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)

with record_function("compute prev sample"):
a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t

prev_sample = a_t * sample + b_t * pred_original_sample
print(
"prevsample devices:",
a_t.device,
b_t.device,
sample.device,
pred_original_sample.device,
prev_sample.device,
)
print("prevsample:", a_t.dtype, b_t.dtype, sample.dtype, pred_original_sample.dtype, prev_sample.dtype)
print("=== done ===")

a_t = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5
b_t = alpha_prod_t_prev**0.5 - alpha_prod_t**0.5 * a_t

prev_sample = a_t * sample + b_t * pred_original_sample

if not return_dict:
return (prev_sample,)
if not return_dict:
return (prev_sample,)

return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
def add_noise(
Expand Down

0 comments on commit 5687dc6

Please sign in to comment.