mirror of https://github.com/davisking/dlib.git
Added an example for the sequence_segmenter
This commit is contained in:
parent
885a27c070
commit
3330d34fd1
|
@ -69,6 +69,7 @@ add_example(running_stats_ex)
|
|||
add_example(rvm_ex)
|
||||
add_example(rvm_regression_ex)
|
||||
add_example(sequence_labeler_ex)
|
||||
add_example(sequence_segmenter_ex)
|
||||
add_example(server_http_ex)
|
||||
add_example(server_iostream_ex)
|
||||
add_example(sockets_ex)
|
||||
|
|
|
@ -0,0 +1,242 @@
|
|||
// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
|
||||
/*
|
||||
|
||||
This example shows how to use dlib to learn to perform sequence segmentation. In a
|
||||
sequence segmentation task we are given a sequence of objects (e.g. words in a
|
||||
sentence) and we are supposed to detect certain subsequences (e.g. the names of
|
||||
people). In the code below we create some very simple sequence/segmentation training
|
||||
examples and use them to learn a sequence segmentation model. In particular, our
|
||||
sequences will be arrays of strings and our task will be to learn to identify person
|
||||
names. Once we have our segmentation model we can use it to find names in new
|
||||
sentences as we will show.
|
||||
*/
|
||||
|
||||
|
||||
#include <iostream>
|
||||
#include <cctype>
|
||||
#include <dlib/svm_threaded.h>
|
||||
#include <dlib/string.h>
|
||||
|
||||
using namespace std;
|
||||
using namespace dlib;
|
||||
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class feature_extractor
|
||||
{
|
||||
/*
|
||||
The sequence segmentation models we work with in this example are chain structured
|
||||
conditional random field style models. Therefore, central to a sequence
|
||||
segmentation model is a feature extractor object. This object defines all the
|
||||
properties of the model such as how many features it will use, and more importantly,
|
||||
how they are calculated.
|
||||
*/
|
||||
|
||||
public:
|
||||
// This should be the type used to represent an input sequence. It can be
|
||||
// anything so long as it has a .size() which returns the length of the sequence.
|
||||
typedef std::vector<std::string> sequence_type;
|
||||
|
||||
// The next four lines define high-level properties of the feature extraction model.
|
||||
// See the documentation for the sequence_labeler object for an extended discussion of
|
||||
// how they are used (note that the main body of the documentation is at the top of the
|
||||
// file documenting the sequence_labeler).
|
||||
const static bool use_BIO_model = true;
|
||||
const static bool use_high_order_features = true;
|
||||
const static bool allow_negative_weights = true;
|
||||
unsigned long window_size() const { return 3; }
|
||||
|
||||
// This function defines the dimensionality of the vectors output by the get_features()
|
||||
// function defined below.
|
||||
unsigned long num_features() const { return 1; }
|
||||
|
||||
template <typename feature_setter>
|
||||
void get_features (
|
||||
feature_setter& set_feature,
|
||||
const sequence_type& sentence,
|
||||
unsigned long position
|
||||
) const
|
||||
/*!
|
||||
requires
|
||||
- position < sentence.size()
|
||||
- set_feature is a function object which allows expressions of the form:
|
||||
- set_features((unsigned long)feature_index, (double)feature_value);
|
||||
- set_features((unsigned long)feature_index);
|
||||
ensures
|
||||
- This function computes a feature vector which should capture the properties
|
||||
of sentence[position] that are informative relative to the sequence
|
||||
segmentation task you are trying to perform.
|
||||
- The output feature vector is returned as a sparse vector by invoking set_feature().
|
||||
For example, to set the feature with an index of 55 to the value of 1
|
||||
this method would call:
|
||||
set_feature(55);
|
||||
Or equivalently:
|
||||
set_feature(55,1);
|
||||
Therefore, the first argument to set_feature is the index of the feature
|
||||
to be set while the second argument is the value the feature should take.
|
||||
Additionally, note that calling set_feature() multiple times with the
|
||||
same feature index does NOT overwrite the old value, it adds to the
|
||||
previous value. For example, if you call set_feature(55) 3 times then it
|
||||
will result in feature 55 having a value of 3.
|
||||
- This function only calls set_feature() with feature_index values < num_features()
|
||||
!*/
|
||||
{
|
||||
// The model in this example program is very simple. Our features only look at the
|
||||
// capitalization pattern of the words. So we have a single feature which checks
|
||||
// if the first letter is capitalized or not.
|
||||
|
||||
// if the first character of the word sentence[position] is capitalized then
|
||||
// we set our only feature 1, otherwise it remains set to 0.
|
||||
if (isupper(sentence[position][0]))
|
||||
set_feature(0);
|
||||
}
|
||||
};
|
||||
|
||||
// We need to define serialize() and deserialize() for our feature extractor if we want
|
||||
// to be able to serialize and deserialize our learned models. In this case the
|
||||
// implementation is empty since our feature_extractor doesn't have any state. But you
|
||||
// might define more complex feature extractors which have state that needs to be saved.
|
||||
void serialize(const feature_extractor&, std::ostream&) {}
|
||||
void deserialize(feature_extractor&, std::istream&) {}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void make_training_examples (
|
||||
std::vector<std::vector<std::string> >& samples,
|
||||
std::vector<std::vector<std::pair<unsigned long, unsigned long> > >& segments
|
||||
)
|
||||
/*!
|
||||
ensures
|
||||
- This function fills samples with example sentences and segments with the
|
||||
locations of person names that should be segmented out.
|
||||
- #samples.size() == #segments.size()
|
||||
!*/
|
||||
{
|
||||
std::vector<std::pair<unsigned long, unsigned long> > names;
|
||||
|
||||
|
||||
// Here we make our first training example. split() turns the string into an array of
|
||||
// 10 words and then we store that into samples.
|
||||
samples.push_back(split("The other day I saw a man named Jim Smith"));
|
||||
// We want to detect person names. So we note that the name is located within the
|
||||
// range [8, 10). Note that we use half open ranges to identify segments. So in this
|
||||
// case, the segment identifies the string "Jim Smith".
|
||||
names.push_back(make_pair(8, 10));
|
||||
segments.push_back(names); names.clear();
|
||||
|
||||
// Now we add a few more example sentences
|
||||
|
||||
samples.push_back(split("Davis King is the main author of the dlib Library"));
|
||||
names.push_back(make_pair(0, 2));
|
||||
segments.push_back(names); names.clear();
|
||||
|
||||
|
||||
samples.push_back(split("Bob Jones is a name and so is George Clinton"));
|
||||
names.push_back(make_pair(0, 2));
|
||||
names.push_back(make_pair(8, 10));
|
||||
segments.push_back(names); names.clear();
|
||||
|
||||
|
||||
samples.push_back(split("My dog is named Bob Barker"));
|
||||
names.push_back(make_pair(4, 6));
|
||||
segments.push_back(names); names.clear();
|
||||
|
||||
|
||||
samples.push_back(split("ABC is an acronym but John James Smith is a name"));
|
||||
names.push_back(make_pair(5, 8));
|
||||
segments.push_back(names); names.clear();
|
||||
|
||||
|
||||
samples.push_back(split("No names in this sentence at all"));
|
||||
segments.push_back(names); names.clear();
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void print_segment (
|
||||
const std::vector<std::string>& sentence,
|
||||
const std::pair<unsigned long,unsigned long>& segment
|
||||
)
|
||||
{
|
||||
// Recall that a segment is a half open range starting with .first and ending just
|
||||
// before .second.
|
||||
for (unsigned long i = segment.first; i < segment.second; ++i)
|
||||
cout << sentence[i] << " ";
|
||||
cout << endl;
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
int main()
|
||||
{
|
||||
// Finally we make it into the main program body. So the first thing we do is get our
|
||||
// training data.
|
||||
std::vector<std::vector<std::string> > samples;
|
||||
std::vector<std::vector<std::pair<unsigned long, unsigned long> > > segments;
|
||||
make_training_examples(samples, segments);
|
||||
|
||||
|
||||
// Next we use the structural_sequence_segmentation_trainer to learn our segmentation
|
||||
// model based on just the samples and segments. But first we setup some of its
|
||||
// parameters.
|
||||
structural_sequence_segmentation_trainer<feature_extractor> trainer;
|
||||
// This is the common SVM C parameter. Larger values encourage the trainer to attempt
|
||||
// to fit the data exactly but might overfit. In general, you determine this parameter
|
||||
// by cross-validation.
|
||||
trainer.set_c(10);
|
||||
// This trainer can use multiple CPU cores to speed up the training. So set this to
|
||||
// the number of available CPU cores.
|
||||
trainer.set_num_threads(4);
|
||||
|
||||
|
||||
// Learn to do sequence segmentation from the dataset
|
||||
sequence_segmenter<feature_extractor> segmenter = trainer.train(samples, segments);
|
||||
|
||||
|
||||
// Lets print out all the segments our segmenter detects.
|
||||
for (unsigned long i = 0; i < samples.size(); ++i)
|
||||
{
|
||||
// get all the detected segments in samples[i]
|
||||
std::vector<std::pair<unsigned long,unsigned long> > seg = segmenter(samples[i]);
|
||||
// Print each of them
|
||||
for (unsigned long j = 0; j < seg.size(); ++j)
|
||||
{
|
||||
print_segment(samples[i], seg[j]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
std::vector<std::string> sentence(split("There once was a man from Nantucket whose name rhymed with Bob Bucket"));
|
||||
std::vector<std::pair<unsigned long,unsigned long> > seg = segmenter(sentence);
|
||||
for (unsigned long j = 0; j < seg.size(); ++j)
|
||||
{
|
||||
print_segment(sentence, seg[j]);
|
||||
}
|
||||
|
||||
|
||||
|
||||
// We can also test the accuracy of the segmenter on a dataset. This statement simply
|
||||
// tests on the training data. In this case we will see that it predicts everything
|
||||
// correctly.
|
||||
cout << "\nprecision, recall, f1-score: " << test_sequence_segmenter(segmenter, samples, segments);
|
||||
// Similarly, we can do 5-fold cross-validation and print the results. Just as before,
|
||||
// we see everything is predicted correctly.
|
||||
cout << "precision, recall, f1-score: " << cross_validate_sequence_segmenter(trainer, samples, segments, 5);
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// Finally, the segmenter can be serialized to disk just like most dlib objects.
|
||||
ofstream fout("segmenter.dat", ios::binary);
|
||||
serialize(segmenter, fout);
|
||||
fout.close();
|
||||
|
||||
// recall from disk
|
||||
ifstream fin("segmenter.dat", ios::binary);
|
||||
deserialize(segmenter, fin);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
Loading…
Reference in New Issue