Skip to content

Commit

Permalink
feat(ml): blocked from inference p_mean_variance function
Browse files Browse the repository at this point in the history
  • Loading branch information
wr0124 committed Aug 2, 2024
1 parent f9685b6 commit d85258d
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 18 deletions.
6 changes: 5 additions & 1 deletion models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,11 @@ def ema_step(self, network_name):
b_ema.copy_(b)

def get_current_batch_size(self):
return self.real_A.shape[0]
if self.opt.G_netG == "unet_vid":
batch_size = self.real_A.shape[0] // self.opt.data_temporal_number_frames
else:
batch_size = self.real_A.shape[0]
return batch_size

def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
Expand Down
70 changes: 63 additions & 7 deletions models/modules/diffusion_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def restoration_ddpm(
ref,
):
phase = "test"
print(self.cond_embed_dim)

print(
" restoration_ddpm ",
Expand All @@ -140,22 +141,45 @@ def restoration_ddpm(
guidance_scale,
ref,
)
#
# bs, frame, channel, height, width = y_0.shape
# y_cond = y_cond.view(bs * frame, channel, height, width)
# y_t = y_t.view(bs * frame, channel, height, width)
# y_0 = y_0.view(bs * frame, channel, height, width)
# bs, frame, channel, height, width = mask.shape
# mask = mask.view(bs * frame, channel, height, width)
#
print(
" y_cond, y_t, y_0 mask ", y_cond.shape, y_t.shape, y_0.shape, mask.shape
)
b, *_ = y_cond.shape

assert (
self.denoise_fn.model.num_timesteps_test > sample_num
), "num_timesteps must greater than sample_num"
sample_inter = self.denoise_fn.model.num_timesteps_test // sample_num

print("sample_inter", sample_inter)
y_t = self.default(y_t, lambda: torch.randn_like(y_cond))
ret_arr = y_t
print(" y_t", y_t.shape, ret_arr.shape)
for i in tqdm(
reversed(range(0, self.denoise_fn.model.num_timesteps_test)),
desc="sampling loop time step",
total=self.denoise_fn.model.num_timesteps_test,
):
b = y_cond.shape[-4]
t = torch.full((b,), i, device=y_cond.device, dtype=torch.long)

print(
"p_sample block input ",
y_t.shape,
t,
y_cond.shape,
phase,
cls,
mask.shape,
ref,
guidance_scale,
)
y_t = self.p_sample(
y_t,
t,
Expand All @@ -166,7 +190,6 @@ def restoration_ddpm(
ref=ref,
guidance_scale=guidance_scale,
)

if mask is not None:
temp_mask = torch.clamp(mask, min=0.0, max=1.0)
y_t = y_0 * (1.0 - temp_mask) + temp_mask * y_t
Expand Down Expand Up @@ -202,10 +225,19 @@ def p_mean_variance(
noise_level = self.extract(
getattr(self.denoise_fn.model, "gammas_" + phase), t, x_shape=(1, 1)
).to(y_t.device)

embed_noise_level = self.compute_gammas(noise_level)

input = torch.cat([y_cond, y_t], dim=1)
print(
" p_mean_variance y_cond y_t ",
y_cond.shape,
y_t.shape,
embed_noise_level.shape,
)
print(
" ",
getattr(self.denoise_fn.model, "gammas_" + phase).shape,
dir(self.denoise_fn.model),
)
input = torch.cat([y_cond, y_t], dim=-3)

if guidance_scale > 0.0 and phase == "test":
y_0_hat_uncond = predict_start_from_noise(
Expand All @@ -231,16 +263,31 @@ def p_mean_variance(
),
phase=phase,
)
print(
" after predictstart fromnoise y_0_hat ",
y_0_hat.shape,
input.shape,
embed_noise_level.shape,
cls,
mask.shape,
ref,
)

if guidance_scale > 0.0 and phase == "test":
y_0_hat = (1 + guidance_scale) * y_0_hat - guidance_scale * y_0_hat_uncond

if clip_denoised:
y_0_hat.clamp_(-1.0, 1.0)

print(" before q_posterior ", y_0_hat.shape, y_t.shape, y_t.shape, t, phase)
model_mean, posterior_log_variance = q_posterior(
self.denoise_fn.model, y_0_hat=y_0_hat, y_t=y_t, t=t, phase=phase
)
print(
" model_mean, posterior_log_variance ",
model_mean.shape,
posterior_log_variance.shape,
)
return model_mean, posterior_log_variance

def q_sample(self, y_0, sample_gammas, noise=None):
Expand Down Expand Up @@ -438,7 +485,15 @@ def ddim_p_mean_variance(
return model_mean, posterior_log_variance

def forward(self, y_0, y_cond, mask, noise, cls, ref, dropout_prob=0.0):
# print("diffusiongenerator forward ", y_0.shape , y_cond.shape, mask.shape, noise,cls, ref )
print(
"diffusiongenerator forward ",
y_0.shape,
y_cond.shape,
mask.shape,
noise,
cls,
ref,
)
bs, frame, channel, height, width = y_0.shape
y_0 = y_0.view(bs * frame, channel, height, width)
y_cond = y_cond.view(bs * frame, channel, height, width)
Expand Down Expand Up @@ -519,5 +574,6 @@ def set_new_sampling_method(self, sampling_method):
self.sampling_method = sampling_method

def compute_gammas(self, gammas):
print(" diffusion_generator compute_gammas ", gammas)
emb = self.cond_embed(gamma_embedding(gammas, self.cond_embed_gammas_in))
return emb
1 change: 1 addition & 0 deletions models/modules/diffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def set_new_noise_schedule(model, phase):


def predict_start_from_noise(model, y_t, t, noise, phase):
print(" inside predictstartfromnoise ", y_t.shape, t.shape, noise.shape, phase)
return (
extract(getattr(model, "sqrt_recip_gammas_" + phase), t, y_t.shape) * y_t
- extract(getattr(model, "sqrt_recipm1_gammas_" + phase), t, y_t.shape) * noise
Expand Down
64 changes: 54 additions & 10 deletions models/palette_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,7 @@ def inference(self, nb_imgs, offset=0):

# no class conditioning
else:
print("is here ??? ")
if self.cls is not None:
cls = self.cls[:nb_imgs]
else:
Expand All @@ -650,17 +651,56 @@ def inference(self, nb_imgs, offset=0):
mask = self.mask[:nb_imgs]
else:
mask = self.mask

self.output, self.visuals = netG.restoration(
y_cond=self.cond_image[:nb_imgs],
y_t=self.y_t[:nb_imgs],
y_0=self.gt_image[:nb_imgs],
mask=mask,
sample_num=self.sample_num,
cls=cls,
ddim_num_steps=self.ddim_num_steps,
ddim_eta=self.ddim_eta,
print("befroe restoration ")
print(
" self.cond_image y_t gt_image, mask ,cls, ddim_num_steps, ddim_eta ",
self.cond_image.shape,
self.y_t.shape,
self.gt_image.shape,
self.mask.shape,
self.sample_num,
cls,
self.ddim_num_steps,
self.ddim_eta,
)
print(self.opt.train_batch_size)
if self.opt.G_netG == "unet_vid":
bf, channel, height, width = self.cond_image.shape
frame = bf // self.opt.train_batch_size
self.cond_image = self.cond_image.contiguous().view(
self.opt.train_batch_size, frame, channel, height, width
)
self.y_t = self.y_t.contiguous().view(
self.opt.train_batch_size, frame, channel, height, width
)
self.gt_image = self.gt_image.contiguous().view(
self.opt.train_batch_size, frame, channel, height, width
)
bf, channel, height, width = mask.shape
mask = mask.contiguous().view(
self.opt.train_batch_size, frame, channel, height, width
)
self.output, self.visuals = netG.restoration(
y_cond=self.cond_image[0:1],
y_t=self.y_t[0:1],
y_0=self.gt_image[0:1],
mask=mask[0:1],
sample_num=self.sample_num,
cls=cls,
ddim_num_steps=self.ddim_num_steps,
ddim_eta=self.ddim_eta,
)
else:
self.output, self.visuals = netG.restoration(
y_cond=self.cond_image[:nb_imgs],
y_t=self.y_t[:nb_imgs],
y_0=self.gt_image[:nb_imgs],
mask=mask,
sample_num=self.sample_num,
cls=cls,
ddim_num_steps=self.ddim_num_steps,
ddim_eta=self.ddim_eta,
)
self.fake_B = self.output

# task: super resolution, pix2pix
Expand Down Expand Up @@ -696,6 +736,10 @@ def inference(self, nb_imgs, offset=0):
def compute_visuals(self, nb_imgs):
super().compute_visuals(nb_imgs)
with torch.no_grad():
if self.opt.G_netG == "unet_vid":
nb_imgs = self.batch_size
else:
nb_imgs = nb_imgs
self.inference(nb_imgs)

def get_dummy_input(self, device=None):
Expand Down
2 changes: 2 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal):
t_comp = (time.time() - iter_start_time) / opt.train_batch_size

batch_size = model.get_current_batch_size() * len(opt.gpu_ids)
print("trainpy batch_size ", batch_size)
opt.total_iters += batch_size
epoch_iter += batch_size
if (
Expand Down Expand Up @@ -247,6 +248,7 @@ def train_gpu(rank, world_size, opt, trainset, trainset_temporal):
opt.total_iters % opt.output_display_freq < batch_size
): # display images on visdom and save images to a HTML file
save_result = opt.total_iters % opt.output_update_html_freq == 0
print("trainpy from here nb_imgs ???", opt.train_batch_size)
model.compute_visuals(opt.train_batch_size)
if not "none" in opt.output_display_type:
visualizer.display_current_results(
Expand Down

0 comments on commit d85258d

Please sign in to comment.