Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[experiment] CogVideoX 🤝🏼 FreeNoise #9389

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1054,9 +1054,8 @@ def forward(
accumulated_values = torch.zeros_like(hidden_states)

for i, (frame_start, frame_end) in enumerate(frame_indices):
# The reason for slicing here is to ensure that if (frame_end - frame_start) is to handle
# cases like frame_indices=[(0, 16), (16, 20)], if the user provided a video with 19 frames, or
# essentially a non-multiple of `context_length`.
# The reason for slicing here is to handle cases like frame_indices=[(0, 16), (16, 20)],
# if the user provided a video with 19 frames, or essentially a non-multiple of `context_length`.
weights = torch.ones_like(num_times_accumulated[:, frame_start:frame_end])
weights *= frame_weights

Expand Down
126 changes: 110 additions & 16 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -354,9 +354,12 @@ def __init__(
super().__init__()

self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.sample_height = sample_height
self.text_embed_dim = text_embed_dim
self.bias = bias
self.sample_width = sample_width
self.sample_height = sample_height
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
Expand All @@ -377,7 +380,6 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
num_patches = post_patch_height * post_patch_width * post_time_compression_frames

pos_embedding = get_3d_sincos_pos_embed(
self.embed_dim,
Expand All @@ -387,12 +389,7 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp
self.temporal_interpolation_scale,
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
joint_pos_embedding = torch.zeros(
1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
)
joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)

return joint_pos_embedding
return pos_embedding

def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Expand All @@ -409,11 +406,108 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
image_embeds = image_embeds.flatten(1, 2).contiguous() # [batch, num_frames x height x width, channels]

if self.use_positional_embeddings:
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
if (
self.sample_height != height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(image_embeds.device, dtype=image_embeds.dtype)
else:
pos_embedding = self.pos_embedding

image_embeds = image_embeds + pos_embedding

return text_embeds, image_embeds


class FreeNoiseCogVideoXPatchEmbed(nn.Module):
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
embed_dim: int = 1920,
text_embed_dim: int = 4096,
bias: bool = True,
sample_width: int = 90,
sample_height: int = 60,
sample_frames: int = 49,
temporal_compression_ratio: int = 4,
max_text_seq_length: int = 226,
spatial_interpolation_scale: float = 1.875,
temporal_interpolation_scale: float = 1.0,
use_positional_embeddings: bool = True,
) -> None:
super().__init__()

self.patch_size = patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
self.text_embed_dim = text_embed_dim
self.bias = bias
self.sample_width = sample_width
self.sample_height = sample_height
self.sample_frames = sample_frames
self.temporal_compression_ratio = temporal_compression_ratio
self.max_text_seq_length = max_text_seq_length
self.spatial_interpolation_scale = spatial_interpolation_scale
self.temporal_interpolation_scale = temporal_interpolation_scale
self.use_positional_embeddings = use_positional_embeddings

self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.text_proj = nn.Linear(text_embed_dim, embed_dim)

if use_positional_embeddings:
pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)

# Copied from diffusers.models.embeddings.CogVideoXPatchEmbed
def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
post_patch_height = sample_height // self.patch_size
post_patch_width = sample_width // self.patch_size
post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1

pos_embedding = get_3d_sincos_pos_embed(
self.embed_dim,
(post_patch_width, post_patch_height),
post_time_compression_frames,
self.spatial_interpolation_scale,
self.temporal_interpolation_scale,
)
pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
return pos_embedding

embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
def forward(self, text_embeds: Union[torch.Tensor, Tuple[Dict[int, torch.Tensor]]], image_embeds: torch.Tensor):
r"""
Args:
text_embeds (`torch.Tensor`):
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
if isinstance(text_embeds, torch.Tensor):
text_embeds = self.text_proj(text_embeds)
else:
assert isinstance(text_embeds, tuple)
text_embeds_output = []
for tuple_index in range(len(text_embeds)):
text_embeds_output.append({})
for key, text_embed in list(text_embeds[tuple_index].items()):
text_embeds_output[tuple_index][key] = self.text_proj(text_embed)
text_embeds = tuple(text_embeds_output)

batch, num_frames, channels, height, width = image_embeds.shape
image_embeds = image_embeds.reshape(-1, channels, height, width)
image_embeds = self.proj(image_embeds)
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2).contiguous() # [batch, num_frames x height x width, channels]

if self.use_positional_embeddings:
pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
Expand All @@ -423,13 +517,13 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(embeds.device, dtype=embeds.dtype)
pos_embedding = pos_embedding.to(image_embeds.device, dtype=image_embeds.dtype)
else:
pos_embedding = self.pos_embedding

embeds = embeds + pos_embedding
image_embeds = image_embeds + pos_embedding

return embeds
return text_embeds, image_embeds


def get_3d_rotary_pos_embed(
Expand Down
Loading
Loading