openface/util/tsne.py

42 lines
1.1 KiB
Python
Raw Normal View History

2015-10-04 06:46:29 +08:00
#!/usr/bin/env python3
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
plt.style.use('bmh')
import os
import sys
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('workDir', type=str)
parser.add_argument('--names', type=str, nargs='+', required=True)
args = parser.parse_args()
2015-10-12 23:30:29 +08:00
y = pd.read_csv("{}/labels.csv".format(args.workDir)).as_matrix()[:, 0]
2015-10-04 06:46:29 +08:00
X = pd.read_csv("{}/reps.csv".format(args.workDir)).as_matrix()
target_names = np.array(args.names)
colors = cm.gnuplot2(np.linspace(0, 0.7, len(target_names)))
X_pca = PCA(n_components=50).fit_transform(X, X)
tsne = TSNE(n_components=2, init='random', random_state=0)
X_r = tsne.fit_transform(X_pca)
2015-10-13 19:37:27 +08:00
for c, i, target_name in zip(colors,
list(range(1, len(target_names) + 1)),
target_names):
plt.scatter(X_r[y == i, 0], X_r[y == i, 1],
c=c, label=target_name)
2015-10-04 06:46:29 +08:00
plt.legend()
plt.savefig("{}/tsne.pdf".format(args.workDir))