name:
src/python/analyze.py
-rw-r--r--
6278
1import warnings
2import glob
3import matplotlib.pyplot as plt
4import matplotlib.cm as cm
5import numpy as np
6warnings.filterwarnings("ignore")
7from fingerprinter.reader import read_single
8from fingerprinter.fingerprint import Fingerprinter
9from fingerprinter.fingerprint_record import determine_match
10from database.fingerprint_db import FingerprintDatabase
11from etc.util import get_args_for_input_output_directories, get_file_name
12
13DESCRIPTION = """
14This script will load all clips in a given directory (-d), fingerprint them, and lookup the corresponding episode in
15the fingerprint database. It will then plot the positive, negative, false-positive, and false-negative matches to rate
16the effectiveness of the current parameters of the algorithm. These plots will be saved as .png files in a given
17output directory.
18
19Files in the input directory should be in the format "{episode}.{play_head}.{noise_class}.wav", as written in the
20clips.py script.
21"""
22PLOT_FINGERPRINTS = False
23MAX_HASH_LOOKUP = 512
24
25
26def pull_metadata_from_file_path(file_path):
27 filename = file_path.split("/")[-1:][0]
28 sp = filename.split('.')
29 episode_name = sp[0]
30 t = sp[1]
31 wav_class = "NORMAL"
32 if "HN" in filename:
33 wav_class = "HN"
34 elif "MN" in filename:
35 wav_class = "MN"
36 elif "LN" in filename:
37 wav_class = "LN"
38 return (wav_class, episode_name, t)
39
40
41if __name__ == "__main__":
42 args = get_args_for_input_output_directories(DESCRIPTION)
43 print("Loading .wav files for analysis from {}".format(args.directory))
44 db = FingerprintDatabase()
45 results = {
46 "NORMAL": {
47 "tp": 0,
48 "tn": 0,
49 "fp": 0,
50 "fn": 0
51 },
52 "LN": {
53 "tp": 0,
54 "tn": 0,
55 "fp": 0,
56 "fn": 0
57 },
58 "MN": {
59 "tp": 0,
60 "tn": 0,
61 "fp": 0,
62 "fn": 0
63 },
64 "HN": {
65 "tp": 0,
66 "tn": 0,
67 "fp": 0,
68 "fn": 0
69 }
70 }
71 f = Fingerprinter(plot_fingerprint=PLOT_FINGERPRINTS)
72 for filename in glob.glob(args.directory + "positive/*.wav"):
73 print("Current working positive clip: {}".format(filename))
74 wav_class, episode_name, tstamp = pull_metadata_from_file_path(filename)
75 print("Positive Clip: {} {} {}".format(wav_class, episode_name, tstamp))
76 channel, frame_rate = read_single(filename)
77 hashes = [x for x in f.fingerprint(channel, frame_rate)]
78 hash_numbers = [hash_pair[1] for hash_pair in hashes[0:MAX_HASH_LOOKUP]]
79 fingerprint_records = db.find_records_by_hashes(hash_numbers)
80 actual_match_name = determine_match(fingerprint_records)
81 print("expected {}, actual {}".format(episode_name, actual_match_name))
82 results[wav_class]["tp" if actual_match_name == episode_name else "fn"] += 1
83 for filename in glob.glob(args.directory + "negative/*.wav"):
84 print("Current working negative clip: {}".format(filename))
85 wav_class, episode_name, tstamp = pull_metadata_from_file_path(filename)
86 print("Negative Clip: {} {} {}".format(wav_class, episode_name, tstamp))
87 channel, frame_rate = read_single(filename)
88 hashes = [x for x in f.fingerprint(channel, frame_rate)]
89 hash_numbers = [hash_pair[1] for hash_pair in hashes[0:MAX_HASH_LOOKUP]]
90 fingerprint_records = db.find_records_by_hashes(hash_numbers)
91 should_be_none = determine_match(fingerprint_records)
92 print("expected None, actual {}".format(should_be_none))
93 results[wav_class]["tn" if should_be_none is None else "fp"] += 1
94 print(results)
95 graphing_dict = {
96 "LN": {
97 "label": "LN",
98 "color": "blue"
99 },
100 "MN": {
101 "label": "MN",
102 "color": "orange"
103 },
104 "HN": {
105 "label": "HN",
106 "color": "red"
107 },
108 "NORMAL": {
109 "label": "NORMAL",
110 "color": "green"
111 }
112 }
113 for classification in ["NORMAL", "LN", "MN", "HN"]:
114 tp = float(results[classification]["tp"])
115 fp = float(results[classification]["fp"])
116 tn = float(results[classification]["tn"])
117 fn = float(results[classification]["fn"])
118 precision = tp / (tp + fp)
119 recall = tp / (tp + fn)
120 specificity = tn / (tn + fp)
121 accuracy = (tp + tn) / (fp + tp + fn + tn)
122 ppcr = (tp + fp) / (fp + tp + fn + tn)
123 print(classification, "precision", precision)
124 print(classification, "recall", recall)
125 print(classification, "specificity", specificity)
126 print(classification, "accuracy", accuracy)
127 print(classification, "ppcr", ppcr)
128 print("")
129 plt.scatter([recall], [precision], marker="o", facecolors="None", color=graphing_dict[classification]["color"], s=40, linewidths=2, label=graphing_dict[classification]["label"])
130 plt.title("Recall vs Precision")
131 plt.ylabel("Precision")
132 plt.xlabel("Recall")
133 plt.xlim(-1, 1)
134 plt.ylim(-1, 1)
135 plt.legend(scatterpoints=1, loc='lower right')
136 plt.axvline(x=0, color='b', linestyle=':')
137 plt.axhline(y=0, color='b', linestyle=':')
138 plt.tight_layout()
139 plt.show()
140
141 # grouping specificity, accuracy, ppcr
142 rows = []
143 for classification in ["NORMAL", "LN", "MN", "HN"]:
144 tp = float(results[classification]["tp"])
145 fp = float(results[classification]["fp"])
146 tn = float(results[classification]["tn"])
147 fn = float(results[classification]["fn"])
148 precision = tp / (tp + fp)
149 recall = tp / (tp + fn)
150 specificity = tn / (tn + fp)
151 accuracy = (tp + tn) / (fp + tp + fn + tn)
152 ppcr = (tp + fp) / (fp + tp + fn + tn)
153 rows.append([classification, "specificity", specificity])
154 rows.append([classification, "accuracy", accuracy])
155 rows.append([classification, "ppcr", ppcr])
156 dpoints = np.array(rows)
157 fig = plt.figure()
158 ax = fig.add_subplot(111)
159 space = 0.3
160 conditions = np.unique(dpoints[:, 0])
161 categories = np.unique(dpoints[:, 1])
162 n = len(conditions)
163 width = (1 - space) / (len(conditions))
164 for i, cond in enumerate(conditions):
165 indeces = range(1, len(categories) + 1)
166 vals = dpoints[dpoints[:, 0] == cond][:, 2].astype(np.float)
167 pos = [j - (1 - space) / 2. + i * width for j in range(1, len(categories) + 1)]
168 ax.bar(pos, vals, width=width, label=cond, color=cm.Accent(float(i) / n))
169 ax.set_xticks(indeces)
170 ax.set_xticklabels(categories)
171 handles, labels = ax.get_legend_handles_labels()
172 ax.legend(handles[::-1], labels[::-1])
173 plt.setp(plt.xticks()[1])
174 ax.set_ylabel("Rate")
175 ax.set_xlabel("Sample Class")
176 plt.show()