cleaned this up a little

This commit is contained in:
Davis King 2011-11-03 19:17:53 -04:00
parent 29964d2858
commit 90c9d0be6e
1 changed files with 18 additions and 16 deletions

View File

@ -73,8 +73,8 @@ void deserialize(feature_extractor&, std::istream&) {}
// ----------------------------------------------------------------------------------------
void make_dataset (
const matrix<double>& emission_probabilities,
const matrix<double>& transition_probabilities,
const matrix<double>& emission_probabilities,
std::vector<std::vector<unsigned long> >& samples,
std::vector<std::vector<unsigned long> >& labels,
unsigned long dataset_size
@ -90,8 +90,10 @@ void make_dataset (
- This function randomly samples a bunch of sequences from the HMM defined by
transition_probabilities and emission_probabilities.
- The HMM is defined by:
- P(next_label |previous_label) == transition_probabilities(previous_label, next_label)
- P(next_sample|next_label) == emission_probabilities (next_label, next_sample)
- The probability of transitioning from hidden state H1 to H2
is given by transition_probabilities(H1,H2).
- The probability of a hidden state H producing an observed state
O is given by emission_probabilities(H,O).
- #samples.size() == labels.size() == dataset_size
- for all valid i:
- #labels[i] is a randomly sampled sequence of hidden states from the
@ -103,6 +105,10 @@ void make_dataset (
int main()
{
matrix<double> transition_probabilities(num_label_states, num_label_states);
transition_probabilities = 0.05, 0.90, 0.05,
0.05, 0.05, 0.90,
0.90, 0.05, 0.05;
// set this up so emission_probabilities(L,X) == The probability of a state with label L
// emitting an X.
@ -111,17 +117,11 @@ int main()
0.0, 0.5, 0.5,
0.5, 0.0, 0.5;
matrix<double> transition_probabilities(num_label_states, num_label_states);
transition_probabilities = 0.05, 0.90, 0.05,
0.05, 0.05, 0.90,
0.90, 0.05, 0.05;
std::vector<std::vector<unsigned long> > samples;
std::vector<std::vector<unsigned long> > labels;
make_dataset(emission_probabilities, transition_probabilities,
make_dataset(transition_probabilities,emission_probabilities,
samples, labels, 1000);
cout << "samples.size(): "<< samples.size() << endl;
@ -139,17 +139,19 @@ int main()
trainer.set_num_threads(4);
matrix<double> confusion_matrix;
// Learn to do sequence labeling from the dataset
sequence_labeler<feature_extractor> labeler = trainer.train(samples, labels);
confusion_matrix = test_sequence_labeler(labeler, samples, labels);
cout << "trained sequence labeler: " << endl;
cout << confusion_matrix;
cout << "label accuracy: "<< sum(diag(confusion_matrix))/sum(confusion_matrix) << endl;
std::vector<unsigned long> predicted_labels = labeler(samples[0]);
cout << "true hidden states: "<< trans(vector_to_matrix(labels[0]));
cout << "predicted hidden states: "<< trans(vector_to_matrix(predicted_labels));
// We can also do cross-validation
matrix<double> confusion_matrix;
confusion_matrix = cross_validate_sequence_labeler(trainer, samples, labels, 4);
cout << "\ncross-validation: " << endl;
cout << confusion_matrix;
@ -236,8 +238,8 @@ void sample_hmm (
// ----------------------------------------------------------------------------------------
void make_dataset (
const matrix<double>& emission_probabilities,
const matrix<double>& transition_probabilities,
const matrix<double>& emission_probabilities,
std::vector<std::vector<unsigned long> >& samples,
std::vector<std::vector<unsigned long> >& labels,
unsigned long dataset_size