-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_net.py
84 lines (70 loc) · 2.74 KB
/
test_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# import _init_path
from test_engine import test_net
from config.base_config import cfg, print_cfg, get_models_dir, cfg_from_file
import caffe
# from networks import models
from networks.models import Net
import argparse
import pprint
import time, os, sys
import numpy as np
import os.path as osp
from utils.dictionary import Dictionary
def parse_args():
"""
Parse input arguments
"""
parser = argparse.ArgumentParser(description='Test a Visual Grounding network')
parser.add_argument('--gpu_id', help='gpu_id', default=0, type=int)
parser.add_argument('--test_split', help='test_split', default='val', type=str)
parser.add_argument('--batchsize', help='batchsize', default=64, type=int)
parser.add_argument('--vis_pred', help='visualize prediction', default=False, type=bool)
parser.add_argument('--test_net', help='Net', default=None, type=str)
parser.add_argument('--pretrained_model', help='pretrained_model', type=str)
parser.add_argument(
'--cfg',
dest='cfg_file',
help='optional config file',
# default='config/experiments/refcoco-kld-bbox_reg.yaml',
type=str
)
if len(sys.argv) == 1:
parser.print_help()
sys.exit(1)
opts = parser.parse_args()
return opts
if __name__ == '__main__':
opts = parse_args()
# print('Using config:')
# pprint.pprint(cfg)
if opts.cfg_file is not None:
cfg_from_file(opts.cfg_file)
print_cfg()
if opts.test_net is None:
qdic_dir = cfg.QUERY_DIR # osp.join(cfg.DATA_DIR, cfg.IMDB_NAME, 'query_dict')
qdic = Dictionary(qdic_dir)
qdic.load()
vocab_size = qdic.size()
test_model = Net(opts.test_split, vocab_size, opts)
test_net_path = osp.join(get_models_dir(), 'test.prototxt')
with open(test_net_path, 'w') as f:
f.write(str(test_model))
else:
test_net_path = opts.test_net
caffe.set_mode_gpu()
caffe.set_device(opts.gpu_id)
net = caffe.Net(test_net_path, opts.pretrained_model, caffe.TEST)
net.name = os.path.splitext(os.path.basename(opts.pretrained_model))[0]
log_file = osp.join(cfg.LOG_DIR, '%s_%s_%s_accuracy.txt' % (cfg.IMDB_NAME, cfg.FEAT_TYPE, cfg.PROJ_NAME))
if os.path.exists(log_file):
os.remove(log_file)
test_split = opts.test_split
if type(test_split) is list:
for split in test_split:
accuracy = test_net(split, net, opts.batchsize, vis=opts.vis_pred)
with open(log_file, 'a') as f:
f.write('%s accuracy: %f\n' % (split, accuracy))
else:
accuracy = test_net(test_split, net, opts.batchsize, vis=opts.vis_pred)
with open(log_file, 'a') as f:
f.write('%s accuracy: %f\n' % (test_split, accuracy))