Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ericspod committed Sep 25, 2024
2 parents ca1cdfa + 203a999 commit 4b631e8
Show file tree
Hide file tree
Showing 16 changed files with 292 additions and 258 deletions.
18 changes: 9 additions & 9 deletions active_learning/liver_tumor_al/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def main():
# Model Definition
device = torch.device("cuda:0")
network = UNet(
dimensions=3,
spatial_dims=3,
in_channels=1,
out_channels=3,
channels=(16, 32, 64, 128, 256),
Expand Down Expand Up @@ -187,7 +187,7 @@ def main():
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
SpatialPadd(keys=["image", "label"], spatial_size=(96, 96, 96)),
RandCropByPosNegLabeld(
keys=["image", "label"],
Expand Down Expand Up @@ -225,7 +225,7 @@ def main():
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
CropForegroundd(keys=["image", "label"], source_key="image", allow_smaller=True),
EnsureTyped(keys=["image", "label"]),
]
)
Expand All @@ -240,7 +240,7 @@ def main():
mode=("bilinear"),
),
ScaleIntensityRanged(keys="image", a_min=-21, a_max=189, b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(keys=("image"), source_key="image"),
CropForegroundd(keys=("image"), source_key="image", allow_smaller=True),
EnsureTyped(keys=["image"]),
]
)
Expand Down Expand Up @@ -315,7 +315,7 @@ def main():
unl_loader = DataLoader(unl_ds, batch_size=1)

# Calculation of Epochs based on steps
max_epochs = np.int(args.steps / (np.ceil(len(train_d) / args.batch_size)))
max_epochs = int(args.steps / (np.ceil(len(train_d) / args.batch_size)))
print("Epochs Estimated are {} for Active Iter {} with {} Vols".format(max_epochs, active_iter, len(train_d)))

# Model Training begins for one active iteration
Expand Down Expand Up @@ -393,7 +393,7 @@ def main():
prev_best_ckpt = os.path.join(active_model_dir, "model.pt")

device = torch.device("cuda:0")
ckpt = torch.load(prev_best_ckpt)
ckpt = torch.load(prev_best_ckpt, weights_only=True)
network.load_state_dict(ckpt)
network.to(device=device)

Expand Down Expand Up @@ -487,16 +487,16 @@ def main():

variance_dims = np.shape(variance)
score_list.append(np.nanmean(variance))
name_list.append(unl_data["image_meta_dict"]["filename_or_obj"][0])
name_list.append(unl_data["image"].meta["filename_or_obj"][0])
print(
"Variance for image: {} is: {}".format(
unl_data["image_meta_dict"]["filename_or_obj"][0], np.nanmean(variance)
unl_data["image"].meta["filename_or_obj"][0], np.nanmean(variance)
)
)

# Plot with matplotlib and save all slices
plt.figure(1)
plt.imshow(np.squeeze(variance[:, :, np.int(variance_dims[2] / 2)]))
plt.imshow(np.squeeze(variance[:, :, int(variance_dims[2] / 2)]))
plt.colorbar()
plt.title("Dropout Uncertainty")
fig_path = os.path.join(fig_base_dir, "active_{}_file_{}.png".format(active_iter, counter))
Expand Down
30 changes: 10 additions & 20 deletions active_learning/liver_tumor_al/results_uncertainty_analysis.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions active_learning/tool_tracking_al/active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def main():
unl_loader = DataLoader(unl_ds, batch_size=1)

# Calculation of Epochs based on steps
max_epochs = np.int(args.steps / (np.ceil(len(train_d) / args.batch_size)))
max_epochs = int(args.steps / (np.ceil(len(train_d) / args.batch_size)))
print("Epochs Estimated are {} for Active Iter {} with {} Vols".format(max_epochs, active_iter, len(train_d)))

# Keep track of Best_metric, it is being used as IoU and not Dice
Expand Down Expand Up @@ -379,7 +379,7 @@ def main():
prev_best_ckpt = os.path.join(active_model_dir, "model.pt")

device = torch.device("cuda:0")
ckpt = torch.load(prev_best_ckpt)
ckpt = torch.load(prev_best_ckpt, weights_only=True)
network.load_state_dict(ckpt)
network.to(device=device)

Expand Down

Large diffs are not rendered by default.

175 changes: 61 additions & 114 deletions generation/maisi/maisi_diff_unet_training_tutorial.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions generation/maisi/maisi_inference_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -364,25 +364,25 @@
"device = torch.device(\"cuda\")\n",
"\n",
"autoencoder = define_instance(args, \"autoencoder_def\").to(device)\n",
"checkpoint_autoencoder = torch.load(args.trained_autoencoder_path)\n",
"checkpoint_autoencoder = torch.load(args.trained_autoencoder_path, weights_only=True)\n",
"autoencoder.load_state_dict(checkpoint_autoencoder)\n",
"\n",
"diffusion_unet = define_instance(args, \"diffusion_unet_def\").to(device)\n",
"checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path)\n",
"checkpoint_diffusion_unet = torch.load(args.trained_diffusion_path, weights_only=False)\n",
"diffusion_unet.load_state_dict(checkpoint_diffusion_unet[\"unet_state_dict\"], strict=True)\n",
"scale_factor = checkpoint_diffusion_unet[\"scale_factor\"].to(device)\n",
"\n",
"controlnet = define_instance(args, \"controlnet_def\").to(device)\n",
"checkpoint_controlnet = torch.load(args.trained_controlnet_path)\n",
"checkpoint_controlnet = torch.load(args.trained_controlnet_path, weights_only=False)\n",
"monai.networks.utils.copy_model_state(controlnet, diffusion_unet.state_dict())\n",
"controlnet.load_state_dict(checkpoint_controlnet[\"controlnet_state_dict\"], strict=True)\n",
"\n",
"mask_generation_autoencoder = define_instance(args, \"mask_generation_autoencoder_def\").to(device)\n",
"checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path)\n",
"checkpoint_mask_generation_autoencoder = torch.load(args.trained_mask_generation_autoencoder_path, weights_only=True)\n",
"mask_generation_autoencoder.load_state_dict(checkpoint_mask_generation_autoencoder)\n",
"\n",
"mask_generation_diffusion_unet = define_instance(args, \"mask_generation_diffusion_def\").to(device)\n",
"checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path)\n",
"checkpoint_mask_generation_diffusion_unet = torch.load(args.trained_mask_generation_diffusion_path, weights_only=True)\n",
"mask_generation_diffusion_unet.load_state_dict(checkpoint_mask_generation_diffusion_unet[\"unet_state_dict\"])\n",
"mask_generation_scale_factor = checkpoint_mask_generation_diffusion_unet[\"scale_factor\"]\n",
"\n",
Expand Down
Loading

0 comments on commit 4b631e8

Please sign in to comment.