-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_model.py
58 lines (39 loc) · 1.99 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
from importlib.machinery import SourceFileLoader
import argparse
import utils
# if not sys_config.running_on_gpu_host:
# import matplotlib.pyplot as plt
from train_model import UNetModel
import shutil
from data.lidc_data import lidc_data
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
structures_dict = {1: 'RV', 2: 'Myo', 3: 'LV'}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Script for training")
parser.add_argument("EXP_PATH", type=str, help="Path to experiment config file")
parser.add_argument("LOCAL", type=str, help="Is this script run on the local machine or the BIWI cluster?")
args = parser.parse_args()
config_file = args.EXP_PATH
config_module = config_file.split('/')[-1].rstrip('.py')
if args.LOCAL == 'local':
print('Running with local configuration')
import config.local_config as sys_config
import matplotlib.pyplot as plt
else:
import config.system as sys_config
exp_config = SourceFileLoader(config_module, config_file).load_module()
log_dir = os.path.join(sys_config.log_root, exp_config.log_dir_name, exp_config.experiment_name)
utils.makefolder(log_dir)
shutil.copy(exp_config.__file__, log_dir)
basic_logger = utils.setup_logger('basic_logger', log_dir + '/test_log_lowest_ged.log')
basic_logger.info('Running experiment with script: {}'.format(config_file))
basic_logger.info('!!!! Copied exp_config file to experiment folder !!!!')
basic_logger.info('**************************************************************')
basic_logger.info(' *** Running Experiment: %s', exp_config.experiment_name)
basic_logger.info('**************************************************************')
model = UNetModel(exp_config, logger=basic_logger, tensorboard=False)
transform = None
data = exp_config.data_loader(sys_config=sys_config, exp_config=exp_config)
model.generate_images(data, sys_config=sys_config)