openface/util/tsne.py

49 lines
1.2 KiB
Python
Raw Normal View History

2015-11-12 00:20:27 +08:00
#!/usr/bin/env python2
2015-10-04 06:46:29 +08:00
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
plt.style.use('bmh')
import argparse
2015-11-12 00:20:27 +08:00
print("""
Note: This example assumes that `name i` corresponds to `label i`
in `labels.csv`.
""")
2015-10-04 06:46:29 +08:00
parser = argparse.ArgumentParser()
parser.add_argument('workDir', type=str)
parser.add_argument('--names', type=str, nargs='+', required=True)
args = parser.parse_args()
2015-10-12 23:30:29 +08:00
y = pd.read_csv("{}/labels.csv".format(args.workDir)).as_matrix()[:, 0]
2015-10-04 06:46:29 +08:00
X = pd.read_csv("{}/reps.csv".format(args.workDir)).as_matrix()
target_names = np.array(args.names)
colors = cm.gnuplot2(np.linspace(0, 0.7, len(target_names)))
X_pca = PCA(n_components=50).fit_transform(X, X)
tsne = TSNE(n_components=2, init='random', random_state=0)
X_r = tsne.fit_transform(X_pca)
2015-10-13 19:37:27 +08:00
for c, i, target_name in zip(colors,
list(range(1, len(target_names) + 1)),
target_names):
plt.scatter(X_r[y == i, 0], X_r[y == i, 1],
c=c, label=target_name)
2015-10-04 06:46:29 +08:00
plt.legend()
2015-12-13 05:30:16 +08:00
out = "{}/tsne.pdf".format(args.workDir)
plt.savefig(out)
print("Saved to: {}".format(out))