Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Pass class index as an integer to discriminator in multiprompt #657

Open
wants to merge 33 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
c54a516
feat: multiprompt
wr0124 May 14, 2024
66492d7
feat: multi-prompt local works
wr0124 May 15, 2024
c888f84
feat(ml): batched prompts for turbo
beniz May 17, 2024
55f5b1d
feat(ml): prompt print on the image real_B
wr0124 May 24, 2024
533689d
feat(ml): horse2zebra inference
wr0124 May 24, 2024
137a6ac
feat(ml): multi-prompt on image A and B
wr0124 May 27, 2024
4533645
feat(ml): inference for multiprompt
wr0124 May 30, 2024
cebeb85
feat(ml): pass class index as an int to D for batch_size 1 in
wr0124 Jun 5, 2024
5771969
Merge branch 'jolibrain:master' into netD_cls
wr0124 Jun 5, 2024
15fbbb9
feat(ml): multi batch_size with multi prompt fixed
wr0124 Jun 6, 2024
a7235d6
feat(ml): G_sem_mask_AB no zero
wr0124 Jun 10, 2024
f2400a2
feat(ml): black format and remove tensor2im_re
wr0124 Jun 10, 2024
2e07ee6
fix: cm with conditioning
beniz Jun 12, 2024
2f81f63
feat(ml): CM with added discriminator
beniz Mar 13, 2024
7633ffc
feat(ml): adding example of CM+discriminator
beniz Jun 18, 2024
8ecb554
chore(ml): added CM+GAN unit tests
beniz Jun 18, 2024
e47605f
doc: options auto update
Jun 19, 2024
863f32b
feat(ml): modif for horse2zebra prompt
wr0124 May 20, 2024
4474b77
feat(ml): prompt for inference horze2zebra
wr0124 May 21, 2024
7b8fb6e
fix: identity with cut turbo
beniz Jun 14, 2024
f183f14
fix: prompt unaligned loading
beniz Jun 18, 2024
b0a45a7
fix: paths loading prompts file
beniz Jun 18, 2024
bcba748
fix: gan inference script with prompts
beniz Jun 19, 2024
e8c0bd8
doc: options auto update
Jun 20, 2024
53efb67
chore: clamp input real images before psnr to avoid errors
beniz Jun 21, 2024
ec40e13
doc: options auto update
Jun 24, 2024
b68e9dd
fix: inference with images > 8bit and GANs
beniz Jun 21, 2024
5ba0a80
doc: options auto update
Jun 24, 2024
3615d47
feat: max number of visualized images from train/test set
beniz Jun 24, 2024
f106796
doc: options auto update
Jun 26, 2024
56659d8
feat(ml): horse2zebra inference
wr0124 May 24, 2024
7908371
Merge branch 'master' into netD_cls
wr0124 Jun 26, 2024
ac8138e
feat(ml):fix PR
wr0124 Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions data/unaligned_labeled_mask_online_prompt_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

from PIL import Image

from data.base_dataset import get_transform_ref
Expand Down Expand Up @@ -34,14 +35,14 @@ def get_img(
index,
clamp_semantics,
)
# print()

img_path_B = result["B_img_paths"]

real_B_prompt_path = self.B_img_prompt[img_path_B]

if len(real_B_prompt_path) == 1 and isinstance(real_B_prompt_path[0], str):
real_B_prompt = real_B_prompt_path[0]

# print("real_B_prompt=", real_B_prompt)
result.update({"real_B_prompt": real_B_prompt})

return result
73 changes: 59 additions & 14 deletions models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,31 @@ def get_current_APA_prob(self):

return current_APA_prob

def process_prompt(self):
prompt_to_label = {
"driving in foggy weather": 1,
"driving in cloudy weather": 2,
"driving in rainy weather": 3,
"driving in snowy weather": 4,
}
prompt = self.real_B_prompt[0]
if prompt not in prompt_to_label:
raise ValueError(f"Unknown prompt: {prompt}")

D_cls = prompt_to_label[prompt]
batch_size, _, height, width = self.real_B.shape
label_tensor = torch.full(
(batch_size, 1, height, width), D_cls, dtype=self.real_B.dtype
).to(self.device)
real_B_with_label = torch.cat(
(self.real_B.clone().detach(), label_tensor), dim=1
)
fake_B_with_label = torch.cat(
(self.fake_B.clone().detach(), label_tensor), dim=1
)

