mirror of https://github.com/davisking/dlib.git
Added the rls_filter object.
This commit is contained in:
parent
cea4f73d3c
commit
9802fe64f8
|
@ -4,6 +4,7 @@
|
|||
#define DLIB_FILTERiNG_HEADER
|
||||
|
||||
#include "filtering/kalman_filter.h"
|
||||
#include "filtering/rls_filter.h"
|
||||
|
||||
#endif // DLIB_FILTERiNG_HEADER
|
||||
|
||||
|
|
|
@ -0,0 +1,147 @@
|
|||
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#ifndef DLIB_RLS_FiLTER_H__
|
||||
#define DLIB_RLS_FiLTER_H__
|
||||
|
||||
#include "rls_filter_abstract.h"
|
||||
#include "../svm/rls.h"
|
||||
#include <vector>
|
||||
#include "../matrix.h"
|
||||
#include "../sliding_buffer.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class rls_filter
|
||||
{
|
||||
/*!
|
||||
CONVENTION
|
||||
- data.size() == the number of variables in a measurement
|
||||
- data[i].size() == data[j].size() for all i and j.
|
||||
- data[i].size() == get_window_size()
|
||||
- data[i][0] == most recent measurement of i-th variable given to update.
|
||||
- data[i].back() == oldest measurement of i-th variable given to update
|
||||
(or zero if we haven't seen this much data yet).
|
||||
|
||||
- if (count <= 2) then
|
||||
- count == number of times update(z) has been called
|
||||
!*/
|
||||
public:
|
||||
|
||||
rls_filter()
|
||||
{
|
||||
size = 5;
|
||||
count = 0;
|
||||
filter = rls(0.8, 100);
|
||||
}
|
||||
|
||||
explicit rls_filter (
|
||||
unsigned long size_,
|
||||
double forget_factor = 0.8,
|
||||
double C = 100
|
||||
)
|
||||
{
|
||||
size = size_;
|
||||
count = 0;
|
||||
filter = rls(forget_factor, C);
|
||||
}
|
||||
|
||||
double get_c(
|
||||
) const
|
||||
{
|
||||
return filter.get_c();
|
||||
}
|
||||
|
||||
double get_forget_factor(
|
||||
) const
|
||||
{
|
||||
return filter.get_forget_factor();
|
||||
}
|
||||
|
||||
unsigned long get_window_size (
|
||||
) const
|
||||
{
|
||||
return size;
|
||||
}
|
||||
|
||||
void update (
|
||||
)
|
||||
{
|
||||
if (filter.get_w().size() == 0)
|
||||
return;
|
||||
|
||||
for (unsigned long i = 0; i < data.size(); ++i)
|
||||
{
|
||||
// Put old predicted value into the circular buffer as if it was
|
||||
// the measurement we just observed. But don't update the rls filter.
|
||||
data[i].push_front(next(i));
|
||||
}
|
||||
|
||||
// predict next state
|
||||
for (long i = 0; i < next.size(); ++i)
|
||||
next(i) = filter(vector_to_matrix(data[i]));
|
||||
}
|
||||
|
||||
template <typename EXP>
|
||||
void update (
|
||||
const matrix_exp<EXP>& z
|
||||
)
|
||||
{
|
||||
// initialize data if necessary
|
||||
if (data.size() == 0)
|
||||
{
|
||||
data.resize(z.size());
|
||||
for (long i = 0; i < z.size(); ++i)
|
||||
data[i].assign(size, 0);
|
||||
}
|
||||
|
||||
|
||||
for (unsigned long i = 0; i < data.size(); ++i)
|
||||
{
|
||||
// Once there is some stuff in the circular buffer, start
|
||||
// showing it to the rls filter so it can do its thing.
|
||||
if (count >= 2)
|
||||
{
|
||||
filter.train(vector_to_matrix(data[i]), z(i));
|
||||
}
|
||||
|
||||
// keep track of the measurements in our circular buffer
|
||||
data[i].push_front(z(i));
|
||||
}
|
||||
|
||||
// Don't bother with the filter until we have seen two samples
|
||||
if (count >= 2)
|
||||
{
|
||||
for (long i = 0; i < z.size(); ++i)
|
||||
next(i) = filter(vector_to_matrix(data[i]));
|
||||
}
|
||||
else
|
||||
{
|
||||
++count;
|
||||
next = matrix_cast<double>(z);
|
||||
}
|
||||
}
|
||||
|
||||
const matrix<double,0,1>& get_predicted_next_state(
|
||||
)
|
||||
{
|
||||
return next;
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
unsigned long count;
|
||||
unsigned long size;
|
||||
rls filter;
|
||||
matrix<double,0,1> next;
|
||||
std::vector<circular_buffer<double> > data;
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_RLS_FiLTER_H__
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
||||
// License: Boost Software License See LICENSE.txt for the full license.
|
||||
#undef DLIB_RLS_FiLTER_ABSTRACT_H__
|
||||
#ifdef DLIB_RLS_FiLTER_ABSTRACT_H__
|
||||
|
||||
#include "../svm/rls_abstract.h"
|
||||
#include "../matrix/matrix_abstract.h"
|
||||
|
||||
namespace dlib
|
||||
{
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class rls_filter
|
||||
{
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
!*/
|
||||
|
||||
public:
|
||||
|
||||
rls_filter(
|
||||
);
|
||||
/*!
|
||||
ensures
|
||||
- #get_window_size() == 5
|
||||
- #get_c() == 100
|
||||
- #get_forget_factor() == 0.8
|
||||
!*/
|
||||
|
||||
explicit rls_filter (
|
||||
unsigned long size,
|
||||
double forget_factor = 0.8,
|
||||
double C = 100
|
||||
);
|
||||
/*!
|
||||
requires
|
||||
- 0 < forget_factor <= 1
|
||||
- 0 < C
|
||||
- size >= 2
|
||||
ensures
|
||||
- #get_window_size() == size
|
||||
- #get_forget_factor() == forget_factor
|
||||
- #get_c() == C
|
||||
!*/
|
||||
|
||||
double get_c(
|
||||
) const;
|
||||
/*!
|
||||
!*/
|
||||
|
||||
double get_forget_factor(
|
||||
) const;
|
||||
|
||||
unsigned long get_window_size (
|
||||
) const;
|
||||
|
||||
void update (
|
||||
);
|
||||
|
||||
template <typename EXP>
|
||||
void update (
|
||||
const matrix_exp<EXP>& z
|
||||
);
|
||||
|
||||
const matrix<double,0,1>& get_predicted_next_state(
|
||||
);
|
||||
|
||||
};
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
}
|
||||
|
||||
#endif // DLIB_RLS_FiLTER_ABSTRACT_H__
|
||||
|
||||
|
Loading…
Reference in New Issue