-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate_periodicity_predictions.py
182 lines (165 loc) · 6.71 KB
/
evaluate_periodicity_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#!/usr/bin/env
import argparse
from collections import defaultdict
import datetime
import re
import itertools
"""
File to evaluate the quality of event predictions
"""
parser = argparse.ArgumentParser(description = "")
parser.add_argument('-i', action = 'store', required = True, help = "The file with identified events")
parser.add_argument('-p', action = 'store', required = True, help = "The predictions file")
parser.add_argument('-o', action = 'store', required = True, help = "The directory to write outcomes to")
parser.add_argument('-y', action = 'store_true', help = "include only yearly patterns")
parser.add_argument('-d', action = 'store_true', help = "divide pattern types")
parser.add_argument('--std', action = 'store_true', help = "specify if predictions are made by std")
args = parser.parse_args()
eventsfile = open(args.i,"r",encoding="utf-8")
predictfile = open(args.p,"r",encoding="utf-8")
print("generating dicts")
#generate term_dates dict from file
term_dates = defaultdict(list)
if args.std:
for line in eventsfile.readlines():
if re.match(r"^\d+\.\d+",line):
try:
tokens = line.strip().split("\t")
terms = tokens[1].split(", ")
dates_raw = [x[:10] for x in tokens[2].split(" > ")]
for date in dates_raw:
entries = date.split("-")
for term in terms:
term_dates[term].append(datetime.datetime(int(entries[0]),int(entries[1]),int(entries[2])))
except:
continue
else:
for line in eventsfile.readlines():
if line[0] == "<":
tokens = line.strip().split("\t")
terms = tokens[1].split(", ")
dates_raw = tokens[3].split(" > ")
for date in dates_raw:
entries = date.split("-")
for term in terms:
term_dates[term].append(datetime.datetime(int(entries[0]),int(entries[1]),int(entries[2])))
#generate term_predictions dict from file
terms_predictions = defaultdict(list)
for line in predictfile.readlines():
tokens = line.strip().split("\t")
terms = "_".join(sorted(tokens[0].split(", ")))
pattern = tokens[1]
fields = tokens[2][1:-1].split(", ")
if args.y and pattern[1] == "v":
continue
else:
#print(pattern,pattern[1])
predict_date = datetime.datetime(int(fields[5][-4:]),int(fields[6]),int(fields[7]))
score = float(fields[11])
if args.std:
terms_predictions[terms].append((predict_date,pattern,score))
else:
coverage = float(fields[12])
consistency = float(fields[13])
terms_predictions[terms].append((predict_date,pattern,score,coverage,consistency))
#print(terms_predictions)
print("scoring predictions")
#match predictions with occurrences and list scores and accuracies
resultsfile = open(args.o + "results.txt","w",encoding = "utf-8")
#results9 = open(args.o + "results9.txt","w",encoding="utf-8")
scores_raw = open(args.o + "scores_raw.txt","w",encoding = "utf-8")
score_accuracies = []
coverage_accuracies = []
consistency_accuracies = []
print("assessment")
for term in terms_predictions.keys():
# print(term,terms_predictions[term])
predictions_sorted = sorted(terms_predictions[term],key = lambda x : x[0])
if re.search("_",term):
dates = []
ts = term.split("_")
for term in ts:
dates.extend(term_dates[term])
else:
dates = term_dates[term]
# i = len(ts) - 1
# while i >= 2:
# for combination in itertools.combinations(ts,i):
# dates.extend(term_dates["_".join(sorted(combination))])
# i -= 1
# for t in ts:
# dates.extend(term_dates[t])
dates = list(set(dates))
prediction = predictions_sorted[0]
# for prediction in predictions:
prdate = prediction[0]
if prdate in dates:
assessment = "Correct"
else:
assessment = "False"
resultsfile.write("\t".join([term,prediction[1],str(prdate),assessment]) + "\n")
#if prediction[2] >= 0.9:
# results9.write("\t".join([term,prediction[1],str(prdate),assessment]) + "\n")
score_accuracies.append([term,prediction[2],assessment])
if not args.std:
coverage_accuracies.append([term,prediction[3],assessment])
consistency_accuracies.append([term,prediction[4],assessment])
resultsfile.close()
def score_accuracy(data,sdev):
outlist = []
if sdev:
# print("sdev")
thresh = 0
while thresh <= 25:
below_thresh = [x for x in data if x[1] <= thresh]
# if thresh == 1.0 and pr:
# print(above_thresh)
# print(thresh,below_thresh)
accuracy = len([s for s in below_thresh if s[2] == "Correct"]) / len(below_thresh)
outlist.append([str(thresh),str(accuracy)])
thresh += 0.1
else:
thresh = 1.0
while thresh >= 0:
above_thresh = [x for x in data if x[1] >= thresh]
# if thresh == 1.0 and pr:
# print(above_thresh)
accuracy = len([s for s in above_thresh if s[2] == "Correct"]) / len(above_thresh)
outlist.append([str(thresh),str(accuracy)])
thresh -= 0.01
return outlist
def rank_accuracy(data,sdev):
outlist = []
if sdev:
ranked_data = sorted(data,key = lambda k : k[1])
else:
ranked_data = sorted(data,key = lambda k : k[1],reverse = True)
for r in range(50,len(ranked_data),50):
accuracy = len([s for s in ranked_data[:r] if s[2] == "Correct"]) / len(ranked_data[:r])
outlist.append([str(r),str(accuracy)])
return outlist
print("accuracy plots")
print("score")
accuracies_score = score_accuracy(score_accuracies,args.std)
for accuracy in accuracies_score:
scores_raw.write(" ".join(accuracy) + "\n")
scores_raw.close()
print("rank")
ranked_file = open(args.o + "accuracy_by_rank.txt","w",encoding = "utf-8")
accuracies_rank = rank_accuracy(score_accuracies,args.std)
for rank in accuracies_rank:
ranked_file.write(" ".join(rank) + "\n")
ranked_file.close()
if not args.std:
coverage_raw = open(args.o + "coverage_raw.txt","w",encoding = "utf-8")
consistency_raw = open(args.o + "consistency_raw.txt","w",encoding = "utf-8")
print("coverage")
accuracies_coverage = score_accuracy(coverage_accuracies,sdev=False)
print("consistency")
accuracies_consistency = score_accuracy(consistency_accuracies,sdev=False)
for accuracy in accuracies_coverage:
coverage_raw.write(" ".join(accuracy) + "\n")
coverage_raw.close()
for accuracy in accuracies_consistency:
consistency_raw.write(" ".join(accuracy) + "\n")
consistency_raw.close()