diff --git a/training/plot-loss.py b/training/plot-loss.py index b201dc0..eae203e 100755 --- a/training/plot-loss.py +++ b/training/plot-loss.py @@ -27,15 +27,15 @@ import sys scriptDir = os.path.dirname(os.path.realpath(__file__)) plotDir = os.path.join(scriptDir, 'plots') -workDir = os.path.join(scriptDir, 'work') +# workDir = os.path.join(scriptDir, 'work') def plot(workDirs): trainDfs = [] testDfs = [] for d in workDirs: - trainF = os.path.join(workDir, "{:03d}".format(d), 'train.log') - testF = os.path.join(workDir, "{:03d}".format(d), 'test.log') + trainF = os.path.join(d, 'train.log') + testF = os.path.join(d, 'test.log') trainDfs.append(pd.read_csv(trainF, sep='\t')) testDfs.append(pd.read_csv(testF, sep='\t')) if len(trainDfs[-1]) != len(testDfs[-1]): @@ -84,6 +84,6 @@ def plot(workDirs): if __name__ == '__main__': os.makedirs(plotDir, exist_ok=True) parser = argparse.ArgumentParser() - parser.add_argument('workDirs', type=int, nargs='+') + parser.add_argument('workDirs', type=str, nargs='+') args = parser.parse_args() plot(args.workDirs)