-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
329 lines (289 loc) · 9.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
"""Contains utility functions for training."""
from typing import List, Type, Union, Callable
from dataclasses import dataclass, asdict
import yaml
import tensorflow as tf
import tensorflow_addons as tfa
import math
import os
import logging
from wandb.keras import WandbCallback
from dataclasses import dataclass
from typing import List, Type, Union, Callable
PI = math.pi
logging.basicConfig(
format="%(asctime)s %(levelname)s : %(message)s",
datefmt="%d-%b-%y %H:%M:%S",
level=logging.INFO,
)
@dataclass
class PreprocessingConfig:
tfrecs_filepath: List[str]
batch_size: int
image_size: int
area_factor: float
crop_size: int
resize_pre_crop: int
augment_fn: Union[str, Callable]
num_classes: int
color_jitter: bool
mixup: bool
mixup_alpha: float
@dataclass
class RegNetYConfig:
"""
Dataclass for architecture configuration for RegNetY.
Args:
name: Name of the model eg. "RegNetY200MF"
flops: Flops of the model eg. "400MF" (Processing one image requires
400 million floating point operations (multiplication, addition))
depths: List of depths for every stage
widths: List of widths (number of channels) after every stage
group_width: Integer denoting groups in every convolution layer
bottleneck_ratio: Integer specifying bottleneck ratio
se_ratio: Float denoting squeeze and excite ratio
wa: Integer, slope used in linear parameterization
w0: Integer, inital value used in linear parameterization
wm: Float, quantization parameter
"""
name: str
flops: str
num_classes: int
depths: List[int]
widths: List[int]
group_width: int
bottleneck_ratio: int
se_ratio: float
wa: int
w0: int
wm: float
@dataclass
class TrainConfig:
"""
Dataclass of training configuration for RegNetY
Args:
optimizer: One of "sgd", "adam", "adamw"
base_lr: Base learning rate for training
warmup_epochs: Number of epochs used for warmup
warmup_factor: Gradual linear warmup factor
total_epochs: Number of training epochs
weight_decay: Weight decay to be used in optimizer
momentum: Momentum to be used in optimizer
lr_schedule: One of "constant" or "half_cos"
log_dir: Path to store logs
model_dir: Path to store model checkpoints
"""
optimizer: str
base_lr: float
warmup_epochs: int
warmup_factor: float
total_epochs: int
weight_decay: float
momentum: float
label_smoothing: float
lr_schedule: str
log_dir: str
model_dir: str
def get_preprocessing_config(
tfrecs_filepath: List[str] = None,
batch_size: int = 1024,
image_size: int = 512,
area_factor: float = 0.08,
crop_size: int = 224,
resize_pre_crop: int = 320,
augment_fn: Union[str, Callable] = "default",
num_classes: int = 1000,
color_jitter: bool = False,
mixup: bool = True,
mixup_alpha: float = 0.2
):
return PreprocessingConfig(
tfrecs_filepath=tfrecs_filepath,
batch_size=batch_size,
image_size=image_size,
area_factor=area_factor,
crop_size=crop_size,
resize_pre_crop=resize_pre_crop,
augment_fn=augment_fn,
num_classes=num_classes,
color_jitter=color_jitter,
mixup=mixup,
mixup_alpha=mixup_alpha,
)
def get_train_config(
optimizer: str = "adamw",
base_lr: float = 0.001 * 8,
warmup_epochs: int = 5,
warmup_factor: float = 0.1,
total_epochs: int = 100,
weight_decay: float = 5e-5,
momentum: float = 0.9,
label_smoothing: float = 0.0,
lr_schedule: str = "half_cos",
log_dir: str = "",
model_dir: str = "",
):
return TrainConfig(
optimizer=optimizer,
base_lr=base_lr,
warmup_epochs=warmup_epochs,
warmup_factor=warmup_factor,
total_epochs=total_epochs,
weight_decay=weight_decay,
label_smoothing=label_smoothing,
momentum=momentum,
lr_schedule=lr_schedule,
log_dir=log_dir,
model_dir=model_dir,
)
def get_optimizer(cfg):
if cfg.optimizer == "sgd":
opt = tfa.optimizers.SGDW(
weight_decay=cfg.weight_decay,
learning_rate=cfg.base_lr,
momentum=cfg.momentum,
nesterov=True,
)
# opt = tfa.optimizers.MovingAverage(
# opt,
# average_decay=0.99999,
# start_step=6250,
# )
return opt
elif cfg.optimizer == "adam":
return tf.keras.optimizers.Adam(
learning_rate=cfg.base_lr,
)
elif cfg.optimizer == "adamw":
return tfa.optimizers.AdamW(
weight_decay=cfg.weight_decay, learning_rate=cfg.base_lr
)
elif cfg.optimizer == "nadamw":
opt = tfa.optimizers.extend_with_decoupled_weight_decay(tf.keras.optimizers.Nadam)
return opt(
weight_decay=cfg.weight_decay, learning_rate=cfg.base_lr
)
else:
raise NotImplementedError(
f"Optimizer choice not supported: {cfg.optimizer}")
def get_train_schedule(cfg):
if cfg.lr_schedule == "half_cos":
def half_cos_schedule(epoch, lr):
# Taken from pycls/pycls/core/optimizer.py, since not clear from paper.
if epoch < cfg.warmup_epochs:
new_lr = (
0.5
* (1.0 + tf.math.cos(PI * epoch / cfg.total_epochs))
* cfg.base_lr
)
alpha = epoch / cfg.warmup_epochs
warmup_factor = cfg.warmup_factor * (1.0 - alpha) + alpha
return new_lr * warmup_factor
else:
new_lr = (
0.5
* (1.0 + tf.math.cos(PI * epoch / cfg.total_epochs))
* cfg.base_lr
)
return new_lr
return half_cos_schedule
elif cfg.lr_schedule == "constant":
return cfg.base_lr
def get_callbacks(cfg, timestr):
lr_callback = tf.keras.callbacks.LearningRateScheduler(
get_train_schedule(cfg))
tboard_callback = tf.keras.callbacks.TensorBoard(
log_dir=os.path.join(cfg.log_dir, timestr), histogram_freq=1
) # profile_batch="0,1023"
best_model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(
cfg.model_dir,
timestr,
"best_model_epoch_{epoch:02d}_val_loss_{val_loss:.2f}",
),
save_weights_only=True,
monitor="val_accuracy",
mode="max",
save_best_only=True,
)
all_models_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(
cfg.model_dir,
timestr,
"all_model_epoch_{epoch:02d}",
),
save_weights_only=True,
monitor="val_accuracy",
mode="max",
save_best_only=False,
)
average_saving_callback = tfa.callbacks.AverageModelCheckpoint(
update_weights=False,
filepath=os.path.join(
cfg.model_dir,
timestr,
"average_model_epoch_{epoch:02d}",
),
save_weights_only=True,
monitor="val_accuracy",
mode="max",
save_best_only=False,
)
return [
lr_callback,
# tboard_callback,
# best_model_checkpoint_callback,
# average_saving_callback,
all_models_checkpoint_callback,
WandbCallback(),
]
def get_config_dict(train_prep_cfg, val_prep_cfg, train_cfg, misc=None):
config_dict = dict()
train_prep_dict = asdict(train_prep_cfg)
val_prep_dict = asdict(val_prep_cfg)
del train_prep_dict["tfrecs_filepath"]
del val_prep_dict["tfrecs_filepath"]
config_dict["train_prep"] = train_prep_dict
config_dict["val_prep"] = val_prep_dict
config_dict["train_cfg"] = asdict(train_cfg)
config_dict["misc"] = misc
return config_dict
# def make_model(flops, train_cfg):
# optim = get_optimizer(train_cfg)
# model = regnety.models.model.RegNetY(flops)
# model.compile(
# loss=tf.keras.losses.CategoricalCrossentropy(
# from_logits=True, label_smoothing=0.2),
# optimizer=optim,
# metrics=[
# tf.keras.metrics.CategoricalAccuracy(name="accuracy"),
# tf.keras.metrics.TopKCategoricalAccuracy(
# 5, name="top-5-accuracy"),
# ],
# )
# return model
def connect_to_tpu(tpu_address: str = None):
if tpu_address is not None:
cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu=tpu_address
)
if tpu_address not in ("", "local"):
tf.config.experimental_connect_to_cluster(cluster_resolver)
tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
logging.info(f"Running on TPU {cluster_resolver.master()}")
logging.info(f"REPLICAS: {strategy.num_replicas_in_sync}")
return cluster_resolver, strategy
else:
try:
cluster_resolver = (
tf.distribute.cluster_resolver.TPUClusterResolver.connect()
)
strategy = tf.distribute.TPUStrategy(cluster_resolver)
logging.info(f"Running on TPU {cluster_resolver.master()}")
logging.info(f"REPLICAS: {strategy.num_replicas_in_sync}")
return cluster_resolver, strategy
except:
logging.warning("No TPU detected.")
mirrored_strategy = tf.distribute.MirroredStrategy()
return None, mirrored_strategy