diff --git a/generation/2d_super_resolution/2d_sd_super_resolution.ipynb b/generation/2d_super_resolution/2d_sd_super_resolution.ipynb new file mode 100644 index 000000000..15d111bc4 --- /dev/null +++ b/generation/2d_super_resolution/2d_sd_super_resolution.ipynb @@ -0,0 +1,1249 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "33a9aedb-b4d8-48c6-9590-58b221405ca5", + "metadata": {}, + "source": [ + "Copyright (c) MONAI Consortium
\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");
\n", + "you may not use this file except in compliance with the License.
\n", + "You may obtain a copy of the License at
\n", + "http://www.apache.org/licenses/LICENSE-2.0
\n", + "Unless required by applicable law or agreed to in writing, software
\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,
\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
\n", + "See the License for the specific language governing permissions and
\n", + "limitations under the License.
" + ] + }, + { + "cell_type": "markdown", + "id": "95c08725", + "metadata": {}, + "source": [ + "# Super-resolution using Stable Diffusion v2 Upscalers\n", + "\n", + "This tutorial illustrates how to perform **super-resolution** on medical images using Latent Diffusion Models (LDMs) [1]. The idea is that, given a low-resolution image, we train a spatial autoencoder with a latent space of the same spatial size of the low resolution, so that high resolution images are encoded into a latent space of the same size of the low resolution image. The LDM then learns how to go from **noise to a latent representation of a high resolution image**. On training and inference, the **low resolution image is concatenated to the latent**, to condition the generative process. Finally, the high resolution latent representation is decoded into a high resolution image. \n", + "\n", + "To improve the performance of our models, we will use a method called \"noise conditioning augmentation\" (introduced in [2] and used in Stable Diffusion v2.0 and Imagen Video [3]). During the training, we add noise to the low-resolution images using a random signal-to-noise ratio, and we condition the diffusion models on the amount of noise added. At sampling time, we use a fixed signal-to-noise ratio, representing a small amount of augmentation that aids in removing artefacts in the samples.\n", + "\n", + "\n", + "[1] - Rombach et al. \"High-Resolution Image Synthesis with Latent Diffusion Models\" https://arxiv.org/abs/2112.10752\n", + "\n", + "[2] - Ho et al. \"Cascaded diffusion models for high fidelity image generation\" https://arxiv.org/abs/2106.15282\n", + "\n", + "[3] - Ho et al. \"High Definition Video Generation with Diffusion Models\" https://arxiv.org/abs/2210.02303" + ] + }, + { + "cell_type": "markdown", + "id": "b839bf2d", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "77f7e633", + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\"\n", + "!python -c \"import matplotlib\" || pip install -q matplotlib\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "id": "214066de", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de71fe08", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "import tempfile\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from monai import transforms\n", + "from monai.apps import MedNISTDataset\n", + "from monai.config import print_config\n", + "from monai.data import CacheDataset, DataLoader\n", + "from monai.utils import first, set_determinism\n", + "from torch import nn\n", + "from torch.amp import GradScaler, autocast\n", + "from tqdm import tqdm\n", + "from monai.losses import PatchAdversarialLoss, PerceptualLoss\n", + "from monai.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator\n", + "from monai.networks.schedulers import DDPMScheduler\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "id": "c0dde922", + "metadata": {}, + "source": [ + "## Setup a data directory and download dataset\n", + "Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ded618a7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/tmpj53lse09\n" + ] + } + ], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", + "print(root_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "645f97bb-6879-4b2e-8fc9-29dd1a6e904f", + "metadata": {}, + "source": [ + "## Set deterministic training for reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9f0a17bc", + "metadata": {}, + "outputs": [], + "source": [ + "# for reproducibility purposes set a seed\n", + "set_determinism(42)" + ] + }, + { + "cell_type": "markdown", + "id": "d80e045b", + "metadata": {}, + "source": [ + "## Description of data and download the training set\n", + "\n", + "For this tutorial, we use the head CT dataset from MedNIST." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c8cf204a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-23 09:27:05,757 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n", + "2024-09-23 09:27:05,758 - INFO - File exists: /tmp/tmpj53lse09/MedNIST.tar.gz, skipped downloading.\n", + "2024-09-23 09:27:05,759 - INFO - Non-empty folder exists in /tmp/tmpj53lse09/MedNIST, skipped extracting.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47164/47164 [00:16<00:00, 2923.68it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-23 09:27:22,258 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n", + "2024-09-23 09:27:22,258 - INFO - File exists: /tmp/tmpj53lse09/MedNIST.tar.gz, skipped downloading.\n", + "2024-09-23 09:27:22,259 - INFO - Non-empty folder exists in /tmp/tmpj53lse09/MedNIST, skipped extracting.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5895/5895 [00:01<00:00, 2964.04it/s]\n" + ] + } + ], + "source": [ + "train_data = MedNISTDataset(root_dir=root_dir, section=\"training\", download=True, seed=0)\n", + "train_datalist = [{\"image\": item[\"image\"]} for item in train_data.data if item[\"class_name\"] == \"HeadCT\"]\n", + "val_data = MedNISTDataset(root_dir=root_dir, section=\"validation\", download=True, seed=0)\n", + "val_datalist = [{\"image\": item[\"image\"]} for item in val_data.data if item[\"class_name\"] == \"HeadCT\"]" + ] + }, + { + "cell_type": "markdown", + "id": "cacdb233", + "metadata": {}, + "source": [ + "## Prepare dataloaders\n", + "\n", + "Here, we create the data loader that we will use to train our models. We will use data augmentation and create low-resolution images using MONAI's transformations:\n", + "\n", + "1. `LoadImaged`: to load the images\n", + "2. `EnsureChannelFirstd`: to make sure there is a channel dimension at the beginning of the output tensor\n", + "3. `ScaleIntensityRanged`: normalise the images\n", + "4. `RandAffined`: affine augmentation (just training)\n", + "5. `CopyItemd`: we copy the image item to obtain the low-resolution representation\n", + "6. `Resized`: we resize the low resolution image (copy we just made) to obtain a low resolution representation to 16x16" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "c7997edf", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7991/7991 [00:05<00:00, 1544.42it/s]\n", + "Loading dataset: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 972/972 [00:01<00:00, 804.53it/s]\n" + ] + } + ], + "source": [ + "image_size = 64\n", + "\n", + "# Transforms\n", + "all_transforms = [\n", + " transforms.LoadImaged(keys=[\"image\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", + " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),\n", + " transforms.RandAffined(\n", + " keys=[\"image\"],\n", + " rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],\n", + " translate_range=[(-1, 1), (-1, 1)],\n", + " scale_range=[(-0.05, 0.05), (-0.05, 0.05)],\n", + " spatial_size=[image_size, image_size],\n", + " padding_mode=\"zeros\",\n", + " prob=0.5,\n", + " ),\n", + " transforms.CopyItemsd(keys=[\"image\"], times=1, names=[\"low_res_image\"]),\n", + " transforms.Resized(keys=[\"low_res_image\"], spatial_size=(16, 16)),\n", + "]\n", + "\n", + "train_transforms = transforms.Compose(all_transforms)\n", + "val_transforms = transforms.Compose(all_transforms[:3] + all_transforms[4:])\n", + "\n", + "# Datasets\n", + "train_ds = CacheDataset(data=train_datalist, transform=train_transforms)\n", + "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=4, persistent_workers=True)\n", + "val_ds = CacheDataset(data=val_datalist, transform=val_transforms)\n", + "val_loader = DataLoader(val_ds, batch_size=32, shuffle=True, num_workers=4)" + ] + }, + { + "cell_type": "markdown", + "id": "166e4242", + "metadata": {}, + "source": [ + "### Visualise examples from the training set" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "8c0fe41c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot 3 examples from the training set\n", + "check_data = first(train_loader)\n", + "fig, ax = plt.subplots(nrows=1, ncols=3)\n", + "for i in range(3):\n", + " ax[i].imshow(check_data[\"image\"][i, 0, :, :], cmap=\"gray\")\n", + " ax[i].axis(\"off\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "76412555", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAClCAYAAADBAf6NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAMcUlEQVR4nO3cS4iW5d8H8GscZxpNnZnUUVJHyyJNsEUuOieUhAuTDkRBB3BVUAS1aRUELVqFLSJatEiCoiKEwE7mtCjUFiHhAQopNdE0z47j6IzzX7ybd/Hn/f2k+30c5/p81l/u+5rncM+XZ/FtGxsbGysAQLUmXekDAABXljIAAJVTBgCgcsoAAFROGQCAyikDAFA5ZQAAKqcMAEDlJmeDbW1tjWQmTYr7R1OZ8XimTKaVZ87eL+Ps2bNhZnBwsJF7XY6urq4w09HREWYy+1xNZZq+Vqvu1cozZ6916dKllt2radnvMPxfMt8BnzQAqJwyAACVUwYAoHLKAABUThkAgMopAwBQOWUAACqnDABA5dKjQ319ff+f52ACGBkZudJH+K86OzvDTGaYiNbJDvzMmzevkWsdOHAgdT+YqPwyAACVUwYAoHLKAABUThkAgMopAwBQOWUAACqnDABA5ZQBAKhcenSIZkybNi3MXLhwoZFMq7W1tV3pI/xX4/VcTciO87RK5rXOnDn7nh08eDDMzJo1K8wsW7YsdT9aa7x9vlutlc8uvwwAQOWUAQConDIAAJVTBgCgcsoAAFROGQCAyikDAFA5ZQAAKnfVjg5lxxiaGkFpb28PM/Pnzw8zmbGgQ4cOhZnxaNKk8dktr9bRocxnLvO3Xbp0qZHrjIyMhJnxOBJz7NixMHPkyJEWnIT/LfNZWbRoUZh5/vnnw8xNN90UZo4ePRpmNm7cGGY2b94cZkopZXR0NJVrlfH59AYAWkYZAIDKKQMAUDllAAAqpwwAQOWUAQConDIAAJVTBgCgcm1jyZWQvr6++GINjbtkrpO91+LFi8PMzTffHGbOnj0bZq677row093dHWZuvPHGMLNp06Yws3379jBTSnNDMefOnQszmdexab29vWGmo6OjBSf5H1OmTEnlXn311TCT+axs2bIlzPz1119hZubMmWEm49tvvw0zg4ODqWs1NdySGVQ6ceJEI/e6HE0NeWWel9dff32YefHFF1P327BhQ5h56aWXwsx9990XZnbt2hVmzp8/H2Yyz4AbbrihkeuUUsqjjz4aZvbt2xdmmhoe88sAAFROGQCAyikDAFA5ZQAAKqcMAEDllAEAqJwyAACVUwYAoHLKAABUbnI2mFkwam9vDzOZtbvMdTJrf6XkFqyWLl0aZk6ePBlm9uzZE2aGhobCzKpVq8LM008/HWZef/31MFNKKZ988kmYSS1YNbSW1rSmljEzf19mpW/58uWp+61du7aR+/X09ISZzOLfwYMHw0xXV1cj5/n888/DTCm51cuLFy+Gmczne6L77rvvwszp06dT11qzZk2YyawCfvHFF2Hm+PHjYSbzOens7AwzmbXDhQsXhplSShkYGAgzmf9Nw8PDqftFxufTGwBoGWUAACqnDABA5ZQBAKicMgAAlVMGAKByygAAVE4ZAIDKpUeHMmNBGZkBmMy4y4IFC1L3yww7ZO63Y8eOMJMZQnrsscfCTHd3d5jZu3dvmFm3bl2YKaWUrVu3hpk//vgjzIzX0aGMjo6OMPPMM8+Emcz7O3369NSZNm3aFGbuvvvuMJMZbslc54knnggz69evDzNLliwJM48//niYKaWUjRs3hpl//vknda2J7J577gkzr7zySpg5evRo6n4ff/xxmNm5c2eYyYxhZcaLMmNvv/32W5jJDBNlzlNK7n/Yk08+GWY+/PDD1P0iV+/TGwBohDIAAJVTBgCgcsoAAFROGQCAyikDAFA5ZQAAKqcMAEDl0qNDly5dauSGM2bMCDOrVq0KM11dXan7TZ06NczMmTMnzAwPD4eZ5557Lsz09/eHma+//jrMZMYo3n333TBTSikrV64MM/v37w8zmUGpKyFzrswoy2uvvRZmjh8/HmZOnToVZkrJjaBMmzYtzBw7dizM3HnnnWFmy5YtYSYzXvTVV1+Fmfnz54eZUnIDL5nXaLwOE2UGqubOnRtmHnnkkTCzbdu2MPPDDz+EmVJKWbNmTZjJjFhl/he89dZbYaa3tzfMZEbj+vr6wkz2OZh5VqxevTrMbNiwIXW/iF8GAKByygAAVE4ZAIDKKQMAUDllAAAqpwwAQOWUAQConDIAAJVLjw6Njo6GmczYwm233RZm3nnnnTDzyy+/hJlSSlm6dGmYefnll8PMrl27wkxmmGf37t1hZmBgIMxkRmJOnjwZZkoppbu7O8y0t7eHmUmTrt5uOXPmzDCTeZ127twZZrLvy+DgYJj59NNPw8yZM2fCTGZULDPM09HREWYOHz7cyL1KKWXy5PgRNjIyEmbG62DW0NBQmNm3b1+YyTybli9fHmYOHDgQZkrJPQsymaNHj4aZsbGxRjKZ/3GZEajMs7KU3PMk8/5n/raMq/fpDQA0QhkAgMopAwBQOWUAACqnDABA5ZQBAKicMgAAlVMGAKBy6dGhzLBBZkRi9uzZYSYzAHLu3LkwU0ophw4dCjNLliwJMz/++GOYyQylvPDCC2Fm9erVYWbx4sVh5qGHHgozpZTy/vvvh5nMIMd4HR3KnCszBvX777+HmRUrVoSZ8+fPh5lScgNG/f39YSYzgrJ3794ws27dujDz3nvvhZk5c+aEmcw4WSmlbN26NcxcuHAhzGRGl66EixcvNnKdzOc7M7521113pe63cuXKMLN+/fowk3nuLliwIMxkhqdOnDgRZvbs2RNmbrnlljBTSm4w6/vvvw8zTQ1mjc+nNwDQMsoAAFROGQCAyikDAFA5ZQAAKqcMAEDllAEAqJwyAACVaxvLrAmVUq699towM23atDDT29sbZp566qkw09PTE2ZKKaWzszPMdHV1hZldu3aFmV9//TXM3HHHHWEmM5Ly5ptvhpkPPvggzJSSG//InOnMmTNh5siRI5kjNaqvr6+R6yxbtizMrF27NsxMnTq1ieOUUkp5+OGHw0xmeCszXDJjxowwkxkdWrhwYZj56KOPwkwppRw7dizMnDp1KswMDg42kmlaU0NemescOHAgzGQ/u5nnZeb1HBgYCDOnT58OM2fPng0zmdeou7s7zGQG4UopZc2aNWHm1ltvDTNDQ0NhJjOq5ZcBAKicMgAAlVMGAKByygAAVE4ZAIDKKQMAUDllAAAqpwwAQOXSo0OZsYnp06f/6wOVUkpHR0eYyQwclVLKrFmzwsyDDz4YZqZMmRJmdu/eHWYyQxv33ntvmMmMrXz55ZdhppRSTp48GWbOnTuXulbk77//buQ6l6Op0aGmZMZ7Sinl7bffDjOLFi0KM5n3d9OmTWEm89498MADYWbz5s1h5qeffgozpeTGZEZGRsJM5nvZ1HfgcjQ1OpQZlco8426//fbU/f78888w88Ybb4SZFStWhJlt27aFmfPnz4eZ0dHRMNPf3x9msqNDmXG9PXv2pK4VMToEAISUAQConDIAAJVTBgCgcsoAAFROGQCAyikDAFA5ZQAAKpceHerq6goz3d3d//pApeQGMtrb21PXmj9/fphZtWpVmMkMjuzYsSPM9PT0hJnMQMTPP/8cZjJjK6XkXu+mTPTRoczXKTskkxnMmjdvXpg5ceJEmMmMeGWGWw4dOhRmWj3ek3lPMuNFQ0NDTRznsjQ1OjQeZcbl7r///jDz7LPPhpm5c+eGmcOHD4eZb775Jsx89tlnYaaUUoaHh8NMU89mo0MAQEgZAIDKKQMAUDllAAAqpwwAQOWUAQConDIAAJVTBgCgcsoAAFQuvUDY2dkZZnp7e8NM5naZ1aUmV/Myf1t28TCSWXI8fvx4mGnqdWy1ib5A2KSm3uOJep2szLWuueaaMLN///4mjnNZJvIC4dX4eWryXq08twVCACCkDABA5ZQBAKicMgAAlVMGAKByygAAVE4ZAIDKKQMAULnJ2WBmtCCjleMPWcPDw2GmqdGhoaGhRq4zHgeFaFZT7/F4u06T9+rv7w8zPT09YWb79u2p+9Gcify5bPW1muCXAQConDIAAJVTBgCgcsoAAFROGQCAyikDAFA5ZQAAKqcMAEDl2saS6z2TJsW9Yfbs2f/6QFnZ0aFMLjOo1FRmvJ0nm2vq3KOjo6kzNWn69OlhJjMA0lQmq5VnauXf1uS9MmNgIyMjYebUqVNhJjNO1rTMcxcimWezTxoAVE4ZAIDKKQMAUDllAAAqpwwAQOWUAQConDIAAJVTBgCgcunRIQBgYvLLAABUThkAgMopAwBQOWUAACqnDABA5ZQBAKicMgAAlVMGAKByygAAVO4/gaFP9p1bq/QAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Plot 3 examples from the training set in low resolution\n", + "fig, ax = plt.subplots(nrows=1, ncols=3)\n", + "for i in range(3):\n", + " ax[i].imshow(check_data[\"low_res_image\"][i, 0, :, :], cmap=\"gray\")\n", + " ax[i].axis(\"off\")" + ] + }, + { + "cell_type": "markdown", + "id": "9fc99896", + "metadata": {}, + "source": [ + "## Define the autoencoder network and training components" + ] + }, + { + "cell_type": "markdown", + "id": "9b52c4a7-26eb-47e7-8aac-99c62ca88ee3", + "metadata": {}, + "source": [ + "To yield a 16x16 latent representation from the high-resolution images, we use AutoencoderKL. We train it using a Patch-GAN adversarial loss, as well as a perceptual loss, to boost image fidelity." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "610bd118", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cuda\n" + ] + } + ], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Using {device}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "0e4ef480", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [], + "source": [ + "autoencoderkl = AutoencoderKL(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " channels=(256, 512, 512),\n", + " latent_channels=3,\n", + " num_res_blocks=2,\n", + " norm_num_groups=32,\n", + " attention_levels=(False, False, True),\n", + ")\n", + "autoencoderkl = autoencoderkl.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "bd5197a4-ec30-4f13-9b7a-5e3e43a42637", + "metadata": {}, + "outputs": [], + "source": [ + "discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1, num_layers_d=3, channels=64)\n", + "discriminator = discriminator.to(device)\n", + "adv_loss = PatchAdversarialLoss(criterion=\"least_squares\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfd826c6", + "metadata": {}, + "outputs": [], + "source": [ + "perceptual_loss = PerceptualLoss(spatial_dims=2, network_type=\"alex\")\n", + "perceptual_loss.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "410911c9", + "metadata": {}, + "outputs": [], + "source": [ + "scaler_g = GradScaler()\n", + "scaler_d = GradScaler()" + ] + }, + { + "cell_type": "markdown", + "id": "c16de505", + "metadata": {}, + "source": [ + "## Train Autoencoder" + ] + }, + { + "cell_type": "markdown", + "id": "a93437fe-d6ef-42d2-bedd-4da735c59dd1", + "metadata": {}, + "source": [ + "In this section, we train a spatial autoencoder to learn how to compress high-resolution images into a latent space representation. We need to ensure that the latent space spatial shape matches that of the low resolution images." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "2789508d-9fa8-483e-8b4c-dd17bf2f39b8", + "metadata": {}, + "outputs": [], + "source": [ + "# Loss weights\n", + "perceptual_weight = 0.002\n", + "adv_weight = 0.005\n", + "kl_weight = 1e-6\n", + "\n", + "# Optimizers\n", + "optimizer_g = torch.optim.Adam(autoencoderkl.parameters(), lr=5e-5)\n", + "optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "830a3979", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 0/75:,recons loss: 0.153782,perc_epoch_loss: 0.505703,kl_epoch_loss: 2163.503702,\n", + "Validation. recons loss: 0.002920,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 5/75:,recons loss: 0.055186,perc_epoch_loss: 0.259082,kl_epoch_loss: 2720.869604,\n", + "epoch 10/75:,recons loss: 0.047029,perc_epoch_loss: 0.211522,kl_epoch_loss: 2314.558471,\n", + "Validation. recons loss: 0.001608,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 15/75:,recons loss: 0.042698,perc_epoch_loss: 0.174223,kl_epoch_loss: 2112.795490,,gen_loss: 1.012134,disc_loss: 0.002280,\n", + "epoch 20/75:,recons loss: 0.035483,perc_epoch_loss: 0.102338,kl_epoch_loss: 2271.543911,,gen_loss: 1.012229,disc_loss: 0.002381,\n", + "Validation. recons loss: 0.001392,\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKIAAABbCAYAAADwb17KAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAv6klEQVR4nO19WWxb55X/j9vlvlOkSEmk1niNLXlJnKLjOpimKTBBUxczaIC8FCgKDDroWzF9m6d5mLeZtpn2YR5aFOgGd0mTtJjUjvc4i9M4tisvsiRroyiK4r6vdx70P8cfry8l2lUb+Q8dQCB1ebfvu+c7y+8sVyPLsowd2qFPmbSf9g3s0A4BO4y4Q9uEdhhxh7YF7TDiDm0L2mHEHdoWtMOIO7QtaIcRd2hb0A4j7tC2oB1G3KFtQfpudwyFQn/N+3gkkmUZjUYDAKDRaKDVaqHRaNr2oe3KwJH4f6PRQKVSgdVqhclkAgBUKhU0Gg0YDAYYjUbodDrIsoxWq4VisQiNRgOHw4FCoYBisQir1QqtVot6vQ6NRvPQfWzVeNVI7Voajeah/Wkb/Yn70DYA0Gq1vH+r1ep4fbX5Fq+lvN7q6uqmY+yaEbcTaTQamEwmaDQaNJtN1Ot1aLVaGAwGaLVatFot1Ot1VKvVh44TP2VZhsFgQL1eR7PZ5AnW6/VoNpsoFovMhACg1+shSRLK5TJkWYbZbIYsy2g2mxsyIT38bkm5v/hdZIrHic7qdDro9XpmularxecR56XZbLYxo0jKsYrMrHb/3dATyYiyLKNer8NgMLDkMxgMMJlM0Gq1qNVq0Ov1CAaD0Gq10Gq10Ov10Ov1MBgM0Ol0bRKzVCohl8shl8uhVCqh2Wy2XU+v18NisaBer6NUKvE5ZFlmxu/00ICHJddmD2qz3zZjQHHB0Th1Oh0MBgMsFgssFgv0ej3fS6PRQLPZZIZqNpuoVqu8QIkpRSlHC1SUrOL1lIt+M3oiGZGIGECj0cBmsyEYDMLr9cJoNEKv18NsNm94PE0aTTYA1Go1ZDIZxONxrK2toVQqodVqoVKpQKfTQZIk1Ot1SJIEnU6HcrncJk2JOqlH8XdxWzcP7FEYEAAvQp1OB5PJBKvVCofDAbvdDqPRCGDdPKnX68x0xIiVSoW3NRoNNoXovuk42o+eg06n6zjmjeiJZETRDiEGkmWZJ9xut7Pd1mq12iZTucI1Gg2MRiNsNhtcLhdMJhP8fj96enqQyWSQTqcRi8WwuLgIk8kEp9OJdDqNWq0Gs9mMRqOBVqvV9gDUSO2BKJmx2+O7YXLlfqIqJnuWNAuZNzQnJN1FzSFJEjOawWCALMuo1WrI5/MoFAqoVCptgoH+NtIUbWPoNg1suzkrwAOjutFosPqlbaJ67cQE9AD0ej3bTna7HX19fRgeHsbAwAAajQZmZmZw9+5dpNNpGAwGFItFVtEkWZTqXO1e6f/NVPVGklVpj6ntJ9pwoulCDEVzRYxDc6F2byTpyXaWJAkWiwUmkwkGgwGVSgW5XA7pdBqlUgn1er3tuGazibW1NdW5EemJZESgXf3Qd5pYUdqJ+6odSw9ElJa08nU6Hfr7+3H06FHs3r0bn3zyCX75y1/CbDbDbrejVCpBkiSWLIA64yiv2+l32qeThOtEIrOrHafUIJ2uSd9Jeop2HxGZJw6HAx6Ph82ffD7P0rFcLrfZnclkcsP7B55wRqSBktPRarXaVn+tVuvqPMDDUokkD0lKs9mMl19+GaOjo/jpT3+Kjz/+GFarlW2nTqp5MxtwI2+T/id1p7zfTmMRj1VeVyn51Bhf7X7EBS7LMmsPr9cLt9sNg8GAVquFcrmMQqGAXC7HDLmysqI69rZ7346MqNfrmbFEGIGklFarZayPJoiYkVYiqRKCKpRwRLdGNKlco9GIQqGAgwcP4uTJk/jkk0/w+9//HrVaje0nNdrIQ1aDPZTUrdcp2mTiMaQxRKZW2pCbXVfJ4KKNbTQaYbFY4HK5YLPZ2hiSUIj5+fnN73+7MaIIuoqAtDiBwAMJVKvVGA8kJiVoolKp8LHAAy9SBFw3IzpnpVLhBxCJRPCFL3wBAHDhwgXMzc11ZDg1NShu38wJUQOjldJRdA7Ec6ltp7mgz42YUe0a4nnIpJFlGZIksWdOMFqlUkGlUsHMzIzquduus10ZUWRGtd9IbQIPmIWcD/pNp9O12X0EZYjMsNnwCScslUpsoJfLZQSDQezbtw89PT24ffs2bt68ycc8CmyhtO86OR3KfcVPccFuxFBKRlJzmsTvneAlpc1JzKjVamE0Ghk+A4B6vY65ublN52HbwTfEMKLjQduUjgntR2qcVrm4v1IyiEzeDZEN6HA4UK/XUavVYDAYEI1G0Wq1cOjQoYeOUXNYNrveRrai6DyonV+032j8G3nmneagkzpWuybtQ8KAficwvNFoPPk4oshAatRqtVCr1Rg6EaWnKCXIvhOZkh7YRvFSJWk0GkiSxECvxWKBJEnIZrO4efMmmwBqJDKP0jHaTB0q56STjSsuRGJCJYPQd7Woh3g/SudlM2aia4n3puaZb0bbkhFFlF7NRiTbhGwRAmPFBAUCmgG04YvihD3Kiq3Vamx/NptNZszFxcW2B0/0KA9iowWxkZQS71/EU5XaQ1T5xGgbRYKU87IRnNTJPlXaoJvRtmNEEYahhyt6acD64C0WC2w2G2fQEHMQrENhN3pA3Uo/NSKmJjC3Xq+jXC4DAAwGA9+3chxq11SDTZTMJkpRpbRSs9lojPV6nXE+gp3E+VAiDEoNojaGjf5XG5dy/yeWEQE8JAVJQlIkg5IdfD4farUayuUy6vU6TCYT9Ho9QzhGoxHFYrFNYtKEi6t2s3shI7xWq3FojzBEUd2LpLRJxe1qY1WqRNEpE4/t5NBQdMhsNsPpdMLlcsFsNsNgMLBJQREhSnVTAvji/SmvJTqO3ahrpYe/GW05I3ZrmHfahyYVAHtkADiG7PF4YLfbYbFYcOjQIYyOjvIqB4BUKoXFxUXcv38fS0tLSCaTSKfTbeqIJJzSK1cjkrL5fB5msxlerxfZbBaNRoMTIDoZ/moSQcloZMd2cjDEe6Y/gkssFgtHOJxOJ3/v6emB2+2G1WplXK9SqSCbzWJ1dZXnJJVKIZVKIZfLoVqtotVqcTBgMw9eCSupPUfxczPaUka0WCxYWVmB2+2GyWRCqVRCtVplNdFsNmEwGJDNZmGz2QCAgWmdTod6vc4TbbFYMDg4iAMHDmB8fBy7d+9GOByG1+vlZNRmswm9Xo9CocAqU5R6rVYL9+7dw2uvvYa33noLAFCtVmGxWNBsNtkDJmaiiaOFUKvV0Gw2WdIAQLFYbFsoZKh3o5pFxqffRYyToCVRyorRDK1WC6vVCovFAr/fj8HBQYyMjGBkZAQ+nw8OhwNGoxFms5nviyQfzW25XEa1WkWpVEI2m8X8/Dzu3LmDe/fuYWVlBZVKhY8HHtjrapJZvP9OHvhmC533l7tU4pvhiLIsI5fLIRQKYXV1FbIsM9JOapUC5D6fDwBQLpf5YRuNRvT09GBgYACRSATHjx/H5z73ObhcLuj1etRqNeRyOaytrSGVSkGr1eL48eNIJBKIx+PIZDLI5XK8fzKZRL1ex4kTJ1AsFvHrX/8a3/3ud5HP5zlfkZhQNOABtKkrcbI7kVpyqfgbbVdzDkgikmqlRSRGgmw2G0KhEEKhEILBIHw+H/x+P3w+H9xuNxwOx0NecrVaRbFY5DkWw58kTS0WC0qlEmZnZ/GnP/0JV65cwczMDGR5PWIi2o2i40OfnexCcaw6nQ5LS0sbzh+whRJRo9HAbrcjn8/DarXyBNNqpMC3zWZDtVpFtVqFJEno6emBz+fD2NgYDh06hAMHDmBsbAySJCGdTuPWrVtIp9NsaNNqpvN/5jOfwXe/+11cuHABTqcTn//85zE4OIjp6Wm8//77eP/99/Hv//7vePXVVyFJEn70ox9hZmamLeFB9KhFaUqLqFtbUmnjqTkj4v6iFy9KPmJCnU4Ht9uN3bt347nnnsPu3bvh9/thNBrZSatWq1haWmKHTafTMRPVarW2fEHxXqxWK6vwvXv3wmq1sgkwOzuLarXKWfC0SJULhJ67kg/Ez0/FWSG1ptfrOT+NJqe/vx979uzB4uIiVlZW0N/fD7/fD6/Xi127dmF8fBx79+6Fz+dDMpnE3bt3MT09jbW1Nciy3MYUsizD5/PBarVCr9cjk8lgfHwcR44cwf79+2GxWOD1ejE0NASr1YpUKgWfz4eXXnoJpVIJp06dwtzcHFqtFseiRTVMZoRa9GGjsSupk8FODKjT6WCz2WCz2VgdUma0LK+XIgwODmJ8fByHDh1CKBRizZLNZrG2tsY5k6VSCTqdDna7HW63m3Myyasn+5KSEtbW1pDP59HX1we/34/h4WFGGZrNJqLRKKrVKktqGqPIlJ0W2uMgFFvKiLRSqFak2WzC6XTiwIED+PznP4+xsTEsLS3h3XffhcPhQKvVgtfrxbPPPou9e/dCkiSsrq7i3LlzbESTWq9WqyzF7HY7+vv7EYlEcP36dUQiEezbtw8TExNIJpOYnp5GNpsFAGSzWbz55ps4fvw4wuEwXnrpJZTLZfzhD3/AvXv3AIANdDGbRxzPozJjJy9Z3K7X6+F2uxGJRBAOh+Hz+WAymdh2JUbt6elBX18fJEnijPFCoYB0Oo10Oo1CocDmBhWAmc1m2Gw2thepfqdYLKLZbLLtTppFq9UiFArhqaee4rDopUuXEI1G2XShOVFzZNSoG6dVpC1lRJ1Oh1KpxMzjcrlw5MgRnDx5kpMEPvvZz2JoaAharRb3799n1WO1WjE1NYXbt2/jxo0bPIkEWANg+8ZoNMLhcKBcLuPMmTMYGhrCrl27YDabcefOHZw/fx75fJ6r8Oh+zGYzIpEIXn75ZdRqNRQKBUSj0Tabkbx0yvCRZfkhsLpbEqWhiAnq9XpYrVYMDg7i4MGDGB0dRW9vL+x2Oy8KEf8rlUpYXl5GIpFAJpNhdUveP30SE5NTaLfbIUkSZ1JTWYOo1mluTCYTwuEwJiYm2HG8fPkyEonEQ3aiaK6o2YxKM6Ub2lJGlCSJYQ69Xo/Dhw/jH//xH3H06FEsLS0hHo/D6/ViYGAAfX19CIfDPFn37t3DxYsXcf/+fUiSBKfTCb1ez6uXsoOBdW+2Wq3iz3/+M5LJJMbHx2E2mxGLxXDr1i0sLCzAarVCkiQEAgHo9XrMzc3hzp070Ov1eOqpp/CVr3wFsizj9ddfRzweZwOd6jeMRuMjTybwcNKGmsdpNBrZ6w2FQjCZTIyHSpIEl8vFTlcul8O9e/cwNzeHhYUF5HI5AOtwFuGEZLYQ6fV62Gw2uN1uSJKEQqGAUqnUZjOSjV0sFrG6uspSdGhoCM8++ywajQaKxSKuXr2KbDbbFi6lcYhjpk8lM3ZLW8qI5XIZ/f39SCQS6Ovrw0svvYSxsTGcPXsWr7/+OjweD27evIkXX3wRL730Eg4ePAiz2YybN2/i/PnzuHHjBk9sLpfj1U52J0ESrVYLTqcTu3fvxjPPPMNe9/nz53Ht2jU0Gg3s2rULVqsVhUIBZrMZBw8exMzMDPR6PXbv3o29e/fCbDZDo9Hg9ddfx+LiIsNCFIl5FEC2E8MqwXOdTgeHwwG/3w+3241UKsWmhCzLcLlciEQi2LVrF8LhMLRaLarVKvL5PNLpNMrlMs8R8LDZYLFY4PF42E4kuzCZTGJtbQ25XI4dGwK7y+UyEokETCYTJEnC4OAgjhw5guXlZSSTSdy+fRulUumhsKAoFf+SyBXwV3BWGo0GstksXn31VQwODuLHP/4xTp06hUgkgmKxiPn5efzXf/0X3n77bfznf/4nxsfHEY1GsbS0BL1ez5KQ8EeHw9FmRFssFvT09CAQCCAQCKBcLsNsNuOXv/wl3njjDWQyGeh0Oty5cwfj4+PszV+7dg0ulwvLy8uo1+sYHR3F8PAwvva1r+H27duYnZ1lpwhAR3xwM1JmAJGKJaaWJAkejwd+vx/1eh3RaBRTU1NYW1vjAn6Px4OxsTFMTExgaGiIHT86HxVAic0ANJr1wn/SNL29vdDr9UgkElhZWcHCwgLi8ThLQ7ofis+Xy2XOpJZlGYFAAIcPH0YsFkOhUMD8/HybJBVJmaTSCczfiLaUEWVZRqFQwP79+7Fv3z7cvn0bk5OTkCSJVQ6pzGg0in/5l3/B97//fbZ77HY7G9yNRgM6nQ7Ly8u8whuNBmq1GrxeL4LBIBqNBkvUqakp9j6JoWgVk+qpVqttFX3VahXhcBjf+c53MD09jVgsBo/Hg5WVFaTTac4/3Ay+UYsNEyMS07RaLZhMJpaGHo8H2WwWy8vLyOfz0Gq1MJvNqNVqSKfTmJycRC6XY3OGnAYALMUoEYOYu7e3F2NjYxgdHYXD4UAqlUI8Hsf9+/exsLDQBsaLn1RQValU2GYul8uw2WyYmJjA6uoqisUiYrHYQ6FStbruR3XygL9CiM9oNOLw4cNwuVy4dOkS7t69i8HBQQQCAWSzWQQCAdy7dw+yLCORSODf/u3fcPLkSYyNjeE3v/kNOzkjIyNIJpMwmUwol8uIRqNwOByYmJiAz+dDOp2G2+1GqVTC6dOnEY/HMTAwAJPJxIY8TTbwIErSbDaRyWSwurqKarWKQCCA4eFhvPLKK/iP//gPlMtleL1eZtpuMEQ1IsmlDMv19PQgEonA5/Mhk8kgn8+zDexwOAA8eJDFYhErKyvM1CaTqS0rmrQERWVcLhe8Xi8cDgdkWUY6ncbKygpSqRRqtVpbIgSZOpTISguuVCohFotBp9NheHgYQ0NDOHr0KN9nJpMBsHFCw6OYNERb3oQpkUjA5XJBp9PB5/Oht7cXALC4uIhisYh0Os31Dc1mE5OTk/jd736HdDqNL33pSwCAUqmEhYUFzrCx2WxsU9F3gn/Onj2Ly5cvI5lMMnhOHiEZ6KRSKDMFWC8jWFtbw8rKCmw2G06ePAmv18tlomQrkhfdLYm2JalSSpqw2+0cFXE4HJwlQ/dE1yMGazQaKBQKKBQKbGK4XC4ukCcGJKmm1+tRLpcRj8cxPT2Ne/fuIRqNolAosPSiY+i7mLkjFs1TNZ7dbseuXbuwa9cu+P1+PkYNO1Ta1Y/CkFsqEamlxbPPPotQKIQXX3wRlUoFV65cQTweZ8CWJpnUcDQaRTabxYkTJ+DxePDRRx9hfn4efr+fxb7VaoXX64Xf74fL5cLq6ipOnTqF//3f/0WtVsPIyAjq9ToXe5N3Rxk5wDq+ZjQaOZdQltdLHQcHB7ls9Ny5c8zInWqV1UhUyfQnZk6TJ+t0OtmEcLvd8Pv93E+HMpvJHiQGo3oQETWoVCp8DEm1arWK5eVlrK6uIp1OY2lpCbFYDKVSCUB78RnwoAkVAft0z5StU6lUYDAYEAqFMDw8jFu3bmFpaYlT4ESTREmPKhG3jBFleb3yf2JiAhMTE6hWq3C73QiHw1hbW8OhQ4eQTqeRzWYZuCW7aWRkBEeOHEFPTw/C4TAGBgbw3nvvIZvNotVqMRBLMeqpqSmcPn0a77zzDnK5HBv+NOGUwU2rn5iBJB154RqNhlPIrFYrTpw4gQsXLnAmyqPgh0pJQOEwwiGJoWgxSpIEs9mMgYEB9PT0cHcxuh+NRgOn0wmv1wun09nGkAS+k+QiVU44I+GGmUwGhUKhzUwRHShKaxP7ApHNKWY+UTjQ6/VCkiS2vcWwpDgPj0NbKhFrtRpeeOEF1Ot1nDt3jkHYkZERfO5zn8PS0hLbcORElEolhEIhSJKE69ev84Cfe+453L9/H4uLi6yaGo0GpqenceHCBVy+fBn1eh39/f0sUQjspkkSEwFERiXJKGa6NJtNTExMQKPRcFaOWhmqGolYoYgdigA2geukOmmRUU0wqWECnjUaDfeq0Wq1KBQK0Ol0aDQasNvtsNls7NzRMbFYrA2sJrtYNBGUtT4ULNDr9TCZTCxdgQemAmGM1JJFKfW3HXwjSRKOHz+Oixcv4gc/+AEWFhbQ29uLL3zhCzCbzfB4PHj++edx+vRpWCwWaLVafPzxx7h69SqSySRn7ITDYXzxi1/EsWPHAKxXgtlsNpTLZXzyySe4ffs2+vr62NMUi+qpZRxVkimxLzGhQMTi6vU6fD4fh7jEephHIeVDETNeAKBQKGBpaQmyLMNkMsHr9cLlcsFgMDCiUCwWOemi2Wwim80iFouhUqmwjRwKhbjTQrlcRrFYRKFQYCkn5jqS9CftINbw0JwAYMkpZseTx280GjnnUYy2qGUdfaqxZsLI+vr68Otf/5ofZjAYZDghm83iN7/5DV577TUA6/ZboVBANpvl+LRGo2Fc79VXX4XT6cTIyAju3LmDDz74AKlUCvv374fT6UStVmMDmx4aqTWLxcKTSeqG7EUy8MUWdSSxent7MTc3h1qttmH2jBqpqSnyTMnGo5BdqVRiqWy1WjkKQrFhUtPFYhHxeJydPUmS0Nvbi3w+j5GREY7AAGC1TrapCKMQ9CU6LCTVCMCnDHiTyQSLxdKmNSisKuK64jiVkZZPzUYE1ldULBaDzWbD6OgoZwKTQXzt2jXcunULExMTjFd95StfwY0bN3D+/Hm2zRqNBm7cuIGpqSl885vfxJe//GU4HA6k02nE43H2xBOJBMMWkiRxVgoZ2pFIhO00g8HAWTWkomnC6eFYrVbs27ePs1rEaE43pBbqogVqMplY9ZFDUq1W2TakB28ymbi8gaIhmUwGxWKRM3MSiQQsFgs7PnR+svtEIglHdjONha4nSRJvM5lMrILtdjvMZjN74ySJe3p6YDabeQ6VGUaPq6a31FnJ5XKYn5/Hv/7rv+LSpUtsTDscDjz99NOw2Wz44IMPcOnSJXg8HvzzP/8z+vr6EIvF2KingPvAwAC+9a1vYWhoCJFIBHNzc3C73axaqPMUMRg5QgQtVCoV5PN5dhAMBgMnSjQaDVitVkQiEezZswcWiwXAuu34zW9+E7OzsyxdgMfPvqFPgm+sVitsNhurTwDsjNlstjZYpV6vtzkcZFKQ3Voul7mMlSAuYkbggYcsFlCJRV80J+QEkTS2WCzMgBRKpNyBYDCIvr4+2Gw2hoTEMg2lVHwUptxS1ex2u6HRaPDWW28hGo3ydmoBTH1QSC1QrJdWN6kjys27desW3n77bXz5y1/G4uIiZmZmUCqV4Pf7OUUqkUjAbDbDbDZjdnYWuVwOe/bswXPPPceMQD1YSDqQNz8wMAC73d4mEfbv3w+Xy9V1irtyDpT/U/IALQDq4Uh5hbVaDcViEfl8nr1rwk8J3iKMjzx+MilI3Yr1N3TfIqNpNBqWplROQdCMUq1SJKhcLrdld+t0Oq6RoecjNu9UYosbQTtqtOUhvr6+Ppw/fx4///nPUSgUoNfr8f777+PgwYPI5/O4evUq3G43ZFnm8NPU1BR7xkajEZlMBkajEZOTk5ziValU4HQ6OTZKEoIcDqPRiJGREQ4XXrt2DR6Ph8F1koyNRgO5XA5LS0uc10gSw2g0cvSFxvOoNqJIdC5ZlpHJZNjbFRmRHnw+n4dGo+HoCUlwStwgB4TGQjYnqXhKZKA4sphsIWKHIiRFjAw8DD8B67CNqBVEkF6sqVGWFDwqEwJbrJrJ9jGZTCyFZFnGe++9B2A9F1Gn0yGTyWB5eRl/+MMfsLa2hlu3bqFQKKzf0P9TJdFoFP/wD//AmcPk2WYyGe5DUy6XufEPhck8Hg8/EDGiQvdHTFwsFpFIJBCLxTAwMMCTf/HiRU6M6HYyO+FoYlczEYCmRkXk4VNuZKVSYQlG0pPeYEASTavVtnXfIqlJjhHZtLSYRBiLPsVKPbIrlWOgrB+aw1arhXw+j2w2ywm14jHKHMRPTTUTXb58GXv37sWxY8eQzWZRLpeRyWSY4Y4fP465uTlUKhXMzc1hdnYWxWKRvV4q6olEInjmmWc4x5DSt2RZxtLSEvr7+9smAQCrPfIi6aGSJ0nbyK6sVquIx+MIh8MMNP/+979HIpFo65rQDanBNnRP5ICRXSeWbVYqFciyzNtFu5LwR6fTyeiA3W7nOh/KVALW7VsyKYiBRO9ZtBkJ2hKrEele6U90YsjBW1tb42wccZxqXvOj0pYz4h//+Ec8//zz+OpXv4rFxUXMzs5ifn4e5XIZ169fx1NPPQUAiEQiSKVSnD9HsIDf70ez2cQLL7wAi8WC5eVlpFIpBAIBjIyMwOFw4OrVq8hkMgw7UPthwuyUYSwx6UCEU4AHZQ3AelnBe++9xzYtnacbQLsTqE3fRbuLJBs5HiS9KeGX8DwqHaX6E0IIAoEA3G43dDodQzJUGpDL5djBETu3isA6Ya9iFjpJUZJ05O0ThJPJZLC4uIhoNMrpaqSixZqWx42ybHnNyvz8PJaWlrB3716USiVcv34dsViM8+3++7//G61WCxMTEzAYDAiHwwgGg1xM1d/fj0qlgqeeegq/+tWvUCwWodVqcefOHYyMjGB8fBzDw8N48803kUqlODpBq50KoSg6QmqGYAhS/QDYDqOJX1hYQCKR4G4RYjSiG+qUCACAnZZcLodCocCFXxTJIWlJOYZ0jJhtbTab4XA4GCMlz9VkMqGnpwfBYBDNZhPxeJwXMIVHRUakRUhMQ9JftBGpPTMJiHQ6jYWFBUSjUVQqFZhMprbFKp5PiR50Q1vqNVNbjlOnTvFrIpLJJDKZDGRZ5slfWVnBBx98wCn5fr+fU+DJ1nv//fcxOzvLeYxUWJXNZvH3f//3CAaD+MEPfoBPPvkEZrMZoVCIbStiIrPZzA+BVDZBKeSpu1wuhoEuXbrEr64gSdHNZBLD0v70J2a4kJ2YSqVgNBrZliUmpCwhOk5MQJAkCT6fjyGeer3ODN1sNmG327kGxuVywe12t0WGxBQw0hhiQofYMxJYX6A2m43j3MViEclkEtFoFKlUisOGNEd0nb8k+WHLJWKhUMCVK1fQbDbhcrm4RnlychIul4tXO6kNYL1NSCaT4Tj0uXPnYLFY0NfXx5kmNCgqADp8+DC+/e1v47XXXsPk5CSHysrlMsdL6Tir1cqpU8QQAJgBiVHOnTuHcrnMuCJN9ON4zbQAgAdOQrPZ5AVHsXGCcEiF0jhJ0hBsYrVaeaFTxIXeAQOAywCCwSCcTiejCQCQTqeZySliBDywQ8XvFGGhOaO6GcrqKZfLfF/kCCpVMtGjqOYty0ckO4MKgQgkjcVimJ6eZib1er3o6+tjnKtcLmP//v04fvw4dDodpqen+XeaNHHyV1ZWcP/+feh0OkQiEXznO99heAdYb6hJTOn1elmFED5HEopUHnWjuHHjBq5cucJ4KD2gbiZTTAIQ1ZLofNB2sWMFIQBiHiIxPiXCknokaIkSGmgM5JkXCgUkk0l+PyDVvQwPD8Pn83Hmerlcbqvoo2QHijLZbDZYLBaeg0wmg3v37mFqaoqLzChCJUJEnbJw/uY4Iq0kg8GAfD7PRe5GoxErKyvIZrOw2+0oFoussgHgi1/8Il555RWYTCbOBaSkUVKlVK0GrEvdZDKJVqvFybVf//rX8c477yCVSqFQKLBnXK1W4XQ6+ZVf1GGC0pooSlAsFvGTn/yEa6ZzuRzDJ51Wu0id1JJYZ0IQEm0Xa0fIIyaVJ0kSrFYrfD4fgsEgQqEQdDodJ0OIthmpcgDsDJHTR1lGJpMJc3NzSCQSHD4Uu6oRikD4JEFDrVYLi4uL+POf/4zp6WnGhcWxKJ00Ecbp1rQBtlgiFgoF9srEjBGqp6AwIAX9X375ZXzjG9/A7t27OexGnR4oPix6lpTlTJINWJeAu3btgsfjQalUQjweR6VSgdfrhc/nY/VCjZQotcrpdMJqtaJSqeDjjz/G66+/zrl+1Wr1sSAIEfSlP1EyiEC8mOZPv9M5CLLp7e1Ff38/JzYQAE0mDV2LkhVkWeZierLJA4EA+vr6EAgEOK9RDCXSH4XyXC4X/H4/rFYrisUi7t69izt37mB1dZUXZ6dcRJEXHpURt1QiUjYNQQ8AOEYZDoe521QymcSLL76If/qnf8LBgwdhMplQrVY5w8Tj8QB4gGsR7qUsZqKwktPpxNGjR3Hr1q22Pi/UZ4dwO7PZjJ6eHoRCIQwMDMBqtWJhYQG//e1vkUgkEAgEUCgU2tKgHochiTaym+g3uicRfqIFGQqF0Nvby16uqCVET57miLC/YrHIsAtlt/f09DDMlc/nWcXS4iFtEAgE4Pf7YTKZuFPY7Owsl7uKuCp93whV+FScFYPBwCEjskeMRiN2796NoaEhLhY6evQojh8/Do/Hg6mpKe5ydejQIWg0GpZqsiyz9CP1RWptcXERy8vLXDA1OjqKgwcPciPKer0Op9PJapEyX0KhECKRCNxuN5LJJN59912cPXuW36tHvblF9bLZZKqtfGWkQdxPjPLY7XaMjY0hFArBbrfzAqAkCer+AICTXem78nVjNOekgmVZ5vCfRrNe8yI+J7IzKcWLFqnH40Eul8PMzAzbhvQ81MYkfiq/d0tbyohk85BRTalMlUoFwWAQx44dY0fGbDbj7NmzmJ6eRjgcxuHDhzE6Ogqj0YibN29yqznydKmVCeFxs7OzePfddyFJEoaHh3HixAl89rOfRTqdxtTUVFtmCnmRTqcToVAINpsNqVQKH374IU6fPo3FxUW43W6W5lR8LobKHoc2si9FFerz+bBnzx4EAgHO4la+147aNOt0OlabtOjIg6U0OLITKckBANt1ZrOZkYNWq8W1M2QG+P1+NBoN3L9/Hzdv3sT9+/eRz+fbUADg4RLajeagG9rSWDOJfgCMxd29e5dX8SuvvAKHwwGbzYapqSlMTU1heXkZsVgMuVwOx48fx/DwMJxOJ+7cuYO7d+9y/p4YRaEE0Vgshh//+Mf4u7/7O/T09ODgwYMYHx9nTxQAh8RcLhczYSKRwKVLl/Dmm2/iww8/ZIlDUpTSxx4FflCSGhOKEReyqVdWVjje7XK5GKKh6AhJTavVCpfLxeE+cvoymQwXejWbTVa7zeaD10yQB04OiuickE3Y29vLGeqzs7O4cuUKrl27hlgsxj1y1KShWg1zp+8b0ZbaiIT4kyqwWCxotVqYmZnBD3/4Q7zxxht44YUXcOzYMezfv78NN5ufn8fbb7+No0eP4plnnuF2dbQ67969y6lLkiRx96pIJILJyUm89dZbHBLbu3cvA+WUPEoQRiwWwxtvvIFf/epXuHbtGtuYZFfRyhehl62cI/Fh1ut1LCwsoNlscjWhzWZDvV5nyabT6eDxeLhRJwHbwHrZbSaTQTab5QgKMSAlWVAIsFarMSBOZalutxuBQABer5cbEExPT+P06dM4c+YMFhYWuOOGmCCszNTZCLb5mzsrFLinSADVaFAko9VqYXZ2Fv/zP/+DX/ziF/D5fIjH4wgEAhgaGmLVtLCwgJWVFRw4cAAjIyOM98my3Pa6VavViqeffhoHDhzAhx9+CJPJhLW1NS6myufznIxLlMvlcOrUKfzsZz/DzMwM3yeFBL1eL6LRKLxeb5taflzJqHxAYkID2XXVahXRaBTpdBo3b96ExWJpq6XR69dfSOn3+xEOhzE0NMR2HCXaBoNBzqChlsS5XI4zZajmxG63M0OHQiHOtgaATCaDyclJvPPOOzhz5gzm5uYYsFY6bjQWcbGKn0pIpxv6m78CTYRjyOslKUcqAwACgQBGR0e5hXEwGGwr/qHJSaVSOHPmDCwWCy5evIiJiQm88MILXHpKE7O8vIwf/vCHOHXqFJaXlxlmUhbQK50Tmh7yMMnJIGYRE1OVtNHDEDE3sciJVKcYHqT9CJv1+Xzw+Xzwer3wer3o7e1FMBhEMBhkLUTJtsVikePuNpuNewaJSbczMzO4evUqLl68iI8++gjLy8vsVXfyiNUYkMalHG88Hlc9R9t8/K0ZUbmaCAsjfIrsQIqm+Hw+jI+P4+mnn8a+ffuwZ88eDA4OssNDnqLBYMDMzAxcLhfHcanM9PLly/je976HS5cucdSAOmyR5O5EYpZMtVrlXoL0gC0WCzOoSGoPRjTslWC0iDmKsWYx5qxM56Iak4GBAYyOjuLAgQMYHBxEb28vrFYrOyjksJCNSB0hVlZWMDk5iY8//hjXr1/HzMwM0uk0APC8iNnf4hhEW1e5XUmrq6ub88XfmhHbLt4BCiCmFCEcmnSSAISzHTlyBC6XC8FgkCUVlSTMzMzgzJkz+OijjzjdnjxVtZWrRnQvlORL5Q1i2SoAVUYUF5o4PiUp8TnlHAEPd/cXmYCycsLhMMLhMPr7+9Hb2wu3291WzUjoQz6fRzwe53zQpaUlpFIplvYk/ZWLSQ2kVvuudFa25fua1ZhAlJD0nVQHqS1yIig7mRJox8bGOF3eZrNBo9Fwj+h4PI54PM4SU5m5TDbYRkT3Se/gI0CYcvXILlY6IjQWZV11p0iEKDU7fVdGapRahVq+OBwObvZJsXQAXJIgetxUHSjLclvoDoDqvSvvX3RglIxI99YNI34qb55SincicdBkC4rZIQRmZzIZpFIpVssEcNvtdq5wI8iHGIYwNbHBUjd5hhSOExMMxHsme1fMdlGORbldqdo2gjs6YZEE1NN8AuD+2slkksdtMpk4Zk6APYUJxZxLZS2KclGpjU3c7y+BuoBPmRGJSJrQn5g2RfajuI36slAvQerwIFa2kS1EITHgwQon6dENWE2MaDKZOFWLyjap3yCB4GpM9CikhETEOaFFIz50UQrTNrIlidmofYlStdI8i2YDjVXUWMoMIrFQSrwH8b6VY+mGPtV38SknVWQY6ttC6pMmgCaGPF6qfhPPQ54nJZuKwXp6UMrm7RsROQcEMLvdbmSzWdTrdQQCATQaDaRSKb4+jU0cp/ibSGoPTDkvNCcbnYPmSfR0OzGJyNSiSUTqX2mvKhlfOb5O9CgluX9zRlRTQ8rVRzaL2DtbfBCkVsjuq9Vqba9FI1vQ5XIhnU4jl8tx5g1l11AmuFKlKkms7SCqVqswm83o6+vDyMgI4vE4mwpKqaDmYaqNX/lbJ5NFideJx6upUvH6on0s/kZCYCMHTun5K9WyqNLFc3ZLn4pEVN6k0huz2+1sPBNsQv/ThFBeIQCuAqRwFknDtbU12Gw2Ljcl5hZb1W1GZMCXSiWO966urnIjeYpR07jos5NnqSQl03azjxojdXMtpcqnbeLxornSCZJRLjDxPGRGiYkY3dCnCt88LpH6pWxlAqdlWW5LFNjMGSF7U8mc4iRSwiq9ibTRaODYsWP4zGc+g6WlJVy4cAGrq6ttselOzsWj0EYMoLyGUrNsZKdt5BwpVfVGTC7W11AGD5lEYgkthXg3o235vubNiGxDu90Op9PJiRCUvk5SsZNKIyLPMp/PMyRDCah0HsINc7kcnE4nTp48CbfbjQsXLmBmZoYzdh6HOqlftf2UXvZGzKR2nPi/2j2I3nOn84oJvfS7WO1HTeEbjQYnP3fb+vmJZEStVsu9BAnS0ev1bVEO8rg3EvjVapVzJh0OB8drxSKrarUKl8uFEydO4NChQ1hYWMDZs2exurrKJRFiVZyaqt3I7upWdakxmGhbqzGjmupUO5fyftSOV5pP1GHD6/VyRzKy7ylXUmw2tRk9kYwoyzK33CWi0kzRUBbbYogPQWkfiV41pVR5vV6MjY2hv78fPT09aDQa+PDDD7G4uIi1tTXUajW4XC5WQXTuTvfbrR2nPE7t/808VzXGVztWzcERP8UwpEbzoOCe3gxGKX3EgGJ1ofhW1G7oiWREshFJCondscREUNHTFv/EbQQNiW+Ep/6AlB2eyWQwNzfHSaKUmq8Ef9UkXCdVuNE+4rZO51U7nxrso5R45Ewo1bvoNYsFXzS/YpcIsekpqeFsNsvdJSgRpJu2z0RPJCMS44iOBE0qZZnQK8CUk6uccPoj1U59oguFAlKpFJaXlxGNRvl1vWLPQbG1m5oa7sZB2cw2VO7bDUN2gsgoJY/S32gREkwmdtMV65xpm1jBR3Y5lX/kcjkG0cUF8f81IxImRpMk9gn0eDyIRCIYGRmB1+sF0G6Ii5KR7EDCJemVZLOzs9xonqSmXq9HsVjkQqdUKsU102K/GJE6wRzdqmY1UpPqSjtO7T6IsajiT9mMk6Se6PmK+KAYsaEmqNlslttOi10fOtmjG47rSYRviIixKAZMcWeabBETU1PNtVoNpVKJmxJRnYj4dgJqjGQwGOB2u7m1HGGVxIzKOuhOku4vYUKR1FS+khGU9p7IjGIfILpv+qR5pVCq2KBJXLRiaauYICHOc71e56jThuN50hlR7HpPkRZxdRKpeYoUxxYxMdHTFh82dZ6g4+iBURiRFoQadqckNThFTVJuBIp3ug6NYyMSHRH6X+msiB600ksn9UzaSGx/p1wM1Bh1M+qaEXdoh/6atOXv4tuhHXoc2mHEHdoWtMOIO7QtaIcRd2hb0A4j7tC2oB1G3KFtQTuMuEPbgnYYcYe2Be0w4g5tC/o/edvAZ/11eCwAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 25/75:,recons loss: 0.030027,perc_epoch_loss: 0.057049,kl_epoch_loss: 2255.536555,,gen_loss: 1.012236,disc_loss: 0.012159,\n", + "epoch 30/75:,recons loss: 0.025783,perc_epoch_loss: 0.041187,kl_epoch_loss: 2198.212172,,gen_loss: 0.993313,disc_loss: 0.022697,\n", + "Validation. recons loss: 0.000970,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 35/75:,recons loss: 0.026168,perc_epoch_loss: 0.023821,kl_epoch_loss: 2140.541582,,gen_loss: 0.633519,disc_loss: 0.151763,\n", + "epoch 40/75:,recons loss: 0.025719,perc_epoch_loss: 0.017355,kl_epoch_loss: 2062.614804,,gen_loss: 0.479063,disc_loss: 0.183541,\n", + "Validation. recons loss: 0.000736,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 45/75:,recons loss: 0.024476,perc_epoch_loss: 0.015675,kl_epoch_loss: 1969.286549,,gen_loss: 0.510046,disc_loss: 0.174579,\n", + "epoch 50/75:,recons loss: 0.023567,perc_epoch_loss: 0.014090,kl_epoch_loss: 1853.637946,,gen_loss: 0.489276,disc_loss: 0.174202,\n", + "Validation. recons loss: 0.000862,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 55/75:,recons loss: 0.022395,perc_epoch_loss: 0.012438,kl_epoch_loss: 1734.904024,,gen_loss: 0.462214,disc_loss: 0.178832,\n", + "epoch 60/75:,recons loss: 0.021742,perc_epoch_loss: 0.012144,kl_epoch_loss: 1622.974387,,gen_loss: 0.528392,disc_loss: 0.165096,\n", + "Validation. recons loss: 0.000838,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 65/75:,recons loss: 0.021505,perc_epoch_loss: 0.011629,kl_epoch_loss: 1531.166383,,gen_loss: 0.506063,disc_loss: 0.167129,\n", + "epoch 70/75:,recons loss: 0.020647,perc_epoch_loss: 0.010862,kl_epoch_loss: 1449.738802,,gen_loss: 0.516405,disc_loss: 0.174172,\n", + "Validation. recons loss: 0.000698,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "max_epochs = 75\n", + "val_interval = 10\n", + "print_interval = 5\n", + "autoencoder_warm_up_n_epochs = 10\n", + "\n", + "for epoch in range(max_epochs):\n", + " autoencoderkl.train()\n", + " discriminator.train()\n", + " epoch_loss = 0\n", + " gen_epoch_loss = 0\n", + " disc_epoch_loss = 0\n", + " perc_epoch_loss = 0\n", + " kl_epoch_loss = 0\n", + "\n", + " for batch in train_loader:\n", + " images = batch[\"image\"].to(device)\n", + " optimizer_g.zero_grad(set_to_none=True)\n", + "\n", + " with autocast(\"cuda\", enabled=True):\n", + " reconstruction, z_mu, z_sigma = autoencoderkl(images)\n", + " recons_loss = F.l1_loss(reconstruction.float(), images.float())\n", + " p_loss = perceptual_loss(reconstruction.float(), images.float())\n", + " kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])\n", + " kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]\n", + " loss_g = recons_loss + (kl_weight * kl_loss) + (perceptual_weight * p_loss)\n", + "\n", + " if epoch > autoencoder_warm_up_n_epochs:\n", + " logits_fake = discriminator(reconstruction.contiguous().float())[-1]\n", + " generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False)\n", + " loss_g += adv_weight * generator_loss\n", + "\n", + " scaler_g.scale(loss_g).backward()\n", + " scaler_g.step(optimizer_g)\n", + " scaler_g.update()\n", + "\n", + " if epoch > autoencoder_warm_up_n_epochs:\n", + " optimizer_d.zero_grad(set_to_none=True)\n", + "\n", + " with autocast(\"cuda\", enabled=True):\n", + " logits_fake = discriminator(reconstruction.contiguous().detach())[-1]\n", + " loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True)\n", + " logits_real = discriminator(images.contiguous().detach())[-1]\n", + " loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True)\n", + " discriminator_loss = (loss_d_fake + loss_d_real) * 0.5\n", + "\n", + " loss_d = adv_weight * discriminator_loss\n", + "\n", + " scaler_d.scale(loss_d).backward()\n", + " scaler_d.step(optimizer_d)\n", + " scaler_d.update()\n", + "\n", + " epoch_loss += recons_loss.item()\n", + " perc_epoch_loss += p_loss.item()\n", + " kl_epoch_loss += kl_loss.item()\n", + "\n", + " if epoch > autoencoder_warm_up_n_epochs:\n", + " gen_epoch_loss += generator_loss.item()\n", + " disc_epoch_loss += discriminator_loss.item()\n", + "\n", + " if epoch % print_interval == 0:\n", + " msgs = [\n", + " f\"epoch {epoch:d}/{max_epochs:d}:\",\n", + " f\"recons loss: {epoch_loss / len(train_loader) :4f},\"\n", + " f\"perc_epoch_loss: {perc_epoch_loss / len(train_loader):4f},\"\n", + " f\"kl_epoch_loss: {kl_epoch_loss / len(train_loader):4f},\",\n", + " ]\n", + "\n", + " if epoch > autoencoder_warm_up_n_epochs:\n", + " msgs += [\n", + " f\"gen_loss: {gen_epoch_loss / len(train_loader):4f},\"\n", + " f\"disc_loss: {disc_epoch_loss / len(train_loader):4f},\"\n", + " ]\n", + "\n", + " print(\",\".join(msgs))\n", + "\n", + " if epoch % val_interval == 0:\n", + " autoencoderkl.eval()\n", + " val_loss = 0\n", + " with torch.no_grad():\n", + " for batch in val_loader:\n", + " images = batch[\"image\"].to(device)\n", + " reconstruction, z_mu, z_sigma = autoencoderkl(images)\n", + " recons_loss = F.l1_loss(images.float(), reconstruction.float())\n", + " val_loss += recons_loss.item()\n", + "\n", + " msgs = f\"Validation. recons loss: {recons_loss / len(val_loader) :4f},\"\n", + " print(msgs)\n", + "\n", + " # Plot reconstruction\n", + " plt.figure(figsize=(2, 2))\n", + " plt.imshow(torch.cat([images[0, 0].cpu(), reconstruction[0, 0].cpu()], dim=1), vmin=0, vmax=1, cmap=\"gray\")\n", + " plt.tight_layout()\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + "\n", + "\n", + "del discriminator\n", + "del perceptual_loss\n", + "torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "id": "c7108b87", + "metadata": {}, + "source": [ + "## Rescaling factor\n", + "\n", + "As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) became crucial in image-to-image translation models (such as the ones used for super-resolution). For this reason, we will compute the component-wise standard deviation to be used as scaling factor." + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "ccb6ba9f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scaling factor set to 0.8571630120277405\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " with autocast(\"cuda\", enabled=True):\n", + " z = autoencoderkl.encode_stage_2_inputs(check_data[\"image\"].to(device))\n", + "\n", + "print(f\"Scaling factor set to {1/torch.std(z)}\")\n", + "scale_factor = 1 / torch.std(z)" + ] + }, + { + "cell_type": "markdown", + "id": "b386a0c2", + "metadata": {}, + "source": [ + "## Train Diffusion Model\n", + "\n", + "In order to train the diffusion model to perform super-resolution, we will need to **concatenate the latent representation of the high-resolution with the low-resolution image**. Therefore, the number of input channels to the diffusion model will be the sum of the number of channels in the low-resolution (1) and the number of channels of the high-resolution image latent representation (3). In this case, we create a Diffusion model with `in_channels=4`. Since only the output latent representation is interesting, we set `out_channels=3`. \n", + "\n", + "**At inference time** we do not have a high-resolution image. Instead, we pass the concatenation of the low resolution image, and noise of the same shape as the latent space representation." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "92f3e348", + "metadata": {}, + "outputs": [], + "source": [ + "unet = DiffusionModelUNet(\n", + " spatial_dims=2,\n", + " in_channels=4,\n", + " out_channels=3,\n", + " num_res_blocks=2,\n", + " channels=(256, 256, 512, 1024),\n", + " attention_levels=(False, False, True, True),\n", + " num_head_channels=(0, 0, 64, 64),\n", + ")\n", + "unet = unet.to(device)\n", + "\n", + "scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195)" + ] + }, + { + "cell_type": "markdown", + "id": "8fb22b1a", + "metadata": {}, + "source": [ + "As mentioned, we will use the conditioned augmentation (introduced in [2] section 3 and used on Stable Diffusion Upscalers and Imagen Video [3] Section 2.5) as it has been shown critical for cascaded diffusion models, as well for super-resolution tasks. For this, we apply Gaussian noise augmentation to the low-resolution images. We will use a scheduler `low_res_scheduler` to add this noise, with the `t` step defining the signal-to-noise ratio and use the `t` value to condition the diffusion model (inputted using `class_labels` argument)." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "25d9d3e3", + "metadata": {}, + "outputs": [], + "source": [ + "low_res_scheduler = DDPMScheduler(num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195)\n", + "max_noise_level = 350" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "aa959db4", + "metadata": { + "lines_to_next_cell": 2 + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss: 0.005906,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "epoch 20/200:,loss: 0.144559,\n", + "Validation loss: 0.004613,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss: 0.004093,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss: 0.004793,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss: 0.004594,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss: 0.004505,\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKIAAABDCAYAAAAf6t48AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAASi0lEQVR4nO1dyW8T5/9+PJs93uIlzgJOcIDQEEBEUAnRokpF4lBxrKjUqlVvrSpxqyr11Ev/D07c6KW7RA+VqGjToiIQaiEkNIgYZ3G8JrbHHs/4d+D3eXk9jLeQBOereSSU8Wzv+877vJ99Blej0WjAgYNXDOFVd8CBA8AhooM+gUNEB30Bh4gO+gIOER30BRwiOugLOER00BdwiOigLyB1e+K+fft2sh8O2iCXy6HRaEAQBDQaDZimCZfLxY43Gg24XC5QbqKXHAV/H5fL1XQfgiAI7JggCE3X0X67e9LfTCbTsR9dE9HBq4MkSfD7/YwErYhmJWQ36EQiuqcsywAARVEgCAJEUWRtGYbB2jQMAwBgmibq9Tr73XGMXffYwSsHScJ2ROs1Y9uO1KIoIhqN4tChQxgYGIAoipAkCS6XC5IkQRAE6LrOyFar1WAYBnRdR7VahWmaqFarXfXDIeIeQKPRYJJpN0Aq2DRN1Go1LC0tIZVKwTAMaJqGer2Oer3OSGiaJlskpmmy/oqi6EjE/yVYpRZJxVbE7JW0/PlEIvqXz+eRzWaZfWpV3fSXzAZRFJvuTfs7wSHiHoGVXO3IaN3XiZjW+/LX8fYg7yTxjpGVoPzC6XZBOETcA9iOSr1u78ETiXeO2nnkdg5Pr1LZIeIeAD/5ViLQ71695a20xaNVe9bruyWjE9DeA2in6vg4Hm23U9f8Oa3O7RVWe9Gun53gEHEXwau4rVxLsNpxVinUSlLZ2Xf8/ezI2ckGbRVIp/YcG3EH0S72BrSWBtsdfunGUSFYicyTx0psu3u2sw3541uxD4FXQETrZO1mfGw7QBKHYmeUVTBNkx2nOBz9FUXxhdTYVmAnqeyen92+Tu3aqdd213eS7L1meV4ZEWVZhmEYqNfrTbEnChlQuICCojSptP0qQQS0C+oCzyZNFEWWiSBQGATYXU+YwEsrK/H4Z8znsztJw632xYpdJSIfl6pUKhAEAYqioF6vQ5IklrcEng+0Wq1CURS4XC7U63UoioJKpcJyn7vVbx5EMLfb3dWi6NYTbXc9H6+z85ztVKJVHZumCUmSmqQ3EVAURbbAALDMCi8Y7OzPduZIL9puV4lIK6xcLsPj8UAQBFSrVSbpeAlJksTj8UDTNMiyDFVVkc/nMTAwgFqttuP95e0oSuJTn3w+H4LBIAYGBuDxeOB2uyHLMhuLpmkolUoolUooFouoVCqoVquo1+ts4rdTsts5M1b13Wg0UKvVIIoiFEVpClATUYHnJOTTdnzmhMbQCX0dRzRNk0kzXdcRi8Vw4sQJHD9+HIlEAiMjI9jY2MCDBw9w79493Lt3D7lcDi6XC6VSCV6vF6VSadckIkkFSuYDQCAQwODgIMbHxzExMYHBwUEEAgH4fD5IkoRarYZMJoN0Oo1UKoXHjx9jdXUV1WoV1WqV2YzdTtTLpOus6lUQBAiCAFVVMTQ0BJ/PB8MwEAgEEAgE4HK5kM/nUS6XUSgUkM/nUSqVmLS0lqBR/7ZS+cPjlXjNjUYDkUgE58+fx8WLF/H6668jEokwyUMrslQq4fbt27h69Sp++uknpj7cbjdbvTsNmkhZluHxeODxeBCPxzExMYHDhw/j6NGjGB0dRSgUQjAYhCRJ0DQNq6urWFlZweLiIqtUqdfrXVej8LALs9hJQKsnzNt8tC8Wi+HQoUOYnJzE1NQUwuEw3G43VFWFoigoFotYW1tDsVgEAKytreHWrVtYWFiApmlNZgLfNt/+VrCrRCRxf+bMGXz66ac4d+5cU52d1dtUVRVvvfUWpqenEQwGce3atSanYCdgZwOJoohgMIhoNIqRkREcOXIE4+PjiMfjOHToEJOIHo+HXRcMBhGLxRAOh9l4dF2Hpmmo1WpssdG4u8kFd5sz5qtfALDFHQwG8e677+KNN95APB6H1+uFaZrIZrMolUrIZrOoVquQZRmhUAh+vx8HDx7EgQMHcOPGDczOzmJzc9M2m2O1UXsl5Y4RkZwLt9sNXdchyzJEUcTHH3+MTz75BKOjo+xcEvuiKDatZjKsY7EYvvjiC2QyGVy/fr1J3XQ7mb2AV2MkDYeHhzE9PY3p6WkcP34cg4OD8Pl8CIfD8Hq9L5gKwWAQfr8fgUCA2b+bm5vY2NhAoVBgpVQAmsbd6Zm26q+duuQ9X4/Hg7Nnz+LixYuIRqPIZrN4+PAhUqkUUqkUGy9vD9L4IpEIzp07h5WVFczNzTF7mVfHdkUZvWBHiEgTSSSpVqvw+/34/PPP8eGHH0JVVdTrdTZ5NBGAfTiBVMr777+PO3fuIJ1Os7Y6xb+2AvIcZVmG1+tFOBzG4cOHMTMzg1OnTmF6eppJE77MiVdb5GzFYjFMTk6iVCoxm4u3w4DndlsntCMr//zIO3a5XDAMA36/H2fOnMGlS5cQCoUwNzeHf//9F7quo1QqQdM0rK+vw+fzIRAIAAC8Xi+q1SrK5TJcLhf8fj9Onz6NQqGAZDIJSZKaKrPtHKNe5mRHUnz0ABRFQblcxvDwMC5fvowPPvgAXq+3SW1YwzUkOegBuFwu6LoOSZLw5ptv4uzZs8xo7pTJ2AoajQaTVm63GyMjI5iamsLMzAymp6cxMTEBr9cL4MVaO3JoSNIRYrEYEokEJicnkUgkEI1G2UR2WzgKvGgL8vutthpJN1EUceLECXz00UeYnp5GKpXC4uIiNjY2oGkaRFFsikYcOHAAkiRBFEV4vV7UajVUKhXkcjkcOHAA77zzDkZHR5mgaOW49IodU82KokDXdbjdbpw+fRrvvfcefD4fC13wjgnwbBKz2SySySQqlQrz6EjFi6IIj8eDS5cu4ccff4Su6+xBbHeQmIjkdrsxNjaGkydP4uTJk0gkEoyEPEzThK7rLLhNXjZP2EgkgvHxcaRSKWQyGTx9+rQpON9pAq1qr12VC68uVVXF2bNnMTIygjt37mBubo4tcrJpK5UKJicnEYlEkE6nWWiMxkJV2aFQCDMzM8hkMvjuu++g63pTANyuv90KiR0reiAbAgDOnz+PWCzWpLp4247Ptvj9fkSjUfh8Pib1KHAsCAJOnTqFRCJha5dsB4hEhmFAlmWEw2HE43HE43GEQqGmTAnwXAqSJKRtCtXQ2ILBIEZHR5mH3U4ivuyC4kkQCoUQi8WwtLSEhw8fMo9YlmWYpglN0yBJEsLhMIrFIorFIiMXPy5RFFGr1bC8vAxVVZlZZRdE34onvWMSkVSvLMtN+VcioSRJqNfrbFsURfj9fiiK0pSxIIlH3pzX68WJEycwNzfXtTTpFfy9KLhO4RsrEa2ODV9Sz9uLsixjcHCQLTIKRdn1vVP6rNs8L/As+K6qKorFIgsfSZLEpB/w3Dzig91WD5hipwsLC0in03C73czGbdd+t9jRMjDKovz9998oFotNKT4ALB9LkCQJiqI869j/TyjwbGC0ghuNBkKhEAzDYJPfKtC6VZDzYBgGNjY2sL6+jmw2i83NTdsx0pgURWn6R9kT6pcsy1AUBbIsNxHVLnVmhV3csN1vWvAUsaBnq6oqGo0GisUiM5Mon08E9Xq9Ta+LDgwM4Pjx44jH4yiVSlhZWWEZMT4jw/elV8GwI0Qk9UZk+uOPP/DXX38xCci/6cVLhHq9zjxLXdeZLcnbg5IkIZvNvpAxoHZfFvzk1Go1rK+v4/Hjx1hYWEAymUQ+n286n0hIUo8kuizLjHAUnqKUH71qudWQUzsJabW9C4UCcxgpHSkIAhsH9Y2ecalUQj6fZ2lKRVEwPj6OwcFBPHr0CA8ePMDy8vK2aiBgB1UzGfwulwuZTAZXr17F0aNHMTY21vQuBF/wIAjCC1kT2k8PeGNjA3fv3n3hrbLtApEdeKaqVldXoSgK/H4/I5zf72fndEumYrGI5eVlpFIpNtFWidlLH9tJSMoHC4KAcrmM+fl5vPbaawiHw6jVaixuCICZRhRWovy/qqqIRqOYmprCsWPHkMlk8NtvvyGZTAJ4Jt3JBrbL6PSKHSEiBYFN02QhnO+//x4ejwdffvklJiYmoGkavF5vE6EoXkWglU0P3jAMXLt2DU+fPm0KnWwnGXki6rqOtbU1Vu0jSRI8Hg8ikQiGhoa6vqdpmlhbW8P8/Dzm5+exsrLCQlK9TFyrhddqP9nWs7OzSCQSjFyCICCXy0HTNEZEyjlTQYksyzh27BhOnz4NwzDw66+/4v79+6jValAUpSlOybfXqU+tsGMBbVI9siyjXq9DVVX88MMPEEURly9fxrFjx1p6vjwxyWPTNA03btzAlStXUC6XmS35MmklO5BqBp5JdcqGyLIMt9sNt9sNn88HTdMQiUSaFo61/4RkMon5+Xn8888/mJ+fZyES8jy3eyEZhgG3282qZJ48eYJbt27hwoULSCQSqFarWFpaQjqdZqEc6ouqqlBVFfv378fU1BRcLhd+//133Lx5Ez6fD4qiYHNz84VYqXXcfZFZIalimiY2NzebKmZ+/vlnFItFfPXVVzhy5IitRCDVTKrLMAzMzs7i66+/xtOnT1l94nbbKQSSthTYNgwD6XQa8/PzzOZbXV1l4Rgy7qkMjD61UavVkM/nWSbj/v37SCaTzIMlwvcyad2MmexbcroKhQJu376NaDSKUCiESCSC4eFhdh6fQIjFYhgdHcW+ffvg9XqxsLCA69evo1KpQFEUpNNpaJrG+mJ1Vqz97HZsrkaXs7nVr4FZHRJCPB7HZ599hrfffhvDw8Ms+EsTRBmVJ0+e4JtvvsGVK1dYGooe8k6DvE7TNOHxeFi6jwoeqAqHih4CgQAL5GezWaTTaSwtLWFubg5PnjxBMplsChh367BkMhlW2dMuu8JPPhGEl1B+vx+nTp3CzMwMqtUq8vk8EokE/H4/NjY2EAqFMDU1hbGxMWxsbODRo0f45ZdfcPfuXeRyOayurjbd32of0jEreJu0FXaciDz48M3m5ib8fj+mp6dx4cIFHDlyBGNjY4hGozBNE6lUCvfv38e3336LP//8k3mypVIJqqruShmY1RslB2BgYABDQ0PYt28fJiYmMDIygkgkglAoBLfbjVqthrW1NSwtLWFxcRH//fcfq3CpVCoA0ESqTrASsZspo/OsdpvX68Xk5CQOHjyIcrkMv9+PiYkJDAwMIBaLYf/+/SgUCrh58yZmZ2exuLiIcrncVORA92oV/7TuX15e7tzf3SYipbVUVYVhGCiXy+whRyIRBAIBmKaJXC6H9fV1ZkxXKhWm8q3vuewkaELr9TpLdbndbgQCAUSjUYyOjjKVRxKxVqshl8thZWUFyWSSOTw0mYIgMDXeDdpJRLsshvUcahN47gAODQ2xGtBQKIR4PI5oNAqXy4W7d+9ibm4OhUKBnW8n/axt2TkrjUaj/4gIgAWKATS9J8GXp/Nl9HwowuoA7Sao9IkmhoLDXq+XFcySZ02ps1KpxOwvIiGRopcK7Uwmwz4L141qpt8EnhQUouHfR+G/bwM8C1vxpLY6ldb78f3h99N1KysrHce46y9PEalom8rm+Xci+BVIKTK+2mO3bEQeJMV4VCoVpmqB1i8Q0fU8XtZT7kRIq4Qi4pENC6BJ1ZMtTM9YlmVomtbkeFolLn9/63avjuSuvzxFq1GWZTZ4foWSkc2rAYpHUpEABcB3G/zk8y97kSTnx0A2Lf9xy60Ge+3CW1Z02scvfr7iiXdq6BwaF+3jS/V6Uc+t+mWHV/JeM8UHrQOwW2GNRvMHH/ki2lcJIhu/cKwTTxO5VQK2a7uTo9AqwM2fQ9u8/UjntTq3leR72Xhu339yZDsncLvAq6ZuTYTtUsVA64m2kqZT2+0I3e7edsfsvPRe0PdE7Ffs5gJ52bY6EbcVgaltu+tb2YdbhfM1sD2Cbibe6kTYmQS8rcoft9tuZSpZYTVLtkJMRyLuQXRyTOwkXbtrWx3jbfdu1G47u7QTHIm4B9CNfdjrffh97WKC7dBOCvJVU93AkYj/I2gXQul0Hr9tJQ8vFdtVYnejstvBIeIeg10GxTrhdsftftupYus5lDWi93Xo8ylEToqj0nUUM6VPw3T7sSyHiHsA7eKQnbzcVmRr1Y71WspFj42NIRQKsTiuIAiQJAm6rrMguK7r7H+fotx6qxesrHCIuAfQixe6VRvSLlcMPMs7P378GEtLS03fKKIqdusrsUReqhmgdGInOETcI+AzIFbn5WWDyQT+PrwTA6Ap7QfgBZVr5+z08iULh4h7APQZFip969YJ2EpIx041E7ohupWQ3UrorsvAHDjYSThxRAd9AYeIDvoCDhEd9AUcIjroCzhEdNAXcIjooC/gENFBX8AhooO+gENEB32B/wNdGPXBGwjkhwAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss: 0.004176,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss: 0.004405,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss: 0.004649,\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation loss: 0.004160,\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAKIAAABDCAYAAAAf6t48AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhCklEQVR4nO1dW2xcV9X+ztzvMx6Pxx6PHY9jJ6RN0/RC2/SCGgR9QKBW4oEn3kA8wCNSn5EQfUJCSH1DQrygipdKFFSholYQgUKLCijOxXFsJ07sjD33++3MnPkfrG9ln5MZZ5I0rat/lmTFl5kz++y99re+9a21T7R+v9/H2Mb2BZvtix7A2MYGjB1xbIfExo44tkNhY0cc26GwsSOO7VDY2BHHdihs7IhjOxQ2dsSxHQpzjPrC2dnZRzmO+zZN09Dr9WAYBpxOp/zscDjQ6/XQ7/fR7/ehaZq8/jCYOi673Q6n0wmXywWbbR8Tut0udF2Xr16vB13XTe+nqfdkvb9h98v3W+eGXwBgs9lMPw+bQ5vNZrqe+jMA9Ho9tNttlMvle87LyI54mExdTJfLBcMw0O12AQCNRgM2mw12u10W91EUj0a95igbwDAMWXi73W66tt1uh8PhMDnH/dwPP191PGDfifi5fJ3NZoPD4ZBx8HXqe/m+fr+PVquFTqeDXq8n47Lb7fKzy+WSdbmXfSkdkWaz2aDrOux2O5599lk899xz2NnZwcrKCra2ttDtdmURPytnVBFg0L+jIjA3U6/Xk3vh5nE6nfK3fr+Per2Obrd712cNM9VZ1deqDtVutxEOh3Hy5EkEAgGJKk6nU5xe13UZh2EY0HUdDocDdrsdLpcLAFCpVLC3t4dMJoNcLieOR2cd1bRRa82HKTT3+325US5aKBTCkSNHcOTIESwvL8MwDFy8eBEXLlxAPp//TEKzurjWL9XUUKeGOOt1uOB2ux1utxsejwcej0cWm5ZOp9FoNOTa6jXu1/h+UppQKASbzSaOPojW9Pt9oQmGYQiSut1uTExMYGpqCoFAAP1+H7du3UI+n0e73Ybdbkez2UQ+n7/3uL6Mjkjr9XqymHRMwzDgcrnEIScnJ7Gzs4N//vOfD/QZB03PQc590Pus6Gmz2eDxeOD1euH1esUZ6aQ3btxArVYb6tCDfn+Qw1r/NizcW8O5dcyGYcgaOJ1OzM7OYn5+HsViETdu3ECz2USz2UShUBg6F3Ldz8MRB92Q+vtBiDLsPQBMaMjv6YxEoG63C7vdjlgsBofDgb29vQcat4q+KpKRu/FL5U5MOHq9niAMTeV6vE+bzQaXyyWo6HK5BBU1TUM6nUa9Xv/MEq5hCD7s79bfWxMdwzDk+0gkguXlZfR6PWxubiKbzY6EiI+MI3KwKo+z3mCv14Pb7ZYwpN5Uu90WbmfNzvg7lWzz/fzZ4di/tVEm4aB74K5nSAIAp9Mp2a7P5xMEA4BOp4NWq4VGo4FWqyVhDYDpfqyZKTNM8lr1b+12W8ZjtVGdU3V860a3bjJ1jtUwbeW/nB91DPl8HqVSCUePHsWpU6dw4cKFkcb3yByRqOR0OuFwOGQhuQDBYFBCUjgcRiQSgdvtRj6fR7fbRbvdRq1WQ6vVQrPZNE0WrwfcvTsflDsNW2Q1iWBWSS7n9/sRCATg8/mEvHc6HTQaDRk7v7rdrjj1IF5J9Bw0Bmaho46bTsTv+V5GCV6X/JqvIU+kA/JzmQmrAMBrqOMyDAMOhwP9fh/r6+tot9syL/eyR+KIvAmv14tmsyl6WbfbRTgchsfjQSKRwNzcHOLxOHw+HxKJBBwOBwqFAkqlEvb29rC9vY18Po9GoyHZY61WMzkh/+XEP+h41TDMSWf4dbvdiEQiCIfDCIVCsmmcTif8fr84IheMiEh0rFQqaDQaqFarKBQKqNVqaDab6HQ6JsnEmvyoHE5NXu5lVqclavH+CAb9fl8+G7izwTkGUh++R72WOl5+pjWC3b59G51OZ6QxPxJH1DQNDodDJpoDX1hYwMmTJ/HCCy9genoaHo8Hc3NziEaj8Hq96Ha78Pv90DQNnU4H9Xodly5dwgcffIBPP/1UFp5oyZAJQHbvg/IoLlK320Wn00G/34fb7YbX60UwGMT8/Dzm5+cxOzuLeDwui+ZyueDxeOB2u+FyuQRZuICtVgu5XA6FQgG7u7vY3NzEzs6OaHBq5sy5syYTo6D8sIRFTebIOen8Pp8PwWAQHo9HNpPT6TQhc7/fR7lcRrVaRbVaNYntVue0cshByD/MHllodjgc0HVdkCIWi+Fb3/oW3njjDSSTSdjtdmSzWWiahna7jUwmg2azKeR/cnISkUgEX//61/Hyyy9jbW0N586dw0cffYSVlRV4PB5xmvvVrFRTuRGzP4/HA7vdjmAwiGg0iqmpKRw9ehRHjhzBzMwMJiYm0O/30W63RbjlFxMXVYsLhUIIhUJwuVzCG/mZ3W7XxIGtnBeAKRRaxz4syVA5uaZp8Pv9cDgc8Pv9mJiYwOzsLJLJJKanpxEOhxEIBISv8/N4zXw+j0wmg8uXL2NjYwOVSgWdTge1Wu0u/j9Mw7yXPXTWPCi7Jaz7/X4sLCzg9ddfx3PPPYelpSU0Gg387W9/w9WrV7G3twdN0xAKheD3+xGNRqFpGprNJoLBILxeL+bn5zE1NYXp6Wn4fD7cunUL77//Pn7/+9+j2Wyi3W5LhsqJGMXUMKg6YSAQQDAYRDgcRiKRQCwWE2fkYhKxW62WSb4gojDjZaimjFEul5HNZlEqlZDL5bC7u4tsNotyuSw6oRoqrdzuXo7I96hIGAqFMD8/j3A4DK/Xi9nZWRw7dgzJZFIQnxuISKeGU1WBKBQKWFtbw8WLF3HlyhUUCgUR28kdrQlOp9N59Doiwxl3BEOyYRiYmZnBN77xDfzgBz/A0aNHYbPZUC6X8dZbb+H8+fMIBAKo1+twOp2SIU9MTAhCMhslv7TZbHj22WfxzW9+E3Nzc/jXv/6Fn//858hms+h2uyiVSneV9NSs3cohVfGWWXAgEMDc3JwI44uLi5ienobX64Wu68jn89jd3cXe3h7q9bosHJ1ZlXOYRTMpoKNSptF1Hbdu3cLa2hpWV1dFXnK73TLWQUmG6iDqvRLJuKn8fj9mZmawtLSEJ598El6vF06nE8FgEOVyGaVSCd1uV7RLJkuqsE1QYBiPRqOIRCJwuVz43//+h48++gjXr18XZBw0/61WayQd8aFCMxeazhgMBnH8+HG89NJLePnll3H69GmEQiHouo5yuYyVlRUJ071eD5OTk8jn8zKB2WxWJlXlacA+UvzpT3/CxsYGvve97+GVV17BL3/5S/zkJz9BsVhEOBw2hYpROApDusvlQiQSwfT0NE6cOIFjx44hlUohlUohGo3CMAxkMhlkMhnk83lsbW2hVqtJdq82JXD8RAaHwyEIOzk5idnZWUxOTgqvbLfbkpB1Oh1xqHshO0HA7Xab6rk2mw3hcBiPP/44Tp8+jSeeeAJzc3OoVCpYXV3FrVu3UCgUZIOoHJLozvup1WpwOByIxWIolUoolUoIBoM4ceIEXn31VXg8Hrz77rvY2dmRJPIgunCQPTRH5MT1ej089thj+OlPf4pnnnkGXq8XwP5uvnz5Mq5evYpms4lkMgmfz4darQYAWF1dRaFQMCErdyd3qGEYCAQCcLlcWFtbwzvvvAMAOHXqFH7xi1/gV7/6Fba3t6XLQ5URhk2Muuu9Xi/m5uZw/PhxnD59GgsLC4jFYohEInA4HGg2m6ILqjVfNfvUdV10QPUznE6nJCb9fh9+vx+hUAgejwfBYBDxeBxHjhxBt9tFuVxGs9kUHnmv6ojdbker1RL0tdlsmJycxIsvvohXXnkFi4uLcLvdokAUi0VRMehw3CwMx5qmoVwuw+FwSCRQG0r6/T5u3LgBv9+Pl19+GZVKBe+//75QlUH67ij2UI7IBbDZbIhEInj99dfx1FNPCQlvNpu4cOECVldXBT1YIWB2vLy8jFwuh1wuh2azKUhD52aBvVwuw+l0Ck989913kcvlcPbsWfz4xz/Gm2++ienpaZRKJRPHURfRuqBc7EAggMXFRTzzzDM4deoUwuGwhHtyu0wmIwsUi8UQDofR6/XEQa3aoSoxtdttCYU2m03ex1CZSqVgs9lw+/ZtpNNptNttUwI2DFVU2cowDExOTuLs2bN49tln4fF48Omnn6JQKJicjI6nvlfTNEHWXq+HWCyGbrcLt9stiQk5b7PZRCaTgaZpmJ2dxUsvvQQAeO+997C9vT00sbqXPTQiUvA9deoUzp49K4S9Wq3iP//5DzY3NwVJWq2WCLScHMoHrAlzl3Kigf2dyOybaLG+vi7o8vzzz+Pb3/42/vjHPyISiSCXy5kyP9VU/kKNcHJyEgsLCzh+/DiOHj2KXq+H3d1dCcUMS+SyoVBI0JZo6fV6TQjJagg3Jct+wB3Re2pqSqQhanbFYtE0zmFoqGp9rFA9/fTTeOGFF9BoNKQ+bZWj1DIk6QPHpWqNANBsNoWzkyuGQiGpH7daLaRSKbz44otYX19HPp+X4sP92kNzxH6/j1gshu985ztIJBJwOp3QdR3//ve/RV0nbAPm9iA2sZIHJhIJ6LqO7e1tcdRerycVAF3XBSkdDgfW19el6+b73/8+/v73v8PhcKBYLJp65NQwx0nmhgmFQpKVMzMulUqoVquC1NVqFfV6XURutdeRSQp5LUM0cIfHccydTkeyUF3X4XQ6EYlEMDk5CU3TkM/n4Xa771khUjcTKc3i4iJee+01JBIJrKysoNPpSK8gAAEBcmgAJgekJmuz2dBsNk0htt/vo9FoCCft9/sSLXq9Hp588kmcOXMGly5dQjqdNikYoyLjQyOipmmYm5vDmTNnxKG2trZw7do1dDoddLtdNBoNU7VCJcoUkVlrJUHmhPN7VmbI/6rVKlqtlqDxSy+9hOeeew7nz5+Hw+GQ8EaUotER7XY7AoEApqampI0J2Cfo+Xxeeuwo4qrZpJrJkzr0ej1JQJxOp0g36jwB+2hYKpVgGAYikQi63S58Ph/C4bDIQzRrpYXXUTeW3W5HMpnE2bNncfz4cdy8eRONRsOUaXe7XXg8HnFaZtbcTNzYXAdV/uE9+v1+WaNOp4NgMIhWq4VSqYR0Oo14PI5EIoF8Pm9SJEa1hz6zYrPZkEgkMDU1JTe1vr4uu0UVcHO5HLLZrAjBdCxWJZxOJ6LRKBwOh8l5NU2TMONyuSR8G4aB3d1dbGxswDAMvPbaawgGg4JQqqalLiwdNBAIYGJiAtFoFD6fT3Z+pVIRVKRWCOwjSLPZRL1eR71eR6vVMjmptSuH9+R0OuF2uwXtuDk5Dz6fD4FAAH6/3ySKc8yqWTmv1+vFE088gePHj0tSQkck8tLBut2uoBSR2tpZxO8594ZhCJLTWdvtNiqVCqrVKgzDQKVSQbfbRTQaFYpB+1wQUc086RzdbhfFYtEUglUx9uLFi5ifn8fRo0fhcrkkQaH2lUwmZTGYrbVaLTkSwJDG/r1YLAZgf9enUimEw2Hs7OyYpCXr4tFpXC4XQqGQiOck8QxTavc0eR83GPVA/q3T6UjWTMQHIP+qn8974H1wEwQCAalwqCHdapxLm82GYDCIubk5aXUjYnGT83UMsxwf1wWAOB15PB1Q1UOdTqcp6anX64hGo5J9MyLw/M39FhgemiOyIYAalMfjwdNPPw2bzYZcLifo1mq14PF4RMh1Op2YmpqSBKTT6QgJJsfixLA1iuGON5lMJnH69GlMTU2JI09MTJi6iLkQNGuyQuRiRUMNtxTUAcgCVyoVkUx8Pp8pGWm322i32yaRnEmWuigMlZFIRFCw0+lIR4/f7zclN8MqKsz4ieKD5CO14gFAEM5axuO4VQFe/Sxd12W9DcNANpuVOjw/MxKJiGynjnEUe2iO2O/3MTMzg2AwaEok4vE4dnd3Ua1WZdEpZt+8eRMXLlwQROCO4oQwIfD7/Zifnxctj8hVLBbRaDQQiUQEGVmfnpyclAkf1pGjOiNDuDppanhVz2sQvavVquiD3P10Rm4A8kQuqurkAODz+RCJRAQBO52ONMayTDjs4JGqL7J17vbt24hEIne1x6n/qg3F7XZbEhBek5tGbX3jtVS+TRmNm4V/j0Qisjk/d/lG0zRxEKb6H3/8MUqlklQteOhG0zQ0Gg1ZIHIsOiHhndWBkydPYmJiQnQ5cq5kMmk6ZMTeRq/Xi4mJCemC4SSqi6eOmw5v5WREVGbAnU4H1WoVtVoNjUZDRGcilpWck4qoDQT8fDpmKBSSZgOv14t2uy1/U8czaOw00iJGCY5Z5YWkG+wEYuTK5XKyPoZhwOfzmZLEYeGbZ10ikYiJc+u6Dq/XC5/PBwCyQT83jqhWVrjjuJsJ6bquo1qtYn19HYVCATab7a4+NcI+sJ9ZZrNZfPLJJ9IhwnIYEwwudjwex9TUlGSJavhmeBlWciJK0dQMnqG2Wq2KqN1oNEwNFpSk1OuSa1kFaW4+YN8ZvV4v/H4/vF4v3G73XT2JoxJ+hmOv1ztQyFclHpfLJYkYj92qWTIpRyAQMKkb6poCEBmNmTc3KgDZRAzlo9pDI2Kv10OhUJCWL6/Xi3g8Lm1dDAObm5tIp9OmCefNqtkYbwbYF7KvXbuG69evS8Y8OzuLxcVFJBIJ2O127O3t4fHHH4fP50O5XJa6raobWtuUuFAqryO6qaGaO73VaqHdbt/Vsaw6i7VZVzUuGpseGL6taEx5xMr1rGbVRdUmW7WT3JoZc4wsL7ZaLZTLZQnJ/DwmU0RD9fPolKQnnAcK5Fb73JIVXdelesKy3czMDLa2ttDv91GtVpFOp3Hr1i1ZDOskM0Twhkmw1bYiYB+Btre3JUw+9dRTpgNUxWIR6+vrJjSy1j75L9uzVLHa7XbD5/PB7/cLSql6m8odVUdQvwYtABGUpTKWOFnx4PVZcWEDBK8ziFZw3tTF93q9oq0SqVQVgJqo2gBL6UxVAJjZk09y0/LzyJfVjiA6v7peVg30IHtoROx0OlhdXUU6nUY0GkW73ZazspQlotEoEomEhDc6JJGQFQAmNNb2dC4k+RMRdmlpCclkUnjQ1tYWNjc3hX/SwRly6OREh1qthkKhgHK5DF3X4Xa7EQwGpT+SnIdNF5xwyiAATDqiaqq0pb6W9+N2u006JWvz1Oc6nY6I5sOMvI9hcXFx0VSrV52QtIKd2IZhyNFVwNxG1uv1UCqVAOy3pYXDYQAQ1OP6scrCRIfzMIyXH2QP7Yjchevr6zh58iTcbjdisRiOHj2K9fV1OBwOzMzMwO12I5PJoNFooFQqoVgsCq/zeDwiKaicZdAi8GYZihcWFhAKhWAYBj7++GPp6lE1RCvxZ9htNpsoFovS4lUoFKRm7Pf7pfTGA0AOhwOtVsvUIKC2zhPFVEQgSlqFY+AOKpfLZRSLRRHRm82m6HjDFpNz1Gw2pVs9kUggm83i1q1bpqMabHqlA3EcRGfOh1paJCpynaLRqGwYOh3vm4kkGz9o91NZeehkhYT+448/xgsvvCCHoJaXl5HP5+F0OlEqlRCPxxEMBk2n3LiYjUYDuVwOpVLprgK8GoZU7Y/nLaanpxGJRFAoFPDhhx/CbrejXq8LjyOyqOIucIdbaZqG27dvY21tTcIltdHFxUWEQiFks1lZRIZOoo6apZJr0ilVZCK38ng80vFNDS6bzWJnZwe5XE4qSvead3X+KTazCYGHtYA7TbVMlNRrU6ZhJg3AFEkoJXW7XVQqFeGSuVwOe3t7WF5elobicDiMmzdvolKp3FVEGMUemiPa7XaUy2X84Q9/wNbWFn70ox/hzJkzmJ6exquvvorV1VVsbGzIxDCz9Xg8knVGo1HMz8+j1Wohk8lge3sbtVpNxFMeS+XOSyaTeOyxx5BKpbCwsAAA+N3vfocbN26g0+lIyFeTDnXyVYmm0Whgd3cXKysrohMeO3ZM+hHD4bAQf2a7DG0cG0MgnZT3auWMPp8Pk5OTOHLkCFKplGT729vbuH79OjKZjKDYQXOu8t9arYZisQhd1yWcsh2OlITjUxUForzKA3mPwD4NYlcR0ZKFB03bb9+jGJ9KpeByueQ8i1XDHMUeGhG54M1mE+fOncO1a9fw5ptv4rvf/S5CoRCeeuopuFwuXL58GR6PR/RFla9Rq3O73Ugmk3A6ndjZ2UGlUpEEAdgXgVOpFE6ePIlQKITp6WnMz8/jv//9L/7yl79IlcOq3Q1T+Vll4Flq1ky5yDzqGolEoOs6PB4PGo0GgsGg6UkOvBaPjBI91MYLu92OcDiMubk5OUbL47Obm5tYX1/H3t7ewGMBw+aezrO9vY1z584hHo9jZmYGmqZhYmICuq5L17p6BIHVKvJ0lb6QdrAZo9froVgsSojv9/e7rXw+H3Rdl87z8+fPY21tTRD1c0dE1aE0TUOxWMRbb72FjY0NvPHGG1haWhLuuLa2Bk3TJDlgR4imadKJYhgGPB4PJiYmZFEBCJokEgmEQiHMzMzgq1/9KnZ3d/H2228jm81KWFCdjpOiOqGq1/X7fdMZZDpBvV7H0tKSVD4SiQSi0aipAZYbiI5PBPH5fGg2m7JJWU8OhUKie3o8HhSLRezs7GBjYwObm5vSxHoQIlrnH9jfNBsbG/jtb3+L5eVlfOUrX8HS0hLC4bCABGkDO4ToxIOSIvJzOjI/y+VyYWJiQhzU7/fjxIkTuHHjBs6dOyfjvx8klHvpj+i6w07xadr+AZtQKIRqtSptUP1+HwsLC/jhD3+I119/XeSQq1evYn19HZlMRngTSb7axMkQDtzhNexSCYfDeP7559Hr9fCzn/0MH3zwgfA18h3qYvyyLq5a8qLjqv2B8XgcJ06cwPHjx7GwsIDp6WnpUK7VavLFMplhGNIUQc3RZrPB7XabqigMidVqFRsbG1hbW8OVK1dw8+ZNQZOD+vnU0ExTj1ho2v7JvTNnziCVSkmTcrlcliID552OyO9ZBOBr2NNIiYeiuKZpiEajeOKJJ6DrOt555x2srq6iXq+bqimc41GeO/TQoZnSBrM3YF/ScblcuHnzJn7zm9+gWq1KqF5aWoLP5zNxIgDSYECiz8V1u90IBALweDyIxWKYnp6G3+9HtVrF22+/jQ8//FCSHsB8juQgG1RV4Um9XC6Hra0tOXbAurff7wcAU6ezmp2z7q124HD8wWBQuo3K5TJ2dnawubmJGzduIJ1Oo1gsipY5KiJywxJ56VCVSgWXL1/G4uIiUqkUgP3nGObzeZTLZUFwSmbqz2qNmc6nFiDY0f7YY4+h0Wjgvffew8bGxl0P7Pzca81M89XEgBUCu92OdDqNd999F263W85SxONxnDlzxiTcsgivnrGw2WxYWFiQWjZlnitXruDXv/41Pv30U8lYOXnWGrP6O6upojQRhdkvs2QW8Xu9nhylZF8hT+ipEgj5lcqDNW2/bapQKCCfzyObzWJ7exvr6+u4ffs2yuWyOLyK3sM20yCxnOvAa6TTaXz44Ye4fv06kskkFhYWkEqlpCpCDVU9EKbyWfVBUByTz+dDLBZDKpVCNpvFn//8Z1y+fBntdls6wHnfahFhFPtMHks3aAeo2Sm7aJaWlvD8889jfn4eCwsLmJ2dlZanQS3yDKv1eh2lUglXr17Fe++9h7/+9a+o1Wr3TYgPGr8aqpmwhEIhRKNRSYr49Il4PC4d1Qy3qviunuhrtVooFotSfkyn03IOJpvNynNwWJs9qCkWuLturs7ZoJo6N4PT6UQsFsPi4iIWFhakEMDDXpTViOQ86+zz+aRJgy1nW1tb+Mc//iEHvah7DnK8breLTCZzzzV4JM9HVBMEfs+bm5ubQzKZxPLyMpaXl0V35OM92F5UrVaRzWaxubmJzc1NfPLJJ7h27Rrq9bqEw8/SrBsAuNNFQ+44MTEhxz/j8Tii0aicUVZDmK7r0qVTrVZx+/Zt7O7uIp1OI51Oy8k49gBaQ/FB1ML6t2GvVdFJ/Z7NwKlUCseOHcPMzAwikYhokXy4gd/vF56uafsd8tevX8fKygpu3ryJcrksCdswuYZR5gtzRACmZMH6e4rRU1NTiEQiJsGV5yBY/mLHCEMIEed+OcioZq18sHLAJlCWK6PRqDz5wOv1mrgVa7Fsqef5Fz60Uj1+4HA47spa1XkbJj1Z7SAHHURTNG3/PDeRjo8kAQC/3y9Pbet0OigUCshms8hkMqIYqDKZOmYVzVne/MIRUe0yUZsNOEhmY9ZuZ+BOP5v6BCpeb1Qy/yCmTgf5LgDJGJl88FEdDF28T94jx8wDYQzBdEJVlFbnR8301cgyirYIjBa66disQ1sd3crxVFRVX2MFGvX16vyN4oiP7LF0AExSAHcQT7wBEHkAgDylgY5IEVvtM2RIZnb3qM1a72ZixMYNflk3G0MSNyO/t4Z8qw1CPzrM/ZgaIlVnZEscABNI8HWq4qBueBVUrGChOh3Hq/5MJeVe9kifGKsuCAmzGpK48zhoLpBan+XvrCgyyAaR9gcdt9XUVi+rY/F91s+1Jgzql/V16rVYT+dj4tSKz6hm5eiDvrd2kFt7KrkOqvOqdWkrKlr1w0QiMfJ4H/n/s6JWXdQwMKh3b9DJL2t4AO4sHMVwOjUlI9XBGfpHfXLpMFOlGLVvzzrOUd5vDZ8ql6N0whN9LBXSEQdlxvy9dQMN4ufW2j2/V9+jbn5e24qwKsqq98Z1nZ2dxdTUFFZXV+8xs/v2ufyHPw+KUAfpf/ybepDd+sTW6elpnDp1CoZh4Ny5c5/J5z4KU0McO3R4Dlo9f0MbJhhbN7BVCVBDM7uw+Yg8aoiqQM+xqdTAiqpW6uL1erG0tAQAuH79Our1+khz8KX8n6e4W+v1umSszKgNw8Di4iLOnj2LSCSClZUVXLp06TP5TKuNqmOOErLZFMzjrUR4iuz3+mxrYjJoDJw3SlLWkp+qpfI91pYuFRkByJzbbDbMz89jeXkZ1WoVFy9evOtxfQfZl9IReeNslNB1HSdOnMDXvvY1zMzMoF6v4/z581hZWZH/KOdRZtr3Y+qi0gl5/pl1egCSFLEePOj9w649DLnVoxdqLZz/PYf1PBEzepr62Bj2K/p8Pvh8Pni9XpRKJVy6dAmZTMb0pI1R7Ev7P09Z5aFoNIpwOIy9vT2pp/IA0CA984uyQUkJD53xfHa320WtVpND8wypVjR6EFOz2n6/L4fRAoGAqWlYDdE8QKZKb4D53E+pVJIwrPL6RqMhTzg7yL6UiAiYmxv6/f0HjvNZzWpmqmbtX6QN43WA+X+yUg+DsRsJMEs+D8JVrWGbGzSfz8tDEKymdvWo9XM1uVL5I9FTRd1RxzoyIo5tbI/SDgdxGtv/exs74tgOhY0dcWyHwsaOOLZDYWNHHNuhsLEjju1Q2NgRx3YobOyIYzsUNnbEsR0K+z9hgrf7GF613gAAAABJRU5ErkJggg==", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Optimizers\n", + "optimizer = torch.optim.Adam(unet.parameters(), lr=5e-5)\n", + "\n", + "scaler_diffusion = GradScaler()\n", + "\n", + "max_epochs = 200\n", + "val_interval = 20\n", + "print_interval = 20\n", + "epoch_loss_list = []\n", + "val_epoch_loss_list = []\n", + "\n", + "for epoch in range(max_epochs):\n", + " unet.train()\n", + " autoencoderkl.eval()\n", + " epoch_loss = 0\n", + " for batch in train_loader:\n", + " images = batch[\"image\"].to(device)\n", + " low_res_image = batch[\"low_res_image\"].to(device)\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " with autocast(\"cuda\", enabled=True):\n", + " with torch.no_grad():\n", + " latent = autoencoderkl.encode_stage_2_inputs(images) * scale_factor\n", + "\n", + " # Noise augmentation\n", + " noise = torch.randn_like(latent).to(device)\n", + " low_res_noise = torch.randn_like(low_res_image).to(device)\n", + " timesteps = torch.randint(0, scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device).long()\n", + " low_res_timesteps = torch.randint(\n", + " 0, max_noise_level, (low_res_image.shape[0],), device=low_res_image.device\n", + " ).long()\n", + "\n", + " noisy_latent = scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)\n", + " noisy_low_res_image = scheduler.add_noise(\n", + " original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps\n", + " )\n", + " # Here we concatenate the HR latent and thje low resolution image.\n", + " latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)\n", + "\n", + " noise_pred = unet(x=latent_model_input, timesteps=timesteps, class_labels=low_res_timesteps)\n", + " loss = F.mse_loss(noise_pred.float(), noise.float())\n", + "\n", + " scaler_diffusion.scale(loss).backward()\n", + " scaler_diffusion.step(optimizer)\n", + " scaler_diffusion.update()\n", + "\n", + " epoch_loss += loss.item()\n", + "\n", + " msgs = [f\"epoch {epoch:d}/{max_epochs:d}:\", f\"loss: {epoch_loss / len(train_loader) :4f},\"]\n", + "\n", + " if epoch % print_interval == 0:\n", + " print(\",\".join(msgs))\n", + "\n", + " epoch_loss_list.append(epoch_loss / len(train_loader))\n", + "\n", + " if epoch % val_interval == 0:\n", + " unet.eval()\n", + " val_loss = 0\n", + " for batch in val_loader:\n", + " images = batch[\"image\"].to(device)\n", + " low_res_image = batch[\"low_res_image\"].to(device)\n", + "\n", + " with torch.no_grad():\n", + " with autocast(\"cuda\", enabled=True):\n", + " latent = autoencoderkl.encode_stage_2_inputs(images) * scale_factor\n", + " # Noise augmentation\n", + " noise = torch.randn_like(latent).to(device)\n", + " low_res_noise = torch.randn_like(low_res_image).to(device)\n", + " timesteps = torch.randint(\n", + " 0, scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device\n", + " ).long()\n", + " low_res_timesteps = torch.randint(\n", + " 0, max_noise_level, (low_res_image.shape[0],), device=low_res_image.device\n", + " ).long()\n", + "\n", + " noisy_latent = scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)\n", + " noisy_low_res_image = scheduler.add_noise(\n", + " original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps\n", + " )\n", + "\n", + " latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)\n", + " noise_pred = unet(x=latent_model_input, timesteps=timesteps, class_labels=low_res_timesteps)\n", + " loss = F.mse_loss(noise_pred.float(), noise.float())\n", + "\n", + " val_loss += loss.item()\n", + "\n", + " val_loss /= len(val_loader)\n", + " val_epoch_loss_list.append(val_loss)\n", + " msgs = f\"Validation loss: {val_loss / len(val_loader) :4f},\"\n", + " print(msgs)\n", + "\n", + " # Sampling image during training\n", + " sampling_image = low_res_image[0].unsqueeze(0)\n", + " latents = torch.randn((1, 3, 16, 16)).to(device)\n", + " low_res_noise = torch.randn((1, 1, 16, 16)).to(device)\n", + " noise_level = 20\n", + " noise_level = torch.Tensor((noise_level,)).long().to(device)\n", + " noisy_low_res_image = scheduler.add_noise(\n", + " original_samples=sampling_image,\n", + " noise=low_res_noise,\n", + " timesteps=torch.Tensor((noise_level,)).long().to(device),\n", + " )\n", + "\n", + " scheduler.set_timesteps(num_inference_steps=1000)\n", + " for t in scheduler.timesteps:\n", + " with torch.no_grad():\n", + " with autocast(\"cuda\", enabled=True):\n", + " latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)\n", + " noise_pred = unet(\n", + " x=latent_model_input, timesteps=torch.Tensor((t,)).to(device), class_labels=noise_level\n", + " )\n", + " latents, _ = scheduler.step(noise_pred, t, latents)\n", + "\n", + " with torch.no_grad():\n", + " decoded = autoencoderkl.decode_stage_2_outputs(latents / scale_factor)\n", + "\n", + " low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode=\"bicubic\")\n", + " plt.figure(figsize=(2, 2))\n", + " plt.style.use(\"default\")\n", + " plt.imshow(\n", + " torch.cat([images[0, 0].cpu(), low_res_bicubic[0, 0].cpu(), decoded[0, 0].cpu()], dim=1),\n", + " vmin=0,\n", + " vmax=1,\n", + " cmap=\"gray\",\n", + " )\n", + " plt.tight_layout()\n", + " plt.axis(\"off\")\n", + " plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "30f24595", + "metadata": {}, + "source": [ + "### Plotting sampling example" + ] + }, + { + "cell_type": "markdown", + "id": "1a2813d4-9087-459e-8913-bce174ac31cd", + "metadata": {}, + "source": [ + "As mentioned above, at inference time, we only need to pass noise of the same shape of the latent concatenated to the low-resolution image, to get the latent representation of the corresponding high-resolution image." + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "155be091", + "metadata": {}, + "outputs": [], + "source": [ + "# Sampling image during training\n", + "unet.eval()\n", + "num_samples = 3\n", + "validation_batch = first(val_loader)\n", + "\n", + "images = validation_batch[\"image\"].to(device)\n", + "sampling_image = validation_batch[\"low_res_image\"].to(device)[:num_samples]" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "aaf61020", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████| 1000/1000 [00:32<00:00, 31.22it/s]\n" + ] + } + ], + "source": [ + "latents = torch.randn((num_samples, 3, 16, 16)).to(device)\n", + "low_res_noise = torch.randn((num_samples, 1, 16, 16)).to(device)\n", + "noise_level = 10\n", + "noise_level = torch.Tensor((noise_level,)).long().to(device)\n", + "noisy_low_res_image = scheduler.add_noise(\n", + " original_samples=sampling_image, noise=low_res_noise, timesteps=torch.Tensor((noise_level,)).long().to(device)\n", + ")\n", + "scheduler.set_timesteps(num_inference_steps=1000)\n", + "for t in tqdm(scheduler.timesteps, ncols=110):\n", + " with torch.no_grad():\n", + " with autocast(\"cuda\", enabled=True):\n", + " latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)\n", + " noise_pred = unet(x=latent_model_input, timesteps=torch.Tensor((t,)).to(device), class_labels=noise_level)\n", + "\n", + " # 2. compute previous image: x_t -> x_t-1\n", + " latents, _ = scheduler.step(noise_pred, t, latents)\n", + "\n", + "with torch.no_grad():\n", + " decoded = autoencoderkl.decode_stage_2_outputs(latents / scale_factor)" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "32e16e69", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode=\"bicubic\")\n", + "fig, axs = plt.subplots(num_samples, 3, figsize=(8, 8))\n", + "axs[0, 0].set_title(\"Original image\")\n", + "axs[0, 1].set_title(\"Low-resolution Image\")\n", + "axs[0, 2].set_title(\"Outputted image\")\n", + "for i in range(0, num_samples):\n", + " axs[i, 0].imshow(images[i, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + " axs[i, 0].axis(\"off\")\n", + " axs[i, 1].imshow(low_res_bicubic[i, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + " axs[i, 1].axis(\"off\")\n", + " axs[i, 2].imshow(decoded[i, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + " axs[i, 2].axis(\"off\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "7fa52acc", + "metadata": {}, + "source": [ + "### Clean-up data directory" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a6f6d5a", + "metadata": {}, + "outputs": [], + "source": [ + "if directory is None:\n", + " shutil.rmtree(root_dir)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,py:percent" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/generation/2d_super_resolution/2d_sd_super_resolution_lightning.ipynb b/generation/2d_super_resolution/2d_sd_super_resolution_lightning.ipynb new file mode 100644 index 000000000..817efe233 --- /dev/null +++ b/generation/2d_super_resolution/2d_sd_super_resolution_lightning.ipynb @@ -0,0 +1,1318 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "51f79ac5-1e09-4933-9009-9ade92901b3c", + "metadata": {}, + "source": [ + "Copyright (c) MONAI Consortium
\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");
\n", + "you may not use this file except in compliance with the License.
\n", + "You may obtain a copy of the License at
\n", + "http://www.apache.org/licenses/LICENSE-2.0
\n", + "Unless required by applicable law or agreed to in writing, software
\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,
\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
\n", + "See the License for the specific language governing permissions and
\n", + "limitations under the License.
" + ] + }, + { + "cell_type": "markdown", + "id": "95c08725", + "metadata": {}, + "source": [ + "# Super-resolution using Stable Diffusion v2 Upscalers using PyTorch Lightning\n", + "\n", + "This tutorial is identical to '2d_sd_super_resolution' but uses PyTorch Lightning (https://lightning.ai/docs/pytorch/stable/).\n", + "\n", + "Tutorial to illustrate the super-resolution task on medical images using Latent Diffusion Models (LDMs) [1]. For that, we will use an autoencoder to obtain a latent representation of the high-resolution images. Then, we train a diffusion model to infer this latent representation when conditioned on a low-resolution image.\n", + "\n", + "To improve the performance of our models, we will use a method called \"noise conditioning augmentation\" (introduced in [2] and used in Stable Diffusion v2.0 and Imagen Video [3]). During the training, we add noise to the low-resolution images using a random signal-to-noise ratio, and we condition the diffusion models on the amount of noise added. At sampling time, we use a fixed signal-to-noise ratio, representing a small amount of augmentation that aids in removing artefacts in the samples.\n", + "\n", + "\n", + "[1] - Rombach et al. \"High-Resolution Image Synthesis with Latent Diffusion Models\" https://arxiv.org/abs/2112.10752\n", + "\n", + "[2] - Ho et al. \"Cascaded diffusion models for high fidelity image generation\" https://arxiv.org/abs/2106.15282\n", + "\n", + "[3] - Ho et al. \"High Definition Video Generation with Diffusion Models\" https://arxiv.org/abs/2210.02303" + ] + }, + { + "cell_type": "markdown", + "id": "b839bf2d", + "metadata": {}, + "source": [ + "## Setup environment" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "77f7e633", + "metadata": {}, + "outputs": [], + "source": [ + "!python -c \"import monai\" || pip install -q \"monai-weekly[tqdm]\"\n", + "!python -c \"import pytorch_lightning\" || pip install pytorch-lightning\n", + "!python -c \"import matplotlib\" || pip install -q matplotlib\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "id": "214066de", + "metadata": {}, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de71fe08", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "import tempfile\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from monai import transforms\n", + "from monai.apps import MedNISTDataset\n", + "from monai.config import print_config\n", + "from monai.data import CacheDataset, ThreadDataLoader\n", + "from monai.utils import first, set_determinism\n", + "from torch.amp import autocast\n", + "from torch import nn\n", + "from tqdm.notebook import tqdm\n", + "\n", + "from monai.losses import PatchAdversarialLoss, PerceptualLoss\n", + "from monai.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator\n", + "from monai.networks.schedulers import DDPMScheduler\n", + "\n", + "import pytorch_lightning as pl\n", + "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", + "\n", + "print_config()" + ] + }, + { + "cell_type": "markdown", + "id": "c0dde922", + "metadata": {}, + "source": [ + "## Setup a data directory and download dataset\n", + "Specify a MONAI_DATA_DIRECTORY variable, where the data will be downloaded. If not specified a temporary directory will be used." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ded618a7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/tmp/tmpkazhiy23\n" + ] + } + ], + "source": [ + "directory = os.environ.get(\"MONAI_DATA_DIRECTORY\")\n", + "root_dir = tempfile.mkdtemp() if directory is None else directory\n", + "print(root_dir)" + ] + }, + { + "cell_type": "markdown", + "id": "e855e2b7-7e46-44d9-a567-3e91b5db2b6f", + "metadata": {}, + "source": [ + "## Set deterministic training for reproducibility" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9f0a17bc", + "metadata": {}, + "outputs": [], + "source": [ + "# for reproducibility purposes set a seed\n", + "set_determinism(42)" + ] + }, + { + "cell_type": "markdown", + "id": "aa76151c-0a37-471e-8312-10c2afcf11bc", + "metadata": {}, + "source": [ + "## Description of data and download the training set\n", + "\n", + "For this tutorial, we use the head CT dataset from MedNIST." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "298d964a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "MedNIST.tar.gz: 59.0MB [00:01, 38.7MB/s] \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-24 08:29:18,175 - INFO - Downloaded: /tmp/tmpkazhiy23/MedNIST.tar.gz\n", + "2024-09-24 08:29:18,286 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n", + "2024-09-24 08:29:18,286 - INFO - Writing into directory: /tmp/tmpkazhiy23.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47164/47164 [00:16<00:00, 2894.30it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2024-09-24 08:29:39,365 - INFO - Verified 'MedNIST.tar.gz', md5: 0bc7306e7427e00ad1c5526a6677552d.\n", + "2024-09-24 08:29:39,365 - INFO - File exists: /tmp/tmpkazhiy23/MedNIST.tar.gz, skipped downloading.\n", + "2024-09-24 08:29:39,366 - INFO - Non-empty folder exists in /tmp/tmpkazhiy23/MedNIST, skipped extracting.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5895/5895 [00:01<00:00, 3004.12it/s]\n" + ] + } + ], + "source": [ + "train_data = MedNISTDataset(root_dir=root_dir, section=\"training\", download=True, seed=0)\n", + "train_datalist = [{\"image\": item[\"image\"]} for item in train_data.data if item[\"class_name\"] == \"HeadCT\"]\n", + "val_data = MedNISTDataset(root_dir=root_dir, section=\"validation\", download=True, seed=0)\n", + "val_datalist = [{\"image\": item[\"image\"]} for item in val_data.data if item[\"class_name\"] == \"HeadCT\"]" + ] + }, + { + "cell_type": "markdown", + "id": "46bafb78", + "metadata": {}, + "source": [ + "### Setup utils functions" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "4f8eff03", + "metadata": {}, + "outputs": [], + "source": [ + "def get_train_transforms():\n", + " image_size = 64\n", + " train_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", + " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),\n", + " transforms.RandAffined(\n", + " keys=[\"image\"],\n", + " rotate_range=[(-np.pi / 36, np.pi / 36), (-np.pi / 36, np.pi / 36)],\n", + " translate_range=[(-1, 1), (-1, 1)],\n", + " scale_range=[(-0.05, 0.05), (-0.05, 0.05)],\n", + " spatial_size=[image_size, image_size],\n", + " padding_mode=\"zeros\",\n", + " prob=0.5,\n", + " ),\n", + " transforms.CopyItemsd(keys=[\"image\"], times=1, names=[\"low_res_image\"]),\n", + " transforms.Resized(keys=[\"low_res_image\"], spatial_size=(16, 16)),\n", + " ]\n", + " )\n", + " return train_transforms\n", + "\n", + "\n", + "def get_val_transforms():\n", + " val_transforms = transforms.Compose(\n", + " [\n", + " transforms.LoadImaged(keys=[\"image\"]),\n", + " transforms.EnsureChannelFirstd(keys=[\"image\"]),\n", + " transforms.ScaleIntensityRanged(keys=[\"image\"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True),\n", + " transforms.CopyItemsd(keys=[\"image\"], times=1, names=[\"low_res_image\"]),\n", + " transforms.Resized(keys=[\"low_res_image\"], spatial_size=(16, 16)),\n", + " ]\n", + " )\n", + " return val_transforms\n", + "\n", + "\n", + "def get_datasets():\n", + " train_transforms = get_train_transforms()\n", + " val_transforms = get_val_transforms()\n", + " train_ds = CacheDataset(data=train_datalist[:320], transform=train_transforms)\n", + " val_ds = CacheDataset(data=val_datalist[:32], transform=val_transforms)\n", + " return train_ds, val_ds" + ] + }, + { + "cell_type": "markdown", + "id": "d80e045b", + "metadata": {}, + "source": [ + "## Define the LightningModule for AutoEncoder (transforms, network, loaders, etc)\n", + "The LightningModule contains a refactoring of your training code. The following module is a reformating of the code in 2d_stable_diffusion_v2_super_resolution.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "d5d1caff", + "metadata": {}, + "outputs": [], + "source": [ + "class AutoEncoder(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.data_dir = root_dir\n", + " self.autoencoderkl = AutoencoderKL(\n", + " spatial_dims=2,\n", + " in_channels=1,\n", + " out_channels=1,\n", + " channels=(256, 512, 512),\n", + " latent_channels=3,\n", + " num_res_blocks=2,\n", + " norm_num_groups=32,\n", + " attention_levels=(False, False, True),\n", + " )\n", + " self.discriminator = PatchDiscriminator(spatial_dims=2, in_channels=1, num_layers_d=3, channels=64)\n", + " self.perceptual_loss = PerceptualLoss(spatial_dims=2, network_type=\"alex\")\n", + " self.perceptual_weight = 0.002\n", + " self.autoencoder_warm_up_n_epochs = 10\n", + " self.automatic_optimization = False\n", + " self.adv_loss = PatchAdversarialLoss(criterion=\"least_squares\")\n", + " self.adv_weight = 0.005\n", + " self.kl_weight = 1e-6\n", + "\n", + " def forward(self, z):\n", + " return self.autoencoderkl(z)\n", + "\n", + " def prepare_data(self):\n", + " self.train_ds, self.val_ds = get_datasets()\n", + "\n", + " def train_dataloader(self):\n", + " return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True, num_workers=4, persistent_workers=True)\n", + "\n", + " def val_dataloader(self):\n", + " return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False, num_workers=4)\n", + "\n", + " def _compute_loss_generator(self, images, reconstruction, z_mu, z_sigma):\n", + " recons_loss = F.l1_loss(reconstruction.float(), images.float())\n", + " p_loss = self.perceptual_loss(reconstruction.float(), images.float())\n", + " kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3])\n", + " kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]\n", + " loss_g = recons_loss + (self.kl_weight * kl_loss) + (self.perceptual_weight * p_loss)\n", + " return loss_g, recons_loss\n", + "\n", + " def _compute_loss_discriminator(self, images, reconstruction):\n", + " logits_fake = self.discriminator(reconstruction.contiguous().detach())[-1]\n", + " loss_d_fake = self.adv_loss(logits_fake, target_is_real=False, for_discriminator=True)\n", + " logits_real = self.discriminator(images.contiguous().detach())[-1]\n", + " loss_d_real = self.adv_loss(logits_real, target_is_real=True, for_discriminator=True)\n", + " discriminator_loss = (loss_d_fake + loss_d_real) * 0.5\n", + " loss_d = self.adv_weight * discriminator_loss\n", + " return loss_d, discriminator_loss\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " optimizer_g, optimizer_d = self.optimizers()\n", + " images = batch[\"image\"]\n", + " reconstruction, z_mu, z_sigma = self.forward(images)\n", + " loss_g, recons_loss = self._compute_loss_generator(images, reconstruction, z_mu, z_sigma)\n", + " self.log(\"recons_loss\", recons_loss, batch_size=16, prog_bar=True)\n", + "\n", + " if self.current_epoch > self.autoencoder_warm_up_n_epochs:\n", + " logits_fake = self.discriminator(reconstruction.contiguous().float())[-1]\n", + " generator_loss = self.adv_loss(logits_fake, target_is_real=True, for_discriminator=False)\n", + " loss_g += self.adv_weight * generator_loss\n", + " self.log(\"gen_loss\", generator_loss, batch_size=16, prog_bar=True)\n", + "\n", + " self.log(\"loss_g\", loss_g, batch_size=16, prog_bar=True)\n", + " self.manual_backward(loss_g)\n", + " optimizer_g.step()\n", + " optimizer_g.zero_grad()\n", + " self.untoggle_optimizer(optimizer_g)\n", + "\n", + " if self.current_epoch > self.autoencoder_warm_up_n_epochs:\n", + " loss_d, discriminator_loss = self._compute_loss_discriminator(images, reconstruction)\n", + " self.log(\"disc_loss\", loss_d, batch_size=16, prog_bar=True)\n", + " self.log(\"train_loss_d\", loss_d, batch_size=16, prog_bar=True)\n", + " self.manual_backward(loss_d)\n", + " optimizer_d.step()\n", + " optimizer_d.zero_grad()\n", + " self.untoggle_optimizer(optimizer_d)\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " images = batch[\"image\"]\n", + " reconstruction, z_mu, z_sigma = self.autoencoderkl(images)\n", + " recons_loss = F.l1_loss(images.float(), reconstruction.float())\n", + " self.log(\"val_loss_d\", recons_loss, batch_size=1, prog_bar=True)\n", + " self.images = images\n", + " self.reconstruction = reconstruction\n", + "\n", + " def on_validation_epoch_end(self):\n", + " # ploting reconstruction\n", + " plt.figure(figsize=(2, 2))\n", + " plt.imshow(\n", + " torch.cat([self.images[0, 0].cpu(), self.reconstruction[0, 0].cpu()], dim=1), vmin=0, vmax=1, cmap=\"gray\"\n", + " )\n", + " plt.tight_layout()\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer_g = torch.optim.Adam(self.autoencoderkl.parameters(), lr=5e-5)\n", + " optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=1e-4)\n", + " return [optimizer_g, optimizer_d], []" + ] + }, + { + "cell_type": "markdown", + "id": "c16de505", + "metadata": { + "lines_to_next_cell": 0 + }, + "source": [ + "## Train Autoencoder" + ] + }, + { + "cell_type": "markdown", + "id": "e740cb2d-5a57-42ed-806b-e8c720a6f922", + "metadata": {}, + "source": [ + "In this section, we train a spatial autoencoder to learn how to compress high-resolution images into a latent space representation. We need to ensure that the latent space spatial shape matches that of the low resolution images." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9d903aaa", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/vf19/PycharmProjects/MONAI_tutorials/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n", + " warnings.warn(\n", + "/home/vf19/PycharmProjects/MONAI_tutorials/venv/lib/python3.10/site-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=AlexNet_Weights.IMAGENET1K_V1`. You can also use `weights=AlexNet_Weights.DEFAULT` to get the most up-to-date weights.\n", + " warnings.warn(msg)\n", + "/home/vf19/PycharmProjects/MONAI_tutorials/venv/lib/python3.10/site-packages/lpips/lpips.py:107: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", + " self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "Loading dataset: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 320/320 [00:00<00:00, 1478.97it/s]\n", + "Loading dataset: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 32/32 [00:00<00:00, 773.46it/s]\n", + "/home/vf19/PycharmProjects/MONAI_tutorials/venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /tmp/tmpkazhiy23 exists and is not empty.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode \n", + "-----------------------------------------------------------------\n", + "0 | autoencoderkl | AutoencoderKL | 75.1 M | train\n", + "1 | discriminator | PatchDiscriminator | 2.8 M | train\n", + "2 | perceptual_loss | PerceptualLoss | 2.5 M | train\n", + "3 | adv_loss | PatchAdversarialLoss | 0 | train\n", + "-----------------------------------------------------------------\n", + "77.8 M Trainable params\n", + "2.5 M Non-trainable params\n", + "80.3 M Total params\n", + "321.225 Total estimated model params size (MB)\n", + "251 Modules in train mode\n", + "41 Modules in eval mode\n", + "/home/vf19/PycharmProjects/MONAI_tutorials/venv/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (20) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e3c00d0cf81e4f5484624e5e3278d6bd", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=75` reached.\n" + ] + } + ], + "source": [ + "max_epochs = 75\n", + "val_interval = 10\n", + "\n", + "\n", + "# initialise the LightningModule\n", + "ae_net = AutoEncoder()\n", + "\n", + "# set up checkpoints\n", + "\n", + "checkpoint_callback = ModelCheckpoint(dirpath=root_dir, filename=\"best_metric_model\")\n", + "\n", + "\n", + "# initialise Lightning's trainer.\n", + "trainer = pl.Trainer(\n", + " devices=1,\n", + " max_epochs=max_epochs,\n", + " check_val_every_n_epoch=val_interval,\n", + " num_sanity_val_steps=0,\n", + " callbacks=checkpoint_callback,\n", + " default_root_dir=root_dir,\n", + ")\n", + "\n", + "# train\n", + "trainer.fit(ae_net)" + ] + }, + { + "cell_type": "markdown", + "id": "c7108b87", + "metadata": {}, + "source": [ + "## Rescaling factor\n", + "\n", + "As mentioned in Rombach et al. [1] Section 4.3.2 and D.1, the signal-to-noise ratio (induced by the scale of the latent space) became crucial in image-to-image translation models (such as the ones used for super-resolution). For this reason, we will compute the component-wise standard deviation to be used as scaling factor." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "ccb6ba9f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scaling factor set to 0.6885251998901367\n" + ] + } + ], + "source": [ + "def get_scale_factor():\n", + " ae_net.eval()\n", + " device = torch.device(\"cuda:0\")\n", + " ae_net.to(device)\n", + "\n", + " train_loader = ae_net.train_dataloader()\n", + " check_data = first(train_loader)\n", + " z = ae_net.autoencoderkl.encode_stage_2_inputs(check_data[\"image\"].to(ae_net.device))\n", + " print(f\"Scaling factor set to {1/torch.std(z)}\")\n", + " scale_factor = 1 / torch.std(z)\n", + " return scale_factor\n", + "\n", + "\n", + "scale_factor = get_scale_factor()" + ] + }, + { + "cell_type": "markdown", + "id": "3baa2b0f", + "metadata": {}, + "source": [ + "## Define the LightningModule for DiffusionModelUnet (transforms, network, loaders, etc)\n", + "\n", + "The LightningModule contains a refactoring of your training code. The following module is a reformating of the code in 2d_stable_diffusion_v2_super_resolution." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "731034ec", + "metadata": {}, + "outputs": [], + "source": [ + "class DiffusionUNET(pl.LightningModule):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.data_dir = root_dir\n", + " self.unet = DiffusionModelUNet(\n", + " spatial_dims=2,\n", + " in_channels=4,\n", + " out_channels=3,\n", + " num_res_blocks=2,\n", + " channels=(256, 256, 512, 1024),\n", + " attention_levels=(False, False, True, True),\n", + " num_head_channels=(0, 0, 64, 64),\n", + " )\n", + " self.max_noise_level = 350\n", + " self.scheduler = DDPMScheduler(\n", + " num_train_timesteps=1000, schedule=\"linear_beta\", beta_start=0.0015, beta_end=0.0195\n", + " )\n", + " self.z = ae_net.autoencoderkl.eval()\n", + "\n", + " def forward(self, x, timesteps, low_res_timesteps):\n", + " return self.unet(x=x, timesteps=timesteps, class_labels=low_res_timesteps)\n", + "\n", + " def prepare_data(self):\n", + " self.train_ds, self.val_ds = get_datasets()\n", + "\n", + " def train_dataloader(self):\n", + " return ThreadDataLoader(self.train_ds, batch_size=16, shuffle=True, num_workers=4, persistent_workers=True)\n", + "\n", + " def val_dataloader(self):\n", + " return ThreadDataLoader(self.val_ds, batch_size=16, shuffle=False, num_workers=4)\n", + "\n", + " def _calculate_loss(self, batch, batch_idx, plt_image=False):\n", + " images = batch[\"image\"]\n", + " low_res_image = batch[\"low_res_image\"]\n", + " with autocast(\"cuda\", enabled=True):\n", + " with torch.no_grad():\n", + " latent = self.z.encode_stage_2_inputs(images) * scale_factor\n", + "\n", + " # Noise augmentation\n", + " noise = torch.randn_like(latent)\n", + " low_res_noise = torch.randn_like(low_res_image)\n", + " timesteps = torch.randint(\n", + " 0, self.scheduler.num_train_timesteps, (latent.shape[0],), device=latent.device\n", + " ).long()\n", + " low_res_timesteps = torch.randint(\n", + " 0, self.max_noise_level, (low_res_image.shape[0],), device=latent.device\n", + " ).long()\n", + "\n", + " noisy_latent = self.scheduler.add_noise(original_samples=latent, noise=noise, timesteps=timesteps)\n", + " noisy_low_res_image = self.scheduler.add_noise(\n", + " original_samples=low_res_image, noise=low_res_noise, timesteps=low_res_timesteps\n", + " )\n", + "\n", + " latent_model_input = torch.cat([noisy_latent, noisy_low_res_image], dim=1)\n", + "\n", + " noise_pred = self.forward(latent_model_input, timesteps, low_res_timesteps)\n", + " loss = F.mse_loss(noise_pred.float(), noise.float())\n", + "\n", + " if plt_image:\n", + " # Sampling image during training\n", + " sampling_image = low_res_image[0].unsqueeze(0)\n", + " latents = torch.randn((1, 3, 16, 16)).to(sampling_image.device)\n", + " low_res_noise = torch.randn((1, 1, 16, 16)).to(sampling_image.device)\n", + " noise_level = 20\n", + " noise_level = torch.Tensor((noise_level,)).long().to(sampling_image.device)\n", + "\n", + " noisy_low_res_image = self.scheduler.add_noise(\n", + " original_samples=sampling_image,\n", + " noise=low_res_noise,\n", + " timesteps=noise_level,\n", + " )\n", + " self.scheduler.set_timesteps(num_inference_steps=1000)\n", + " for t in tqdm(self.scheduler.timesteps, ncols=110):\n", + " with autocast(\"cuda\", enabled=True):\n", + " with torch.no_grad():\n", + " latent_model_input = torch.cat([latents, noisy_low_res_image], dim=1)\n", + " noise_pred = self.forward(\n", + " latent_model_input, torch.Tensor((t,)).to(sampling_image.device), noise_level\n", + " )\n", + " latents, _ = self.scheduler.step(noise_pred, t, latents)\n", + " with torch.no_grad():\n", + " decoded = self.z.decode_stage_2_outputs(latents / scale_factor)\n", + " low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode=\"bicubic\")\n", + " # plot images\n", + "\n", + " self.images = images\n", + " self.low_res_bicubic = low_res_bicubic\n", + " self.decoded = decoded\n", + "\n", + " return loss\n", + "\n", + " def _plot_image(self, images, low_res_bicubic, decoded):\n", + " plt.figure(figsize=(2, 2))\n", + " plt.style.use(\"default\")\n", + " plt.imshow(\n", + " torch.cat([images[0, 0].cpu(), low_res_bicubic[0, 0].cpu(), decoded[0, 0].cpu()], dim=1),\n", + " vmin=0,\n", + " vmax=1,\n", + " cmap=\"gray\",\n", + " )\n", + " plt.tight_layout()\n", + " plt.axis(\"off\")\n", + " plt.show()\n", + "\n", + " def training_step(self, batch, batch_idx):\n", + " loss = self._calculate_loss(batch, batch_idx)\n", + " self.log(\"train_loss\", loss, batch_size=16, prog_bar=True)\n", + " return loss\n", + "\n", + " def validation_step(self, batch, batch_idx):\n", + " loss = self._calculate_loss(batch, batch_idx, plt_image=True)\n", + " self.log(\"val_loss\", loss, batch_size=16, prog_bar=True)\n", + " return loss\n", + "\n", + " def on_validation_epoch_end(self):\n", + " self._plot_image(self.images, self.low_res_bicubic, self.decoded)\n", + "\n", + " def configure_optimizers(self):\n", + " optimizer = torch.optim.Adam(self.unet.parameters(), lr=5e-5)\n", + " return optimizer" + ] + }, + { + "cell_type": "markdown", + "id": "b386a0c2", + "metadata": {}, + "source": [ + "## Train Diffusion Model\n", + "\n", + "In order to train the diffusion model to perform super-resolution, we will need to **concatenate the latent representation of the high-resolution with the low-resolution image**. Therefore, the number of input channels to the diffusion model will be the sum of the number of channels in the low-resolution (1) and the number of channels of the high-resolution image latent representation (3). In this case, we create a Diffusion model with `in_channels=4`. Since only the output latent representation is interesting, we set `out_channels=3`. \n", + "\n", + "**At inference time** we do not have a high-resolution image. Instead, we pass the concatenation of the low resolution image, and noise of the same shape as the latent space representation.\n", + "\n", + "As mentioned, we will use the conditioned augmentation (introduced in [2] section 3 and used on Stable Diffusion Upscalers and Imagen Video [3] Section 2.5) as it has been shown critical for cascaded diffusion models, as well for super-resolution tasks. For this, we apply Gaussian noise augmentation to the low-resolution images. We will use a scheduler low_res_scheduler to add this noise, with the t step defining the signal-to-noise ratio and use the t value to condition the diffusion model (inputted using class_labels argument)." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "936bbb9c", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "\n", + "Loading dataset: 0%| | 0/320 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "17aa3327b5084a8f96c140125efd91e6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "fa63aaeaed78402f8f478a467bb73cd5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Validation: | …" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "66c3255d7df7478e9277de2e512ec1d7", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "`Trainer.fit` stopped: `max_epochs=200` reached.\n" + ] + } + ], + "source": [ + "max_epochs = 200\n", + "val_interval = 50\n", + "\n", + "\n", + "# initialise the LightningModule\n", + "d_net = DiffusionUNET()\n", + "\n", + "# set up checkpoints\n", + "\n", + "checkpoint_callback = ModelCheckpoint(dirpath=root_dir, filename=\"best_metric_model_dunet\")\n", + "\n", + "\n", + "# initialise Lightning's trainer.\n", + "trainer = pl.Trainer(\n", + " devices=1,\n", + " max_epochs=max_epochs,\n", + " check_val_every_n_epoch=val_interval,\n", + " num_sanity_val_steps=0,\n", + " callbacks=checkpoint_callback,\n", + " default_root_dir=root_dir,\n", + ")\n", + "\n", + "# train\n", + "trainer.fit(d_net)" + ] + }, + { + "cell_type": "markdown", + "id": "30f24595", + "metadata": {}, + "source": [ + "### Plotting sampling example" + ] + }, + { + "cell_type": "markdown", + "id": "19ba049e-fca6-4c76-b7b1-7e992d370583", + "metadata": {}, + "source": [ + "As mentioned above, at inference time, we only need to pass noise of the same shape of the latent concatenated to the low-resolution image, to get the latent representation of the corresponding high-resolution image." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "155be091", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "837b6b2a5e1f42ab862c9cf589dbb615", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/1000 [00:00 x_t-1\n", + " latents, _ = scheduler.step(noise_pred, t, latents)\n", + "\n", + " with torch.no_grad():\n", + " decoded = ae_net.autoencoderkl.decode_stage_2_outputs(latents / scale_factor)\n", + " return sampling_image, images, decoded\n", + "\n", + "\n", + "sampling_image, images, decoded = get_images_to_plot()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "32e16e69", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "low_res_bicubic = nn.functional.interpolate(sampling_image, (64, 64), mode=\"bicubic\")\n", + "fig, axs = plt.subplots(num_samples, 3, figsize=(8, 8))\n", + "axs[0, 0].set_title(\"Original image\")\n", + "axs[0, 1].set_title(\"Low-resolution Image\")\n", + "axs[0, 2].set_title(\"Outputted image\")\n", + "for i in range(0, num_samples):\n", + " axs[i, 0].imshow(images[i, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + " axs[i, 0].axis(\"off\")\n", + " axs[i, 1].imshow(low_res_bicubic[i, 0].cpu(), vmin=0, vmax=1, cmap=\"gray\")\n", + " axs[i, 1].axis(\"off\")\n", + " axs[i, 2].imshow(decoded[i, 0].cpu().detach().numpy(), vmin=0, vmax=1, cmap=\"gray\")\n", + " axs[i, 2].axis(\"off\")\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "7fa52acc", + "metadata": {}, + "source": [ + "### Clean-up data directory" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "3a6f6d5a", + "metadata": {}, + "outputs": [], + "source": [ + "if directory is None:\n", + " shutil.rmtree(root_dir)" + ] + } + ], + "metadata": { + "jupytext": { + "cell_metadata_filter": "-all", + "formats": "ipynb,py:percent" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/generation/README.md b/generation/README.md index d9125a861..351416fd1 100644 --- a/generation/README.md +++ b/generation/README.md @@ -72,3 +72,6 @@ Example shows the use cases of applying a spatial VAE to a 3D synthesis example. ## Performing anomaly detection with diffusion models: [implicit guidance](./anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.ipynb), [using transformers](./anomaly_detection/anomaly_detection_with_transformers.ipynb) and [classifier free guidance](./anomaly_detection/anomalydetection_tutorial_classifier_guidance.ipynb) Examples show how to perform anomaly detection in 2D, using implicit guidance [2D-classifier free guiance](./anomaly_detection/2d_classifierfree_guidance_anomalydetection_tutorial.ipynb), transformers [using transformers](./anomaly_detection/anomaly_detection_with_transformers.ipynb) and [classifier free guidance](./anomalydetection_tutorial_classifier_guidance). + +## 2D super-resolution using diffusion models: [using torch](./2d_super_resolution/2d_sd_super_resolution.ipynb) and [using torch lightning](./2d_super_resolution/2d_sd_super_resolution_lightning.ipynb). +Examples show how to perform super-resolution in 2D, using PyTorch and PyTorch Lightning.