From cc9ff97a29898cdc8b620c3bdba31390847c0434 Mon Sep 17 00:00:00 2001 From: Davis King Date: Sun, 7 Jul 2013 12:28:31 -0400 Subject: [PATCH] Cleaned up python svm struct code a little. --- python_examples/svm_struct.py | 42 ++++++++++++++++++++------------- tools/python/src/svm_struct.cpp | 16 +++++++++++-- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/python_examples/svm_struct.py b/python_examples/svm_struct.py index df1825d81..4a8fe3b64 100755 --- a/python_examples/svm_struct.py +++ b/python_examples/svm_struct.py @@ -1,7 +1,10 @@ #!/usr/bin/python # The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt # -# +# This is an example illustrating the use of the structural SVM solver from the dlib C++ +# Library. This example will briefly introduce it and then walk through an example showing +# how to use it to create a simple multi-class classifier. +# # # COMPILING THE DLIB PYTHON INTERFACE # Dlib comes with a compiled python interface for python 2.7 on MS Windows. If @@ -15,6 +18,7 @@ import dlib def dot(a, b): + "Compute the dot product between the two vectors a and b." return sum(i*j for i,j in zip(a,b)) @@ -23,30 +27,35 @@ class three_class_classifier_problem: be_verbose = True epsilon = 0.0001 + def __init__(self, samples, labels): self.num_samples = len(samples) self.num_dimensions = len(samples[0])*3 self.samples = samples self.labels = labels - def make_psi(self, psi, vector, label): + + def make_psi(self, vector, label): + psi = dlib.vector() psi.resize(self.num_dimensions) dims = len(vector) - if (label == 1): + if (label == 0): for i in range(0,dims): psi[i] = vector[i] - elif (label == 2): + elif (label == 1): for i in range(dims,2*dims): psi[i] = vector[i-dims] - else: + else: # the label must be 2 for i in range(2*dims,3*dims): psi[i] = vector[i-2*dims] + return psi - def get_truth_joint_feature_vector(self, idx, psi): - self.make_psi(psi, self.samples[idx], self.labels[idx]) + def get_truth_joint_feature_vector(self, idx): + return self.make_psi(self.samples[idx], self.labels[idx]) - def separation_oracle(self, idx, current_solution, psi): + + def separation_oracle(self, idx, current_solution): samp = samples[idx] dims = len(samp) scores = [0,0,0] @@ -56,29 +65,28 @@ class three_class_classifier_problem: scores[2] = dot(current_solution[2*dims:3*dims], samp) # Add in the loss-augmentation - if (labels[idx] != 1): + if (labels[idx] != 0): scores[0] += 1 - if (labels[idx] != 2): + if (labels[idx] != 1): scores[1] += 1 - if (labels[idx] != 3): + if (labels[idx] != 2): scores[2] += 1 - # Now figure out which classifier has the largest loss-augmented score. - max_scoring_label = scores.index(max(scores))+1 + max_scoring_label = scores.index(max(scores)) if (max_scoring_label == labels[idx]): loss = 0 else: loss = 1 - self.make_psi(psi, samp, max_scoring_label) + psi = self.make_psi(samp, max_scoring_label) - return loss + return loss,psi -samples = [ [0,0,1], [0,1,0], [1,0,0]]; -labels = [1, 2, 3] +samples = [[0,0,1], [0,1,0], [1,0,0]]; +labels = [0,1,2] problem = three_class_classifier_problem(samples, labels) weights = dlib.solve_structural_svm_problem(problem) diff --git a/tools/python/src/svm_struct.cpp b/tools/python/src/svm_struct.cpp index c0d29cfee..32542f300 100644 --- a/tools/python/src/svm_struct.cpp +++ b/tools/python/src/svm_struct.cpp @@ -37,7 +37,7 @@ public: feature_vector_type& psi ) const { - problem.attr("get_truth_joint_feature_vector")(idx,boost::ref(psi)); + psi = extract(problem.attr("get_truth_joint_feature_vector")(idx)); } virtual void separation_oracle ( @@ -47,7 +47,19 @@ public: feature_vector_type& psi ) const { - loss = extract(problem.attr("separation_oracle")(idx,boost::ref(current_solution),boost::ref(psi))); + object res = problem.attr("separation_oracle")(idx,boost::ref(current_solution)); + pyassert(len(res) == 2, "separation_oracle() must return two objects, the loss and the psi vector"); + // let the user supply the output arguments in any order. + if (extract(res[0]).check()) + { + loss = extract(res[0]); + psi = extract(res[1]); + } + else + { + psi = extract(res[0]); + loss = extract(res[1]); + } } private: