From f112398178bc86d7dbda907758cdb90ce446db33 Mon Sep 17 00:00:00 2001 From: Brandon Amos Date: Sun, 6 Mar 2016 19:46:33 -0500 Subject: [PATCH] DNN Training: Plot LFW accuracies for #100. --- training/plot-loss.py | 43 +++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/training/plot-loss.py b/training/plot-loss.py index f66e961..735f931 100755 --- a/training/plot-loss.py +++ b/training/plot-loss.py @@ -31,19 +31,19 @@ workDir = os.path.join(scriptDir, 'work') def plot(workDirs): trainDfs = [] - # testDfs = [] + testDfs = [] for d in workDirs: - trainF = os.path.join(workDir, str(d), 'train.log') - # testF = os.path.join(workDir, str(d), 'test.log') + trainF = os.path.join(workDir, "{:03d}".format(d), 'train.log') + testF = os.path.join(workDir, "{:03d}".format(d), 'test.log') trainDfs.append(pd.read_csv(trainF, sep='\t')) - # testDfs.append(pd.read_csv(testF, sep='\t')) - # if len(trainDfs[-1]) != len(testDfs[-1]): - # print("Error: Train/test dataframe shapes " - # "for '{}' don't match: {}, {}".format( - # d, trainDfs[-1].shape, testDfs[-1].shape)) - # sys.exit(-1) + testDfs.append(pd.read_csv(testF, sep='\t')) + if len(trainDfs[-1]) != len(testDfs[-1]): + print("Error: Train/test dataframe shapes " + "for '{}' don't match: {}, {}".format( + d, trainDfs[-1].shape, testDfs[-1].shape)) + sys.exit(-1) trainDf = pd.concat(trainDfs, ignore_index=True) - # testDf = pd.concat(testDfs, ignore_index=True) + testDf = pd.concat(testDfs, ignore_index=True) # print("train, test:") # print("\n".join(["{:0.2e}, {:0.2e}".format(x, y) for (x, y) in @@ -52,18 +52,29 @@ def plot(workDirs): fig, ax = plt.subplots(1, 1) trainDf.index += 1 - # testDf.index += 1 - trainDf['avg triplet loss (train set)'].plot(legend='True', ax=ax) - # testDf['avg triplet loss (test set)'].plot(legend='True', ax=ax, alpha=0.6) - plt.legend(['Train loss, semi-hard triplets']) # 'Test loss, random triplets']) + trainDf['avg triplet loss (train set)'].plot(ax=ax) plt.xlabel("Epoch") - plt.ylabel("Loss") + plt.ylabel("Average Triplet Loss, Training") # plt.ylim(ymin=0) plt.xlim(xmin=1) plt.grid(b=True, which='major', color='k', linestyle='-') plt.grid(b=True, which='minor', color='k', linestyle='--', alpha=0.2) ax.set_yscale('log') - fig.savefig(os.path.join(plotDir, "loss.pdf")) + d = os.path.join(plotDir, "train-loss.pdf") + fig.savefig(d) + print("Created {}".format(d)) + + fig, ax = plt.subplots(1, 1) + testDf.index += 1 + testDf['lfwAcc'].plot(ax=ax) + plt.xlabel("Epoch") + plt.ylabel("LFW Accuracy") + plt.ylim(ymin=0, ymax=1) + # plt.xlim(xmin=1) + # ax.set_yscale('log') + d = os.path.join(plotDir, "lfw-accuracy.pdf") + fig.savefig(d) + print("Created {}".format(d)) if __name__ == '__main__': os.makedirs(plotDir, exist_ok=True)