Skip to content

Commit

Permalink
feat(ml): canny random
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Sep 13, 2024
1 parent b984ad8 commit 19d53a3
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
20 changes: 7 additions & 13 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
q_posterior,
gamma_embedding,
extract,
rearrange_5dto4d,
rearrange_4dto5d,
rearrange_5dto4d_fh,
rearrange_4dto5d_fh,
)


Expand Down Expand Up @@ -203,9 +203,7 @@ def p_mean_variance(
sequence_length = 0
if len(y_t.shape) == 5:
sequence_length = y_t.shape[1]
y_t, y_cond, mask = rearrange_5dto4d(
"b f c h w -> b c (f h) w", y_t, y_cond, mask
)
y_t, y_cond, mask = rearrange_5dto4d_fh(y_t, y_cond, mask)

noise_level = self.extract(
getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1)
Expand All @@ -215,9 +213,7 @@ def p_mean_variance(

input = torch.cat([y_cond, y_t], dim=1)
if sequence_length != 0:
input, y_t, mask = rearrange_4dto5d(
"b c (f h) w -> b f c h w", sequence_length, input, y_t, mask
)
input, y_t, mask = rearrange_4dto5d_fh(sequence_length, input, y_t, mask)

if guidance_scale > 0.0 and phase == "test":
y_0_hat_uncond = predict_start_from_noise(
Expand Down Expand Up @@ -455,9 +451,7 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
# vid only
if len(y_0.shape) == 5:
sequence_length = y_0.shape[1]
y_0, y_cond, mask = rearrange_5dto4d(
"b f c h w -> b c (f h) w", y_0, y_cond, mask
)
y_0, y_cond, mask = rearrange_5dto4d_fh(y_0, y_cond, mask)
b, *_ = y_0.shape

t = torch.randint(
Expand Down Expand Up @@ -487,8 +481,8 @@ def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
input = torch.cat([y_cond, y_noisy], dim=1)

if sequence_length != 0:
input, mask, noise = rearrange_4dto5d(
"b c (f h) w -> b f c h w", sequence_length, input, mask, noise
input, mask, noise = rearrange_4dto5d_fh(
sequence_length, input, mask, noise
)

noise_hat = self.denoise_fn(
Expand Down
22 changes: 18 additions & 4 deletions models/modules/diffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,25 @@ def extract(a, t, x_shape=(1, 1, 1, 1)):
return out.reshape(b, *((1,) * (len(x_shape) - 1)))


def rearrange_5dto4d(pattern, *tensors):
def rearrange_5dto4d_fh(*tensors):
"""Rearrange a tensor according to a given pattern using einops.rearrange."""
return [rearrange(tensor, pattern) for tensor in tensors]
return [rearrange(tensor, "b f c h w -> b c (f h) w") for tensor in tensors]


def rearrange_4dto5d(pattern, frame, *tensors):
def rearrange_4dto5d_fh(frame, *tensors):
"""Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange."""
return [rearrange(tensor, pattern, f=frame) for tensor in tensors]
return [
rearrange(tensor, "b c (f h) w -> b f c h w", f=frame) for tensor in tensors
]


def rearrange_5dto4d_bf(*tensors):
"""Rearrange a tensor according to a given pattern using einops.rearrange."""
return [rearrange(tensor, "b f c h w -> (b f) c h w") for tensor in tensors]


def rearrange_4dto5d_bf(frame, *tensors):
"""Rearrange a tensor from 4D to 5D according to a given pattern using einops.rearrange."""
return [
rearrange(tensor, "(b f) c h w -> b f c h w", f=frame) for tensor in tensors
]
9 changes: 4 additions & 5 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .modules.loss import MultiScaleDiffusionLoss
from .modules.unet_generator_attn.unet_attn_utils import revert_sync_batchnorm

from models.modules.diffusion_utils import rearrange_5dto4d, rearrange_4dto5d
from models.modules.diffusion_utils import rearrange_5dto4d_bf, rearrange_4dto5d_bf


class PaletteModel(BaseDiffusionModel):
Expand Down Expand Up @@ -401,8 +401,8 @@ def set_input(self, data):
and self.opt.alg_diffusion_cond_image_creation
== "computed_sketch"
):
self.mask, self.gt_image = rearrange_5dto4d(
"b f c h w -> (b f) c h w", self.mask, self.gt_image
self.mask, self.gt_image = rearrange_5dto4d_bf(
self.mask, self.gt_image
)

self.cond_image = (
Expand All @@ -420,8 +420,7 @@ def set_input(self, data):
== "computed_sketch"
):

self.mask, self.gt_image, self.cond_image = rearrange_4dto5d(
"(b f) c h w -> b f c h w",
self.mask, self.gt_image, self.cond_image = rearrange_4dto5d_bf(
self.opt.data_temporal_number_frames,
self.mask,
self.gt_image,
Expand Down

0 comments on commit 19d53a3

Please sign in to comment.