-
Notifications
You must be signed in to change notification settings - Fork 2
/
loss_functions.py
355 lines (287 loc) · 13.7 KB
/
loss_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
import tensorflow as tf
import numpy as np
from utils import min_max_norm_tf, z_score_norm_tf
from clDice_func import soft_dice_cldice_loss
@tf.function
def reduce_mean(self, inputs, axis=None, keepdims=False):
"""
Compute the mean of the inputs tensor along the given axis and divide by the global batch size.
Args:
- inputs: A tensor of values to compute the mean on.
- axis: The dimensions along which to compute the mean. If None (default), compute the mean over all dimensions.
- keepdims: If True, retains the reduced dimensions with length 1. If False (default), the reduced dimensions are removed.
Returns:
- A tensor with the mean of the inputs tensor along the given axis divided by the global batch size.
"""
arr = tf.reduce_mean(inputs, axis=axis, keepdims=keepdims)
return tf.reduce_sum(arr) / self.global_batch_size
@tf.function
def MSLE(self, real, fake):
"""
Compute the per-sample mean squared logarithmic error (MSLE) between the real and fake tensors.
Args:
- real: A tensor of real values.
- fake: A tensor of fake values.
Returns:
- A scalar tensor representing the per-sample MSLE between the real and fake tensors.
"""
return reduce_mean(self, tf.square(tf.math.log(real + 1.) - tf.math.log(fake + 1.)),
axis=list(range(1, len(real.shape))))
@tf.function
def MAE(self, y_true, y_pred):
"""
Compute the per-sample mean absolute error (MAE) between the true and predicted tensors.
Args:
- y_true: A tensor of true values.
- y_pred: A tensor of predicted values.
Returns:
- A scalar tensor representing the per-sample MAE between the true and predicted tensors.
"""
return reduce_mean(self, tf.abs(y_true - y_pred), axis=list(range(1, len(y_true.shape))))
@tf.function
def MSE(self, y_true, y_pred):
"""
Compute the per-sample mean squared error (MSE) between the true and predicted tensors.
Args:
- y_true: A tensor of true values.
- y_pred: A tensor of predicted values.
Returns:
- A scalar tensor representing the per-sample MSE between the true and predicted tensors.
"""
return reduce_mean(self, tf.square(y_true - y_pred), axis=list(range(1, len(y_true.shape))))
@tf.function
def L4(self, y_true, y_pred):
"""
Compute the per-sample L4 loss between the true and predicted tensors.
Args:
- y_true: A tensor of true values.
- y_pred: A tensor of predicted values.
Returns:
- A scalar tensor representing the per-sample L4 loss between the true and predicted tensors.
"""
return reduce_mean(self, tf.math.pow(y_true - y_pred, 4), axis=list(range(1, len(y_true.shape))))
@tf.function
def ssim_loss_3d(y_true, y_pred, max_val=1.0, filter_size=3, filter_sigma=1.5, k1=0.01, k2=0.03):
# Create Gaussian filter
def gaussian_filter(size, sigma):
grid = tf.range(-size // 2 + 1, size // 2 + 1, dtype=tf.float32)
gaussian_filter = tf.exp(-0.5 * (grid / sigma) ** 2) / (sigma * tf.sqrt(2.0 * np.pi))
return gaussian_filter / tf.reduce_sum(gaussian_filter)
# Create 3D Gaussian filter
filter_3d = gaussian_filter(filter_size, filter_sigma)
filter_3d = tf.einsum('i,j,k->ijk', filter_3d, filter_3d, filter_3d)
filter_3d = filter_3d[:, :, :, tf.newaxis, tf.newaxis]
# Compute mean and variance
mu_true = tf.nn.conv3d(y_true, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME')
mu_pred = tf.nn.conv3d(y_pred, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME')
mu_true_sq = mu_true ** 2
mu_pred_sq = mu_pred ** 2
mu_true_pred = mu_true * mu_pred
sigma_true_sq = tf.nn.conv3d(y_true ** 2, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') - mu_true_sq
sigma_pred_sq = tf.nn.conv3d(y_pred ** 2, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') - mu_pred_sq
sigma_true_pred = tf.nn.conv3d(y_true * y_pred, filter_3d, strides=[1, 1, 1, 1, 1], padding='SAME') - mu_true_pred
c1 = (k1 * max_val) ** 2
c2 = (k2 * max_val) ** 2
ssim_map = (2 * mu_true_pred + c1) * (2 * sigma_true_pred + c2) / (
(mu_true_sq + mu_pred_sq + c1) * (sigma_true_sq + sigma_pred_sq + c2))
# Compute the mean SSIM loss across the batch
return 1.0 - ssim_map
@tf.function
def wasserstein_loss(prob_real_is_real, prob_fake_is_real):
"""
Compute the Wasserstein loss between the probabilities that the real inputs are real and the generated inputs are real.
Args:
- prob_real_is_real: A tensor representing the probability that the real inputs are real.
- prob_fake_is_real: A tensor representing the probability that the generated inputs are real.
Returns:
- A scalar tensor representing the Wasserstein loss between the two input probability tensors.
"""
return tf.reduce_mean(prob_real_is_real - prob_fake_is_real)
@tf.function
def matched_crop(self, stack, axis=None, rescale=False):
"""
Randomly crop the input tensor `stack` (which is compose of two image stacks) along a specified axis and return the resulting cropped tensors.
Args:
- stack: A tensor to be cropped.
- axis: The axis along which to crop the input tensor. If axis=1, the input tensor will be cropped horizontally; if axis=3, it will be cropped vertically. Defaults to None.
- rescale: A Boolean value indicating whether to rescale the resulting tensor values between 0 and 1. Defaults to False.
Returns:
- A tuple containing two cropped tensors of the same shape as the input tensor.
"""
if axis == 1:
shape = (self.batch_size, 2 * self.img_size[1], self.img_size[2], 1, self.channels)
raxis = 3
elif axis == 3:
shape = (self.batch_size, 1, self.img_size[2], 2 * self.img_size[3], self.channels)
raxis = 1
axis -= 1
arr = tf.squeeze(tf.image.random_crop(stack, size=shape),
axis=raxis)
if rescale:
arr = min_max_norm_tf(arr)
return tf.split(arr, num_or_size_splits=2, axis=axis)
@tf.function
def cycle_loss(self, real_image, cycled_image, typ=None):
"""
Compute the cycle consistency loss between real and cycled images.
Args:
self (object): The instance of the class.
real_image (tensor): The tensor of real images.
cycled_image (tensor): The tensor of cycled images.
typ (string): The type of loss to compute. It can be set to "mse", "L4", or None (default).
Returns:
tensor: The cycle consistency loss tensor.
"""
if typ is None:
return MAE(self, real_image, cycled_image) * self.lambda_cycle
elif typ == "mse":
return MSE(self, real_image, cycled_image) * self.lambda_cycle
elif typ == "L4":
return L4(self,
real_image,
cycled_image) * self.lambda_cycle
else:
real = min_max_norm_tf(real_image, axis=(1, 2, 3, 4))
cycled = min_max_norm_tf(cycled_image, axis=(1, 2, 3, 4))
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=False, reduction=tf.keras.losses.Reduction.NONE)
# loss_obj = tf.keras.losses.BinaryFocalCrossentropy(from_logits=False, reduction=tf.keras.losses.Reduction.NONE)
return reduce_mean(self, loss_obj(real, cycled)) * self.lambda_cycle
@tf.function
def cycle_reconstruction(self, real_image, cycled_image):
"""
Return the per sample cycle reconstruction loss using Structural Similarity Index (SSIM) loss
Args:
- real_image: Tensor, shape (batch_size, H, W, C), representing the real image
- cycled_image: Tensor, shape (batch_size, H, W, C), representing the cycled image
Returns:
- loss: float Tensor, representing the per sample cycle reconstruction loss
"""
return reduce_mean(self,
ssim_loss_3d(min_max_norm_tf(real_image, axis=(1, 2, 3, 4)),
min_max_norm_tf(cycled_image, axis=(1, 2, 3, 4)), max_val=1.0)
) * self.lambda_reconstruction
@tf.function
def cycle_seg_loss(self, real_image, cycled_image):
"""
Compute the segmentation loss between the real image and the cycled image
Args:
- real_image: a tensor of shape (batch_size, image_size, image_size, channels) representing the real image
- cycled_image: a tensor of shape (batch_size, image_size, image_size, channels) representing the cycled image
Returns:
- a scalar tensor representing the segmentation loss
"""
real = min_max_norm_tf(real_image, axis=(1, 2, 3, 4))
cycled = min_max_norm_tf(cycled_image, axis=(1, 2, 3, 4))
cl_loss_obj = soft_dice_cldice_loss()
return cl_loss_obj(real, cycled) * (self.lambda_topology / self.n_devices)
@tf.function
def identity_loss(self, real_image, same_image, typ=None):
"""
Compute the identity loss between the real image and the same image.
Args:
real_image: the real image
same_image: the generated same image
typ: the type of loss to use. Currently only supports "cldice", other MAE used.
Returns:
The identity loss between the real image and the same image.
"""
if typ is None:
return self.lambda_identity * MAE(self, real_image, same_image)
else:
if typ == "cldice":
real = min_max_norm_tf(real_image)
same = min_max_norm_tf(same_image)
loss_obj = soft_dice_cldice_loss()
# bf_loss_obj = tf.keras.losses.BinaryFocalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)
# id_loss = reduce_mean(self, bf_loss_obj(real, same_image)) * self.lambda_identity
spat_loss = reduce_mean(self, loss_obj(real, same)) * self.lambda_identity
return spat_loss
@tf.function
def generator_loss_fn(self, fake_image, typ=None, from_logits=True):
"""
Calculates the loss for the generator.
Args:
self (object): Instance of the VANGAN class.
fake_image (tf.Tensor): Generated fake image tensor.
typ (str): Type of loss. If None, default MSE is used.
Else, the valid types are: "bce" - Binary cross-entropy,
"bfce" - Binary focal cross-entropy.
Default: None.
from_logits (bool): Whether to use logits or probabilities.
Default: True.
Returns:
tf.Tensor: The generator loss.
"""
if typ is None:
return MSE(self, tf.ones_like(fake_image), fake_image)
else:
if typ == "bce":
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=from_logits,
reduction=tf.keras.losses.Reduction.NONE)
elif typ == "bfce":
loss_obj = tf.keras.losses.BinaryFocalCrossentropy(from_logits=from_logits,
reduction=tf.keras.losses.Reduction.NONE)
fake = fake_image
if not from_logits:
fake = min_max_norm_tf(fake, axis=(1, 2, 3, 4))
loss = loss_obj(tf.ones_like(fake_image), fake)
return reduce_mean(self, loss)
@tf.function
def discriminator_loss_fn(self, real_image, fake_image, typ=None, from_logits=True):
"""
Calculates the loss for the discriminator network.
Args:
self: The instance of the VANGAN model.
real_image: A tensor representing the real image.
fake_image: A tensor representing the fake image.
typ (str, optional): The type of loss function to use. Defaults to None.
from_logits (bool, optional): Whether to apply sigmoid activation function to the predictions.
Defaults to True.
Returns:
A tensor representing the discriminator loss.
"""
if typ is None:
return 0.5 * (
MSE(self, tf.ones_like(real_image), real_image) + MSE(self, tf.zeros_like(fake_image), fake_image))
else:
if typ == "bce":
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=from_logits,
reduction=tf.keras.losses.Reduction.NONE)
elif typ == "bfce":
loss_obj = tf.keras.losses.BinaryFocalCrossentropy(from_logits=from_logits,
reduction=tf.keras.losses.Reduction.NONE)
real = real_image
fake = fake_image
if from_logits == False:
real = min_max_norm_tf(real)
fake = min_max_norm_tf(fake)
loss = (loss_obj(tf.ones_like(real), real) + loss_obj(tf.zeros_like(fake), fake)) * 0.5
return reduce_mean(self, loss)
def wasserstein_discriminator_loss(self, prob_real_is_real, prob_fake_is_real):
"""Computes the Wassertein-GAN loss as minimized by the discriminator.
From paper :
WasserteinGAN : https://arxiv.org/pdf/1701.07875.pdf
by Martin Arjovsky, Soumith Chintala and Léon Bottou
Args:
prob_real_is_real: The discriminator's estimate that images actually
drawn from the real domain are in fact real.
prob_fake_is_real: The discriminator's estimate that generated images
made to look like real images are real.
Returns:
The total W-GAN loss.
"""
return -reduce_mean(self, prob_real_is_real - prob_fake_is_real)
def wasserstein_generator_loss(self, prob_fake_is_real):
"""Computes the Wassertein-GAN loss as minimized by the generator.
From paper :
WasserteinGAN : https://arxiv.org/pdf/1701.07875.pdf
by Martin Arjovsky, Soumith Chintala and Léon Bottou
Args:
prob_real_is_real: The discriminator's estimate that images actually
drawn from the real domain are in fact real.
prob_fake_is_real: The discriminator's estimate that generated images
made to look like real images are real.
Returns:
The total W-GAN loss.
"""
return -reduce_mean(self, prob_fake_is_real)