diff --git a/generation/2d_super_resolution/2d_sd_super_resolution.ipynb b/generation/2d_super_resolution/2d_sd_super_resolution.ipynb index 03c07ddb6..7933c94fd 100644 --- a/generation/2d_super_resolution/2d_sd_super_resolution.ipynb +++ b/generation/2d_super_resolution/2d_sd_super_resolution.ipynb @@ -243,7 +243,10 @@ " rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],\n", " translate_range=[(-1, 1), (-1, 1)],\n", " scale_range=[(-0.05, 0.05), (-0.05, 0.05)],\n", - " spatial_size=[image_size, image_size], padding_mode=\"zeros\", prob=0.5),\n", + " spatial_size=[image_size, image_size],\n", + " padding_mode=\"zeros\",\n", + " prob=0.5,\n", + " ),\n", " transforms.CopyItemsd(keys=[\"image\"], times=1, names=[\"low_res_image\"]),\n", " transforms.Resized(keys=[\"low_res_image\"], spatial_size=(16, 16)),\n", "]\n", @@ -664,7 +667,7 @@ " f\"epoch {epoch:d}/{max_epochs:d}:\",\n", " f\"recons loss: {epoch_loss / len(train_loader) :4f},\"\n", " f\"perc_epoch_loss: {perc_epoch_loss / len(train_loader):4f},\"\n", - " f\"kl_epoch_loss: {kl_epoch_loss / len(train_loader):4f},\"\n", + " f\"kl_epoch_loss: {kl_epoch_loss / len(train_loader):4f},\",\n", " ]\n", "\n", " if epoch > autoencoder_warm_up_n_epochs:\n", @@ -1015,10 +1018,7 @@ "\n", " epoch_loss += loss.item()\n", "\n", - " msgs = [\n", - " f\"epoch {epoch:d}/{max_epochs:d}:\",\n", - " f\"loss: {epoch_loss / len(train_loader) :4f},\"\n", - " ]\n", + " msgs = [f\"epoch {epoch:d}/{max_epochs:d}:\", f\"loss: {epoch_loss / len(train_loader) :4f},\"]\n", "\n", " if epoch % print_interval == 0:\n", " print(\",\".join(msgs))\n", diff --git a/generation/2d_super_resolution/2d_sd_super_resolution_lightning.ipynb b/generation/2d_super_resolution/2d_sd_super_resolution_lightning.ipynb index fde5a8c65..9f64109c3 100644 --- a/generation/2d_super_resolution/2d_sd_super_resolution_lightning.ipynb +++ b/generation/2d_super_resolution/2d_sd_super_resolution_lightning.ipynb @@ -233,8 +233,7 @@ " [\n", " transforms.LoadImaged(keys=[\"image\"]),\n", " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", - " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0,\n", - " b_max=1.0, clip=True),\n", + " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),\n", " transforms.RandAffined(\n", " keys=[\"image\"],\n", " rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],\n", @@ -256,8 +255,7 @@ " [\n", " transforms.LoadImaged(keys=[\"image\"]),\n", " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", - " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0,\n", - " b_max=1.0, clip=True),\n", + " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),\n", " transforms.CopyItemsd(keys=[\"image\"], times=1, names=[\"low_res_image\"]),\n", " transforms.Resized(keys=[\"low_res_image\"], spatial_size=(16, 16)),\n", " ]\n", @@ -293,16 +291,17 @@ " def __init__(self):\n", " super().__init__()\n", " self.data_dir = root_dir\n", - " self.autoencoderkl = AutoencoderKL(spatial_dims=2,\n", - " in_channels=1,\n", - " out_channels=1,\n", - " channels=(256, 512, 512),\n", - " latent_channels=3,\n", - " num_res_blocks=2,\n", - " norm_num_groups=32,\n", - " attention_levels=(False, False, True))\n", - " self.discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1,\n", - " num_layers_d=3, channels=64)\n", + " self.autoencoderkl = AutoencoderKL(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " channels=(256, 512, 512),\n", + " latent_channels=3,\n", + " num_res_blocks=2,\n", + " norm_num_groups=32,\n", + " attention_levels=(False, False, True),\n", + " )\n", + " self.discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1, num_layers_d=3, channels=64)\n", " self.perceptual_loss = PerceptualLoss(spatial_dims=2, network_type=\"alex\")\n", " self.perceptual_weight = 0.002\n", " self.autoencoder_warm_up_n_epochs = 10\n", @@ -318,12 +317,10 @@ " self.train_ds, self.val_ds = get_datasets()\n", "\n", " def train_dataloader(self):\n", - " return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True,\n", - " num_workers=4, persistent_workers=True)\n", + " return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True, num_workers=4, persistent_workers=True)\n", "\n", " def val_dataloader(self):\n", - " return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False,\n", - " num_workers=4)\n", + " return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False, num_workers=4)\n", "\n", " def _compute_loss_generator(self, images, reconstruction, z_mu, z_sigma):\n", " recons_loss = F.l1_loss(reconstruction.float(), images.float())\n", @@ -381,9 +378,9 @@ " def on_validation_epoch_end(self):\n", " # ploting reconstruction\n", " plt.figure(figsize=(2, 2))\n", - " plt.imshow(torch.cat([self.images[0, 0].cpu(),\n", - " self.reconstruction[0, 0].cpu()],\n", - " dim=1), vmin=0, vmax=1, cmap=\"gray\")\n", + " plt.imshow(\n", + " torch.cat([self.images[0, 0].cpu(), self.reconstruction[0, 0].cpu()], dim=1), vmin=0, vmax=1, cmap=\"gray\"\n", + " )\n", " plt.tight_layout()\n", " plt.axis(\"off\")\n", " plt.show()\n", @@ -658,12 +655,14 @@ "\n", "\n", "# initialise Lightning's trainer.\n", - "trainer = pl.Trainer(devices=1,\n", - " max_epochs=max_epochs,\n", - " check_val_every_n_epoch=val_interval,\n", - " num_sanity_val_steps=0,\n", - " callbacks=checkpoint_callback,\n", - " default_root_dir=root_dir)\n", + "trainer = pl.Trainer(\n", + " devices=1,\n", + " max_epochs=max_epochs,\n", + " check_val_every_n_epoch=val_interval,\n", + " num_sanity_val_steps=0,\n", + " callbacks=checkpoint_callback,\n", + " default_root_dir=root_dir,\n", + ")\n", "\n", "# train\n", "trainer.fit(ae_net)" @@ -741,27 +740,22 @@ " num_head_channels=(0, 0, 64, 64),\n", " )\n", " self.max_noise_level = 350\n", - " self.scheduler = DDPMScheduler(num_train_timesteps=1000,\n", - " schedule=\"linear_beta\",\n", - " beta_start=0.0015,\n", - " beta_end=0.0195)\n", + " self.scheduler = DDPMScheduler(\n", + " num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195\n", + " )\n", " self.z = ae_net.autoencoderkl.eval()\n", "\n", " def forward(self, x, timesteps, low_res_timesteps):\n", - " return self.unet(x=x,\n", - " timesteps=timesteps,\n", - " class_labels=low_res_timesteps)\n", + " return self.unet(x=x, timesteps=timesteps, class_labels=low_res_timesteps)\n", "\n", " def prepare_data(self):\n", " self.train_ds, self.val_ds = get_datasets()\n", "\n", " def train_dataloader(self):\n", - " return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True,\n", - " num_workers=4, persistent_workers=True)\n", + " return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True, num_workers=4, persistent_workers=True)\n", "\n", " def val_dataloader(self):\n", - " return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False,\n", - " num_workers=4)\n", + " return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False, num_workers=4)\n", "\n", " def _calculate_loss(self, batch, batch_idx, plt_image=False):\n", " images = batch[\"image\"]\n", @@ -773,17 +767,16 @@ " # Noise augmentation\n", " noise = torch.randn_like(latent)\n", " low_res_noise = torch.randn_like(low_res_image)\n", - " timesteps = torch.randint(0, self.scheduler.num_train_timesteps, (latent.shape[0],),\n", - " device=latent.device).long()\n", + " timesteps = torch.randint(\n", + " 0, self.scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device\n", + " ).long()\n", " low_res_timesteps = torch.randint(\n", " 0, self.max_noise_level, (low_res_image.shape[0],), device=latent.device\n", " ).long()\n", "\n", - " noisy_latent = self.scheduler.add_noise(original_samples=latent,\n", - " noise=noise, timesteps=timesteps)\n", + " noisy_latent = self.scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)\n", " noisy_low_res_image = self.scheduler.add_noise(\n", - " original_samples=low_res_image, noise=low_res_noise,\n", - " timesteps=low_res_timesteps\n", + " original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps\n", " )\n", "\n", " latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)\n", @@ -809,9 +802,9 @@ " with autocast(\"cuda\", enabled=True):\n", " with torch.no_grad():\n", " latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)\n", - " noise_pred = self.forward(latent_model_input,\n", - " torch.Tensor((t,)).to(sampling_image.device)\n", - " , noise_level)\n", + " noise_pred = self.forward(\n", + " latent_model_input, torch.Tensor((t,)).to(sampling_image.device), noise_level\n", + " )\n", " latents, _ = self.scheduler.step(noise_pred, t, latents)\n", " with torch.no_grad():\n", " decoded = self.z.decode_stage_2_outputs(latents / scale_factor)\n", @@ -1152,12 +1145,14 @@ "\n", "\n", "# initialise Lightning's trainer.\n", - "trainer = pl.Trainer(devices=1,\n", - " max_epochs=max_epochs,\n", - " check_val_every_n_epoch=val_interval,\n", - " num_sanity_val_steps=0,\n", - " callbacks=checkpoint_callback,\n", - " default_root_dir=root_dir)\n", + "trainer = pl.Trainer(\n", + " devices=1,\n", + " max_epochs=max_epochs,\n", + " check_val_every_n_epoch=val_interval,\n", + " num_sanity_val_steps=0,\n", + " callbacks=checkpoint_callback,\n", + " default_root_dir=root_dir,\n", + ")\n", "\n", "# train\n", "trainer.fit(d_net)" @@ -1219,18 +1214,18 @@ " noise_level = 10\n", " noise_level = torch.Tensor((noise_level,)).long().to(d_net.device)\n", " scheduler = d_net.scheduler\n", - " noisy_low_res_image = scheduler.add_noise(original_samples=sampling_image,\n", - " noise=low_res_noise,\n", - " timesteps=torch.Tensor((noise_level,)).long())\n", + " noisy_low_res_image = scheduler.add_noise(\n", + " original_samples=sampling_image, noise=low_res_noise, timesteps=torch.Tensor((noise_level,)).long()\n", + " )\n", "\n", " scheduler.set_timesteps(num_inference_steps=1000)\n", " for t in tqdm(scheduler.timesteps, ncols=110):\n", " with autocast(\"cuda\", enabled=True):\n", " with torch.no_grad():\n", " latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)\n", - " noise_pred = d_net.forward(x=latent_model_input,\n", - " timesteps=torch.Tensor((t,)).to(d_net.device),\n", - " low_res_timesteps=noise_level)\n", + " noise_pred = d_net.forward(\n", + " x=latent_model_input, timesteps=torch.Tensor((t,)).to(d_net.device), low_res_timesteps=noise_level\n", + " )\n", " # 2. compute previous image: x_t -> x_t-1\n", " latents, _ = scheduler.step(noise_pred, t, latents)\n", "\n", diff --git a/generation/README.md b/generation/README.md index cf7822310..351416fd1 100644 --- a/generation/README.md +++ b/generation/README.md @@ -74,4 +74,4 @@ Example shows the use cases of applying a spatial VAE to a 3D synthesis example. Examples show how to perform anomaly detection in 2D, using implicit guidance [2D-classifier free guiance](./anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.ipynb), transformers [using transformers](./anomaly_detection/anomaly_detection_with_transformers.ipynb) and [classifier free guidance](./anomalydetection_tutorial_classifier_guidance). ## 2D super-resolution using diffusion models: [using torch](./2d_super_resolution/2d_sd_super_resolution.ipynb) and [using torch lightning](./2d_super_resolution/2d_sd_super_resolution_lightning.ipynb). -Examples show how to perform super-resolution in 2D, using PyTorch and PyTorch Lightning. \ No newline at end of file +Examples show how to perform super-resolution in 2D, using PyTorch and PyTorch Lightning.