return real_B_with_label, fake_B_with_label

def compute_D_loss_generic(
self, netD, domain_img, loss, real_name=None, fake_name=None
):
Expand All @@ -345,14 +370,15 @@ def compute_D_loss_generic(
context = ""
if self.opt.data_online_context_pixels > 0:
context = "_with_context"

if fake_name is None:
fake = getattr(self, "fake_" + domain_img + "_pool").query(
getattr(self, "fake_" + domain_img + context + noisy)
)
if self.opt.G_multiprompt:
fake = None
else:
fake = getattr(self, "fake_" + domain_img + "_pool").query(
getattr(self, "fake_" + domain_img + context + noisy)
)
else:
fake = getattr(self, fake_name)

if self.opt.dataaug_APA:
fake_2 = getattr(self, "fake_" + domain_img + "_pool").get_random(
fake.shape[0]
Expand All @@ -372,8 +398,22 @@ def compute_D_loss_generic(
if fake_2 is not None:
fake_2 = fake_2.expand(-1, 3, -1, -1)

with torch.cuda.amp.autocast(enabled=self.with_amp):
loss = loss.compute_loss_D(netD, real, fake, fake_2)
##todo change real fake with D_cls infor
if self.opt.G_multiprompt:
real_B_with_label, fake_B_with_label = self.process_prompt()
fake_B_with_label_query = getattr(
self, "fake_" + domain_img + "_pool"
).query(fake_B_with_label)
with torch.cuda.amp.autocast(enabled=self.with_amp):
loss = loss.compute_loss_D(
netD,
real_B_with_label.to(self.device),
fake_B_with_label_query.to(self.device),
fake_2,
)
else:
with torch.cuda.amp.autocast(enabled=self.with_amp):
loss = loss.compute_loss_D(netD, real, fake, fake_2)
return loss

def compute_D_loss(self):
Expand Down Expand Up @@ -421,7 +461,6 @@ def compute_G_loss_GAN_generic(
context = ""
if self.opt.data_online_context_pixels > 0:
context = "_with_context"

if fake_name is None:
fake = getattr(self, "fake_" + domain_img + context)
else:
Expand All @@ -448,7 +487,17 @@ def compute_G_loss_GAN_generic(
if self.opt.data_image_bits != 8 and type(netD) == ProjectedDiscriminator:
fake = fake.expand(-1, 3, -1, -1)
real = real.expand(-1, 3, -1, -1)
loss = loss.compute_loss_G(netD, real, fake)

##todo change real fake with D_cls infor
if self.opt.G_multiprompt:
real_B_with_label, fake_B_with_label = self.process_prompt()
loss = loss.compute_loss_G(
netD,
real_B_with_label.to(self.device),
fake_B_with_label.to(self.device),
)
else:
loss = loss.compute_loss_G(netD, real, fake)
return loss

def compute_G_loss(self):
Expand All @@ -474,7 +523,6 @@ def compute_G_loss_GAN(self):
else:
fake_name = None
real_name = None

loss_value = self.opt.alg_gan_lambda * self.compute_G_loss_GAN_generic(
netD,
domain,
Expand All @@ -487,7 +535,6 @@ def compute_G_loss_GAN(self):
loss_value = torch.zeros([], device=self.device)

loss_name = "loss_" + discriminator.loss_name_G

setattr(
self,
loss_name,
Expand Down Expand Up @@ -762,7 +809,6 @@ def compute_G_loss_semantic_mask_generic(self, domain_fake):
loss_G_sem_mask = self.opt.train_sem_mask_lambda * self.criterionf_s(
getattr(self, "pred_f_s_fake_%s" % domain_fake), label_fake
)

if self.opt.train_sem_idt:
if self.opt.train_mask_for_removal:
label_idt = torch.zeros_like(self.input_A_label_mask)
Expand All @@ -784,12 +830,11 @@ def compute_G_loss_semantic_mask_generic(self, domain_fake):
not hasattr(self, "loss_f_s")
or self.loss_f_s > self.opt.f_s_semantic_threshold
) and self.opt.f_s_net != "sam":
loss_G_sem_mask = 0 * loss_G_sem_mask
loss_G_sem_mask_ = 0 * loss_G_sem_mask
if self.opt.train_sem_idt:
loss_G_sem_mask_idt = 0 * loss_G_sem_mask_idt

setattr(self, "loss_G_sem_mask_%s" % direction, loss_G_sem_mask)

self.loss_G_tot += loss_G_sem_mask

if self.opt.train_sem_idt:
Expand Down
33 changes: 29 additions & 4 deletions models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
import os

import cv2
import torch

if torch.__version__[0] == "2":
Expand All @@ -23,7 +23,7 @@
from piq import MSID, KID, FID, psnr, ssim
from lpips import LPIPS

from util.util import save_image, tensor2im, delete_flop_param
from util.util import save_image, delete_flop_param, add_text2image, im2tensor


from util.diff_aug import DiffAugment
Expand All @@ -38,7 +38,7 @@
# Iter Calculator
from util.iter_calculator import IterCalculator
from util.network_group import NetworkGroup
from util.util import delete_flop_param, save_image, tensor2im, MAX_INT
from util.util import delete_flop_param, save_image, tensor2im, MAX_INT, im2tensor

from . import base_networks, semantic_networks

Expand Down Expand Up @@ -388,6 +388,7 @@ def set_input(self, data):
self.real_A_with_context = data["A"].to(self.device)
if "real_B_prompt" in data:
self.real_B_prompt = data["real_B_prompt"]

self.real_A = self.real_A_with_context.clone()
if self.opt.data_online_context_pixels > 0:
self.real_A = self.real_A[
Expand Down Expand Up @@ -427,6 +428,7 @@ def set_input(self, data):

if self.opt.train_semantic_mask:
self.set_input_semantic_mask(data)
self.set_input_prompt(data)
if self.opt.train_semantic_cls:
self.set_input_semantic_cls(data)

Expand Down Expand Up @@ -463,6 +465,25 @@ def set_input_semantic_mask(self, data):
self.opt.data_online_context_pixels : -self.opt.data_online_context_pixels,
]

def set_input_prompt(self, data):
if "real_B_prompt_img" in data:
self.real_B_prompt_img = (
data["real_B_prompt_img"].to(self.device).squeeze(1)
)

##add each prompt to batch
processed_images = [
im2tensor(
add_text2image(
tensor2im(data["A"][i].unsqueeze(0)), data["real_B_prompt"][i]
)
)
for i in range(self.opt.train_batch_size)
]
modified_image_batch = torch.stack(processed_images).to(self.device)
self.real_A2B_prompt_img = modified_image_batch
data.update({"real_A2B_prompt_img": modified_image_batch})

def set_input_semantic_cls(self, data):
if "A_label_cls" in data:
if not self.opt.train_cls_regression:
Expand Down Expand Up @@ -1331,7 +1352,11 @@ def compute_CLS_loss(self):

def forward_semantic_mask(self):
d = 1

if self.opt.G_netG == "img2img_turbo":
image_numpy_fake_B = tensor2im(self.fake_B)
image_fakeB_text = add_text2image(image_numpy_fake_B, self.real_B_prompt[0])
image_tensor = im2tensor(image_fakeB_text).unsqueeze(0)
self.fake_A2B_prompt_img = image_tensor
if self.opt.f_s_net == "sam":
self.pred_f_s_real_A = predict_sam(
self.real_A, self.f_s_mg, self.input_A_ref_bbox
Expand Down
10 changes: 10 additions & 0 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ def data_dependent_initialize(self, data):

if self.opt.train_semantic_mask:
self.data_dependent_initialize_semantic_mask(data)
self.data_dependent_initialize_semantic_prompt(data)

def data_dependent_initialize_semantic_mask(self, data):
visual_names_seg_A = ["input_A_label_mask", "gt_pred_f_s_real_A_max", "pfB_max"]
Expand All @@ -534,6 +535,15 @@ def data_dependent_initialize_semantic_mask(self, data):
visual_names_out_mask_A = ["real_A_out_mask", "fake_B_out_mask"]
self.visual_names += [visual_names_out_mask_A]

def data_dependent_initialize_semantic_prompt(self, data):

visual_names_prompt_B = [
"real_B_prompt_img",
"real_A2B_prompt_img",
"fake_A2B_prompt_img",
]
self.visual_names += [visual_names_prompt_B]

def inference(self, nb_imgs, offset=0):
self.real = (
torch.cat((self.real_A, self.real_B), dim=0)
Expand Down
2 changes: 2 additions & 0 deletions models/modules/discriminators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.nn import functional as F

from .utils import spectral_norm, normal_init
from util.util import tensor2im, save_image


class NLayerDiscriminator(nn.Module):
Expand All @@ -30,6 +31,7 @@ def __init__(
use_dropout (bool) -- whether to use dropout layers
use_spectral (bool) -- whether to use spectral norm
"""
input_nc += 1
super(NLayerDiscriminator, self).__init__()
if (
type(norm_layer) == functools.partial
Expand Down
11 changes: 9 additions & 2 deletions models/modules/img2img_turbo/img2img_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,14 @@ def forward(self, x, prompt):
).input_ids.cuda()
caption_enc = self.text_encoder(caption_tokens)[0]

# match batch size
captions_enc = caption_enc.repeat(x.shape[0], 1, 1)
batch_size = caption_enc.shape[0]
repeated_encs = [
caption_enc[i].repeat(int(x.shape[0] / batch_size), 1, 1)
for i in range(caption_enc.shape[0])
]

# Concatenate the repeated encodings along the batch dimension
captions_enc = torch.cat(repeated_encs, dim=0)

# deterministic forward
encoded_control = (
Expand All @@ -223,6 +229,7 @@ def forward(self, x, prompt):
return x

def compute_feats(self, input, extract_layer_ids=[]):

# caption_tokens = self.tokenizer(
# #self.prompt, # XXX: set externally
# prompt,
Expand Down
31 changes: 26 additions & 5 deletions models/modules/projected_d/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def __init__(
super().__init__()

assert num_discs in [1, 2, 3, 4]

# the first disc is on the lowest level of the backbone
self.disc_in_channels = channels[:num_discs]
self.disc_in_res = resolutions[:num_discs]
Expand Down Expand Up @@ -244,7 +243,6 @@ def __init__(
):
super().__init__()
self.interp = interp

self.freeze_feature_network = Proj(
projector_model,
config_path=config_path,
Expand All @@ -255,7 +253,6 @@ def __init__(
**backbone_kwargs,
)
self.freeze_feature_network.requires_grad_(False)

self.discriminator = MultiScaleD(
channels=self.freeze_feature_network.CHANNELS,
resolutions=self.freeze_feature_network.RESOLUTIONS,
Expand All @@ -273,10 +270,34 @@ def eval(self):
return self.train(False)

def forward(self, x):
x_rgb = x[:, :3, :, :]
if self.interp > 0:
x = F.interpolate(x, self.interp, mode="bilinear", align_corners=False)
x_rgb = F.interpolate(
x_rgb, self.interp, mode="bilinear", align_corners=False
)

features = self.freeze_feature_network(x_rgb)

x_fourth_channel = x[:, 3:4, :, :]

# Process and integrate the fourth channel into each feature map in the dictionary
integrated_features = {}
for key, feature in features.items():
if self.interp > 0:
x_fourth_channel = F.interpolate(
x_fourth_channel,
size=feature.shape[2:],
mode="bilinear",
align_corners=False,
)
batch_size, num_features, height, width = feature.shape
x_fourth_channel_expanded = x_fourth_channel.expand(
batch_size, 1, height, width
)
integrated_features[key] = torch.cat(
[feature, x_fourth_channel_expanded], dim=1
)

features = self.freeze_feature_network(x)
logits = self.discriminator(features)

return logits
Expand Down
3 changes: 3 additions & 0 deletions models/modules/vision_aided_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def __init__(
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
)
self.model.cv_ensemble.requires_grad_(False) # freeze feature extractor
self.adapter = nn.Conv2d(in_channels=4, out_channels=3, kernel_size=1)

def forward(self, input):

input = self.adapter(input)
return self.model(input)[0]
3 changes: 3 additions & 0 deletions options/common_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ def initialize(self, parser):
# default="",
# help="Text prompt for G",
# )
parser.add_argument(
"--G_multiprompt", action="store_true", help="activate the multiprompt"
)
parser.add_argument(
"--G_lora_unet",
type=int,
Expand Down
Loading
Loading