blog-acoustic-fingerprinting
ARCHIVED - acoustic fingerprinting television shows with python
git clone https://git.vogt.world/blog-acoustic-fingerprinting.git
Log | Files | README.md
← All files
name: src/python/analyze.py
-rw-r--r--
6278
  1import warnings
  2import glob
  3import matplotlib.pyplot as plt
  4import matplotlib.cm as cm
  5import numpy as np
  6warnings.filterwarnings("ignore")
  7from fingerprinter.reader import read_single
  8from fingerprinter.fingerprint import Fingerprinter
  9from fingerprinter.fingerprint_record import determine_match
 10from database.fingerprint_db import FingerprintDatabase
 11from etc.util import get_args_for_input_output_directories, get_file_name
 12
 13DESCRIPTION = """
 14This script will load all clips in a given directory (-d), fingerprint them, and lookup the corresponding episode in
 15the fingerprint database. It will then plot the positive, negative, false-positive, and false-negative matches to rate
 16the effectiveness of the current parameters of the algorithm. These plots will be saved as .png files in a given
 17output directory.
 18
 19Files in the input directory should be in the format "{episode}.{play_head}.{noise_class}.wav", as written in the
 20clips.py script.
 21"""
 22PLOT_FINGERPRINTS = False
 23MAX_HASH_LOOKUP = 512
 24
 25
 26def pull_metadata_from_file_path(file_path):
 27  filename = file_path.split("/")[-1:][0]
 28  sp = filename.split('.')
 29  episode_name = sp[0]
 30  t = sp[1]
 31  wav_class = "NORMAL"
 32  if "HN" in filename:
 33    wav_class = "HN"
 34  elif "MN" in filename:
 35    wav_class = "MN"
 36  elif "LN" in filename:
 37    wav_class = "LN"
 38  return (wav_class, episode_name, t)
 39
 40
 41if __name__ == "__main__":
 42  args = get_args_for_input_output_directories(DESCRIPTION)
 43  print("Loading .wav files for analysis from {}".format(args.directory))
 44  db = FingerprintDatabase()
 45  results = {
 46    "NORMAL": {
 47      "tp": 0,
 48      "tn": 0,
 49      "fp": 0,
 50      "fn": 0
 51    },
 52    "LN": {
 53      "tp": 0,
 54      "tn": 0,
 55      "fp": 0,
 56      "fn": 0
 57    },
 58    "MN": {
 59      "tp": 0,
 60      "tn": 0,
 61      "fp": 0,
 62      "fn": 0
 63    },
 64    "HN": {
 65      "tp": 0,
 66      "tn": 0,
 67      "fp": 0,
 68      "fn": 0
 69    }
 70  }
 71  f = Fingerprinter(plot_fingerprint=PLOT_FINGERPRINTS)
 72  for filename in glob.glob(args.directory + "positive/*.wav"):
 73    print("Current working positive clip: {}".format(filename))
 74    wav_class, episode_name, tstamp = pull_metadata_from_file_path(filename)
 75    print("Positive Clip: {} {} {}".format(wav_class, episode_name, tstamp))
 76    channel, frame_rate = read_single(filename)
 77    hashes = [x for x in f.fingerprint(channel, frame_rate)]
 78    hash_numbers = [hash_pair[1] for hash_pair in hashes[0:MAX_HASH_LOOKUP]]
 79    fingerprint_records = db.find_records_by_hashes(hash_numbers)
 80    actual_match_name = determine_match(fingerprint_records)
 81    print("expected {}, actual {}".format(episode_name, actual_match_name))
 82    results[wav_class]["tp" if actual_match_name == episode_name else "fn"] += 1
 83  for filename in glob.glob(args.directory + "negative/*.wav"):
 84    print("Current working negative clip: {}".format(filename))
 85    wav_class, episode_name, tstamp = pull_metadata_from_file_path(filename)
 86    print("Negative Clip: {} {} {}".format(wav_class, episode_name, tstamp))
 87    channel, frame_rate = read_single(filename)
 88    hashes = [x for x in f.fingerprint(channel, frame_rate)]
 89    hash_numbers = [hash_pair[1] for hash_pair in hashes[0:MAX_HASH_LOOKUP]]
 90    fingerprint_records = db.find_records_by_hashes(hash_numbers)
 91    should_be_none = determine_match(fingerprint_records)
 92    print("expected None, actual {}".format(should_be_none))
 93    results[wav_class]["tn" if should_be_none is None else "fp"] += 1
 94  print(results)
 95  graphing_dict = {
 96    "LN": {
 97      "label": "LN",
 98      "color": "blue"
 99    },
100    "MN": {
101      "label": "MN",
102      "color": "orange"
103    },
104    "HN": {
105      "label": "HN",
106      "color": "red"
107    },
108    "NORMAL": {
109      "label": "NORMAL",
110      "color": "green"
111    }
112  }
113  for classification in ["NORMAL", "LN", "MN", "HN"]:
114    tp = float(results[classification]["tp"])
115    fp = float(results[classification]["fp"])
116    tn = float(results[classification]["tn"])
117    fn = float(results[classification]["fn"])
118    precision = tp / (tp + fp)
119    recall = tp / (tp + fn)
120    specificity = tn / (tn + fp)
121    accuracy = (tp + tn) / (fp + tp + fn + tn)
122    ppcr = (tp + fp) / (fp + tp + fn + tn)
123    print(classification, "precision", precision)
124    print(classification, "recall", recall)
125    print(classification, "specificity", specificity)
126    print(classification, "accuracy", accuracy)
127    print(classification, "ppcr", ppcr)
128    print("")
129    plt.scatter([recall], [precision], marker="o", facecolors="None", color=graphing_dict[classification]["color"], s=40, linewidths=2, label=graphing_dict[classification]["label"])
130  plt.title("Recall vs Precision")
131  plt.ylabel("Precision")
132  plt.xlabel("Recall")
133  plt.xlim(-1, 1)
134  plt.ylim(-1, 1)
135  plt.legend(scatterpoints=1, loc='lower right')
136  plt.axvline(x=0, color='b', linestyle=':')
137  plt.axhline(y=0, color='b', linestyle=':')
138  plt.tight_layout()
139  plt.show()
140
141  # grouping specificity, accuracy, ppcr
142  rows = []
143  for classification in ["NORMAL", "LN", "MN", "HN"]:
144    tp = float(results[classification]["tp"])
145    fp = float(results[classification]["fp"])
146    tn = float(results[classification]["tn"])
147    fn = float(results[classification]["fn"])
148    precision = tp / (tp + fp)
149    recall = tp / (tp + fn)
150    specificity = tn / (tn + fp)
151    accuracy = (tp + tn) / (fp + tp + fn + tn)
152    ppcr = (tp + fp) / (fp + tp + fn + tn)
153    rows.append([classification, "specificity", specificity])
154    rows.append([classification, "accuracy", accuracy])
155    rows.append([classification, "ppcr", ppcr])
156  dpoints = np.array(rows)
157  fig = plt.figure()
158  ax = fig.add_subplot(111)
159  space = 0.3
160  conditions = np.unique(dpoints[:, 0])
161  categories = np.unique(dpoints[:, 1])
162  n = len(conditions)
163  width = (1 - space) / (len(conditions))
164  for i, cond in enumerate(conditions):
165    indeces = range(1, len(categories) + 1)
166    vals = dpoints[dpoints[:, 0] == cond][:, 2].astype(np.float)
167    pos = [j - (1 - space) / 2. + i * width for j in range(1, len(categories) + 1)]
168    ax.bar(pos, vals, width=width, label=cond, color=cm.Accent(float(i) / n))
169    ax.set_xticks(indeces)
170  ax.set_xticklabels(categories)
171  handles, labels = ax.get_legend_handles_labels()
172  ax.legend(handles[::-1], labels[::-1])
173  plt.setp(plt.xticks()[1])
174  ax.set_ylabel("Rate")
175  ax.set_xlabel("Sample Class")
176  plt.show()