Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Virginia Fernandez committed Sep 25, 2024
2 parents d46e6ce + fac4f65 commit ef9d231
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 66 deletions.
12 changes: 6 additions & 6 deletions generation/2d_super_resolution/2d_sd_super_resolution.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@
" rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],\n",
" translate_range=[(-1, 1), (-1, 1)],\n",
" scale_range=[(-0.05, 0.05), (-0.05, 0.05)],\n",
" spatial_size=[image_size, image_size], padding_mode=\"zeros\", prob=0.5),\n",
" spatial_size=[image_size, image_size],\n",
" padding_mode=\"zeros\",\n",
" prob=0.5,\n",
" ),\n",
" transforms.CopyItemsd(keys=[\"image\"], times=1, names=[\"low_res_image\"]),\n",
" transforms.Resized(keys=[\"low_res_image\"], spatial_size=(16, 16)),\n",
"]\n",
Expand Down Expand Up @@ -664,7 +667,7 @@
" f\"epoch {epoch:d}/{max_epochs:d}:\",\n",
" f\"recons loss: {epoch_loss / len(train_loader) :4f},\"\n",
" f\"perc_epoch_loss: {perc_epoch_loss / len(train_loader):4f},\"\n",
" f\"kl_epoch_loss: {kl_epoch_loss / len(train_loader):4f},\"\n",
" f\"kl_epoch_loss: {kl_epoch_loss / len(train_loader):4f},\",\n",
" ]\n",
"\n",
" if epoch > autoencoder_warm_up_n_epochs:\n",
Expand Down Expand Up @@ -1015,10 +1018,7 @@
"\n",
" epoch_loss += loss.item()\n",
"\n",
" msgs = [\n",
" f\"epoch {epoch:d}/{max_epochs:d}:\",\n",
" f\"loss: {epoch_loss / len(train_loader) :4f},\"\n",
" ]\n",
" msgs = [f\"epoch {epoch:d}/{max_epochs:d}:\", f\"loss: {epoch_loss / len(train_loader) :4f},\"]\n",
"\n",
" if epoch % print_interval == 0:\n",
" print(\",\".join(msgs))\n",
Expand Down
113 changes: 54 additions & 59 deletions generation/2d_super_resolution/2d_sd_super_resolution_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@
" [\n",
" transforms.LoadImaged(keys=[\"image\"]),\n",
" transforms.EnsureChannelFirstd(keys=[\"image\"]),\n",
" transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0,\n",
" b_max=1.0, clip=True),\n",
" transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),\n",
" transforms.RandAffined(\n",
" keys=[\"image\"],\n",
" rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],\n",
Expand All @@ -256,8 +255,7 @@
" [\n",
" transforms.LoadImaged(keys=[\"image\"]),\n",
" transforms.EnsureChannelFirstd(keys=[\"image\"]),\n",
" transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0,\n",
" b_max=1.0, clip=True),\n",
" transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),\n",
" transforms.CopyItemsd(keys=[\"image\"], times=1, names=[\"low_res_image\"]),\n",
" transforms.Resized(keys=[\"low_res_image\"], spatial_size=(16, 16)),\n",
" ]\n",
Expand Down Expand Up @@ -293,16 +291,17 @@
" def __init__(self):\n",
" super().__init__()\n",
" self.data_dir = root_dir\n",
" self.autoencoderkl = AutoencoderKL(spatial_dims=2,\n",
" in_channels=1,\n",
" out_channels=1,\n",
" channels=(256, 512, 512),\n",
" latent_channels=3,\n",
" num_res_blocks=2,\n",
" norm_num_groups=32,\n",
" attention_levels=(False, False, True))\n",
" self.discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1,\n",
" num_layers_d=3, channels=64)\n",
" self.autoencoderkl = AutoencoderKL(\n",
" spatial_dims=2,\n",
" in_channels=1,\n",
" out_channels=1,\n",
" channels=(256, 512, 512),\n",
" latent_channels=3,\n",
" num_res_blocks=2,\n",
" norm_num_groups=32,\n",
" attention_levels=(False, False, True),\n",
" )\n",
" self.discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1, num_layers_d=3, channels=64)\n",
" self.perceptual_loss = PerceptualLoss(spatial_dims=2, network_type=\"alex\")\n",
" self.perceptual_weight = 0.002\n",
" self.autoencoder_warm_up_n_epochs = 10\n",
Expand All @@ -318,12 +317,10 @@
" self.train_ds, self.val_ds = get_datasets()\n",
"\n",
" def train_dataloader(self):\n",
" return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True,\n",
" num_workers=4, persistent_workers=True)\n",
" return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True, num_workers=4, persistent_workers=True)\n",
"\n",
" def val_dataloader(self):\n",
" return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False,\n",
" num_workers=4)\n",
" return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False, num_workers=4)\n",
"\n",
" def _compute_loss_generator(self, images, reconstruction, z_mu, z_sigma):\n",
" recons_loss = F.l1_loss(reconstruction.float(), images.float())\n",
Expand Down Expand Up @@ -381,9 +378,9 @@
" def on_validation_epoch_end(self):\n",
" # ploting reconstruction\n",
" plt.figure(figsize=(2, 2))\n",
" plt.imshow(torch.cat([self.images[0, 0].cpu(),\n",
" self.reconstruction[0, 0].cpu()],\n",
" dim=1), vmin=0, vmax=1, cmap=\"gray\")\n",
" plt.imshow(\n",
" torch.cat([self.images[0, 0].cpu(), self.reconstruction[0, 0].cpu()], dim=1), vmin=0, vmax=1, cmap=\"gray\"\n",
" )\n",
" plt.tight_layout()\n",
" plt.axis(\"off\")\n",
" plt.show()\n",
Expand Down Expand Up @@ -658,12 +655,14 @@
"\n",
"\n",
"# initialise Lightning's trainer.\n",
"trainer = pl.Trainer(devices=1,\n",
" max_epochs=max_epochs,\n",
" check_val_every_n_epoch=val_interval,\n",
" num_sanity_val_steps=0,\n",
" callbacks=checkpoint_callback,\n",
" default_root_dir=root_dir)\n",
"trainer = pl.Trainer(\n",
" devices=1,\n",
" max_epochs=max_epochs,\n",
" check_val_every_n_epoch=val_interval,\n",
" num_sanity_val_steps=0,\n",
" callbacks=checkpoint_callback,\n",
" default_root_dir=root_dir,\n",
")\n",
"\n",
"# train\n",
"trainer.fit(ae_net)"
Expand Down Expand Up @@ -741,27 +740,22 @@
" num_head_channels=(0, 0, 64, 64),\n",
" )\n",
" self.max_noise_level = 350\n",
" self.scheduler = DDPMScheduler(num_train_timesteps=1000,\n",
" schedule=\"linear_beta\",\n",
" beta_start=0.0015,\n",
" beta_end=0.0195)\n",
" self.scheduler = DDPMScheduler(\n",
" num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195\n",
" )\n",
" self.z = ae_net.autoencoderkl.eval()\n",
"\n",
" def forward(self, x, timesteps, low_res_timesteps):\n",
" return self.unet(x=x,\n",
" timesteps=timesteps,\n",
" class_labels=low_res_timesteps)\n",
" return self.unet(x=x, timesteps=timesteps, class_labels=low_res_timesteps)\n",
"\n",
" def prepare_data(self):\n",
" self.train_ds, self.val_ds = get_datasets()\n",
"\n",
" def train_dataloader(self):\n",
" return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True,\n",
" num_workers=4, persistent_workers=True)\n",
" return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True, num_workers=4, persistent_workers=True)\n",
"\n",
" def val_dataloader(self):\n",
" return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False,\n",
" num_workers=4)\n",
" return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False, num_workers=4)\n",
"\n",
" def _calculate_loss(self, batch, batch_idx, plt_image=False):\n",
" images = batch[\"image\"]\n",
Expand All @@ -773,17 +767,16 @@
" # Noise augmentation\n",
" noise = torch.randn_like(latent)\n",
" low_res_noise = torch.randn_like(low_res_image)\n",
" timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (latent.shape[0],),\n",
" device=latent.device).long()\n",
" timesteps = torch.randint(\n",
" 0, self.scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device\n",
" ).long()\n",
" low_res_timesteps = torch.randint(\n",
" 0, self.max_noise_level, (low_res_image.shape[0],), device=latent.device\n",
" ).long()\n",
"\n",
" noisy_latent = self.scheduler.add_noise(original_samples=latent,\n",
" noise=noise, timesteps=timesteps)\n",
" noisy_latent = self.scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)\n",
" noisy_low_res_image = self.scheduler.add_noise(\n",
" original_samples=low_res_image, noise=low_res_noise,\n",
" timesteps=low_res_timesteps\n",
" original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps\n",
" )\n",
"\n",
" latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)\n",
Expand All @@ -809,9 +802,9 @@
" with autocast(\"cuda\", enabled=True):\n",
" with torch.no_grad():\n",
" latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)\n",
" noise_pred = self.forward(latent_model_input,\n",
" torch.Tensor((t,)).to(sampling_image.device)\n",
" , noise_level)\n",
" noise_pred = self.forward(\n",
" latent_model_input, torch.Tensor((t,)).to(sampling_image.device), noise_level\n",
" )\n",
" latents, _ = self.scheduler.step(noise_pred, t, latents)\n",
" with torch.no_grad():\n",
" decoded = self.z.decode_stage_2_outputs(latents / scale_factor)\n",
Expand Down Expand Up @@ -1152,12 +1145,14 @@
"\n",
"\n",
"# initialise Lightning's trainer.\n",
"trainer = pl.Trainer(devices=1,\n",
" max_epochs=max_epochs,\n",
" check_val_every_n_epoch=val_interval,\n",
" num_sanity_val_steps=0,\n",
" callbacks=checkpoint_callback,\n",
" default_root_dir=root_dir)\n",
"trainer = pl.Trainer(\n",
" devices=1,\n",
" max_epochs=max_epochs,\n",
" check_val_every_n_epoch=val_interval,\n",
" num_sanity_val_steps=0,\n",
" callbacks=checkpoint_callback,\n",
" default_root_dir=root_dir,\n",
")\n",
"\n",
"# train\n",
"trainer.fit(d_net)"
Expand Down Expand Up @@ -1219,18 +1214,18 @@
" noise_level = 10\n",
" noise_level = torch.Tensor((noise_level,)).long().to(d_net.device)\n",
" scheduler = d_net.scheduler\n",
" noisy_low_res_image = scheduler.add_noise(original_samples=sampling_image,\n",
" noise=low_res_noise,\n",
" timesteps=torch.Tensor((noise_level,)).long())\n",
" noisy_low_res_image = scheduler.add_noise(\n",
" original_samples=sampling_image, noise=low_res_noise, timesteps=torch.Tensor((noise_level,)).long()\n",
" )\n",
"\n",
" scheduler.set_timesteps(num_inference_steps=1000)\n",
" for t in tqdm(scheduler.timesteps, ncols=110):\n",
" with autocast(\"cuda\", enabled=True):\n",
" with torch.no_grad():\n",
" latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)\n",
" noise_pred = d_net.forward(x=latent_model_input,\n",
" timesteps=torch.Tensor((t,)).to(d_net.device),\n",
" low_res_timesteps=noise_level)\n",
" noise_pred = d_net.forward(\n",
" x=latent_model_input, timesteps=torch.Tensor((t,)).to(d_net.device), low_res_timesteps=noise_level\n",
" )\n",
" # 2. compute previous image: x_t -> x_t-1\n",
" latents, _ = scheduler.step(noise_pred, t, latents)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@ Example shows the use cases of applying a spatial VAE to a 3D synthesis example.
Examples show how to perform anomaly detection in 2D, using implicit guidance [2D-classifier free guiance](./anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.ipynb), transformers [using transformers](./anomaly_detection/anomaly_detection_with_transformers.ipynb) and [classifier free guidance](./anomalydetection_tutorial_classifier_guidance).

## 2D super-resolution using diffusion models: [using torch](./2d_super_resolution/2d_sd_super_resolution.ipynb) and [using torch lightning](./2d_super_resolution/2d_sd_super_resolution_lightning.ipynb).
Examples show how to perform super-resolution in 2D, using PyTorch and PyTorch Lightning.
Examples show how to perform super-resolution in 2D, using PyTorch and PyTorch Lightning.

0 comments on commit ef9d231

Please sign in to comment.