mirror of https://github.com/davisking/dlib.git
Relaxed the requirements on the feature extractor interface and also
added some tests to make sure the code really does work with the relaxed interface.
This commit is contained in:
parent
6d48b166a2
commit
44c79bcb91
|
@ -137,7 +137,10 @@ namespace dlib
|
|||
set_feature(55,1);
|
||||
Therefore, the first argument to set_feature is the index of the feature
|
||||
to be set while the second argument is the value the feature should take.
|
||||
- This function only calls set_feature() once for each feature index.
|
||||
Additionally, note that calling set_feature() multiple times with the same
|
||||
feature index does NOT overwrite the old value, it adds to the previous
|
||||
value. For example, if you call set_feature(55) 3 times then it will
|
||||
result in feature 55 having a value of 3.
|
||||
- This function only calls set_feature() with feature_index values < num_features()
|
||||
!*/
|
||||
|
||||
|
|
|
@ -62,6 +62,47 @@ namespace
|
|||
}
|
||||
};
|
||||
|
||||
class feature_extractor_partial
|
||||
{
|
||||
public:
|
||||
typedef unsigned long sample_type;
|
||||
|
||||
unsigned long num_features() const
|
||||
{
|
||||
return num_label_states*num_label_states + num_label_states*num_sample_states;
|
||||
}
|
||||
|
||||
unsigned long order() const
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
unsigned long num_labels() const
|
||||
{
|
||||
return num_label_states;
|
||||
}
|
||||
|
||||
template <typename feature_setter, typename EXP>
|
||||
void get_features (
|
||||
feature_setter& set_feature,
|
||||
const std::vector<sample_type>& x,
|
||||
const matrix_exp<EXP>& y,
|
||||
unsigned long position
|
||||
) const
|
||||
{
|
||||
if (y.size() > 1)
|
||||
{
|
||||
set_feature(y(1)*num_label_states + y(0), 0.5);
|
||||
set_feature(y(1)*num_label_states + y(0), 0.5);
|
||||
}
|
||||
|
||||
set_feature(num_label_states*num_label_states +
|
||||
y(0)*num_sample_states + x[position],0.4);
|
||||
set_feature(num_label_states*num_label_states +
|
||||
y(0)*num_sample_states + x[position],0.6);
|
||||
}
|
||||
};
|
||||
|
||||
bool called_rejct_labeling = false;
|
||||
class feature_extractor2
|
||||
{
|
||||
|
@ -324,6 +365,53 @@ namespace
|
|||
DLIB_TEST(std::abs(accuracy - 0.882) < 0.01);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void test2()
|
||||
{
|
||||
/*
|
||||
The point of this test is to make sure calling set_feature() multiple
|
||||
times works the way it is supposed to.
|
||||
*/
|
||||
|
||||
print_spinner();
|
||||
std::vector<std::vector<unsigned long> > samples;
|
||||
std::vector<std::vector<unsigned long> > labels;
|
||||
|
||||
matrix<double> transition_probabilities(num_label_states, num_label_states);
|
||||
transition_probabilities = 0.05, 0.90, 0.05,
|
||||
0.05, 0.05, 0.90,
|
||||
0.90, 0.05, 0.05;
|
||||
|
||||
matrix<double> emission_probabilities(num_label_states,num_sample_states);
|
||||
emission_probabilities = 0.5, 0.5, 0.0,
|
||||
0.0, 0.5, 0.5,
|
||||
0.5, 0.0, 0.5;
|
||||
|
||||
|
||||
make_dataset(transition_probabilities,emission_probabilities,
|
||||
samples, labels, 1000);
|
||||
|
||||
dlog << LINFO << "samples.size(): "<< samples.size();
|
||||
|
||||
structural_sequence_labeling_trainer<feature_extractor> trainer;
|
||||
structural_sequence_labeling_trainer<feature_extractor_partial> trainer_part;
|
||||
trainer.set_c(4);
|
||||
trainer_part.set_c(4);
|
||||
trainer.set_num_threads(4);
|
||||
trainer_part.set_num_threads(4);
|
||||
|
||||
|
||||
|
||||
// Learn to do sequence labeling from the dataset
|
||||
sequence_labeler<feature_extractor> labeler = trainer.train(samples, labels);
|
||||
sequence_labeler<feature_extractor_partial> labeler_part = trainer_part.train(samples, labels);
|
||||
|
||||
// Both feature extractors should be equivalent.
|
||||
DLIB_TEST(length(labeler.get_weights() - labeler_part.get_weights()) < 1e-10);
|
||||
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class sequence_labeler_tester : public tester
|
||||
|
@ -342,6 +430,8 @@ namespace
|
|||
DLIB_TEST(called_rejct_labeling == false);
|
||||
do_test<feature_extractor2>();
|
||||
DLIB_TEST(called_rejct_labeling == true);
|
||||
|
||||
test2();
|
||||
}
|
||||
} a;
|
||||
|
||||
|
|
Loading…
Reference in New Issue