mirror of https://github.com/davisking/dlib.git
Fixed a bug in the BSP code and added more tests
This commit is contained in:
parent
7619b5308a
commit
56c6d33c4a
|
@ -222,6 +222,13 @@ namespace dlib
|
|||
// now wait for all the other nodes to terminate
|
||||
while (num_terminated_nodes < _cons.size() )
|
||||
{
|
||||
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
|
||||
{
|
||||
num_waiting_nodes = 0;
|
||||
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
|
||||
++current_epoch;
|
||||
}
|
||||
|
||||
if (!msg_buffer.pop(msg))
|
||||
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
|
||||
|
||||
|
@ -251,14 +258,6 @@ namespace dlib
|
|||
{
|
||||
++num_waiting_nodes;
|
||||
}
|
||||
|
||||
|
||||
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
|
||||
{
|
||||
num_waiting_nodes = 0;
|
||||
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
|
||||
++current_epoch;
|
||||
}
|
||||
}
|
||||
|
||||
if (node_id() == 0)
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
#include <dlib/bsp.h>
|
||||
#include <dlib/threads.h>
|
||||
#include <dlib/pipe.h>
|
||||
#include <dlib/matrix.h>
|
||||
|
||||
#include "tester.h"
|
||||
|
||||
|
@ -420,6 +421,118 @@ namespace
|
|||
DLIB_TEST(val == 25);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
double f ( double x)
|
||||
{
|
||||
return std::pow(x-2.0, 2.0);
|
||||
}
|
||||
|
||||
|
||||
void bsp_job_node_0 (
|
||||
bsp_context& context,
|
||||
double& min_value,
|
||||
double& optimal_x
|
||||
)
|
||||
{
|
||||
double left = -100;
|
||||
double right = 100;
|
||||
|
||||
min_value = std::numeric_limits<double>::infinity();
|
||||
double interval_width = std::abs(right-left);
|
||||
|
||||
// This is doing a BSP based grid search for the minimum of f(). Here we
|
||||
// do 100 iterations where we keep shrinking the grid size.
|
||||
for (int i = 0; i < 100; ++i)
|
||||
{
|
||||
context.broadcast(left);
|
||||
context.broadcast(right);
|
||||
|
||||
for (unsigned int k = 1; k < context.number_of_nodes(); ++k)
|
||||
{
|
||||
std::pair<double,double> val;
|
||||
context.receive(val);
|
||||
if (val.second < min_value)
|
||||
{
|
||||
min_value = val.second;
|
||||
optimal_x = val.first;
|
||||
}
|
||||
}
|
||||
|
||||
interval_width *= 0.5;
|
||||
left = optimal_x - interval_width/2;
|
||||
right = optimal_x + interval_width/2;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void bsp_job_other_nodes (
|
||||
bsp_context& context
|
||||
)
|
||||
{
|
||||
double left, right;
|
||||
while (context.try_receive(left))
|
||||
{
|
||||
context.receive(right);
|
||||
|
||||
const double l = (context.node_id()-1)/(context.number_of_nodes()-1.0);
|
||||
const double r = context.node_id() /(context.number_of_nodes()-1.0);
|
||||
|
||||
const double width = right-left;
|
||||
matrix<double> values_to_check = linspace(left +l*width, left + r*width, 100);
|
||||
|
||||
double best_x;
|
||||
double best_val = std::numeric_limits<double>::infinity();
|
||||
for (long j = 0; j < values_to_check.size(); ++j)
|
||||
{
|
||||
double temp = f(values_to_check(j));
|
||||
if (temp < best_val)
|
||||
{
|
||||
best_val = temp;
|
||||
best_x = values_to_check(j);
|
||||
}
|
||||
}
|
||||
|
||||
context.send(make_pair(best_x, best_val), 0);
|
||||
}
|
||||
}
|
||||
|
||||
void dotest6()
|
||||
{
|
||||
dlog << LINFO << "start dotest6()";
|
||||
print_spinner();
|
||||
bool error_occurred = false;
|
||||
{
|
||||
dlib::pipe<unsigned short> ports(5);
|
||||
thread_function t1(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
|
||||
thread_function t2(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
|
||||
thread_function t3(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
|
||||
|
||||
|
||||
try
|
||||
{
|
||||
std::vector<network_address> hosts;
|
||||
unsigned short port;
|
||||
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
||||
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
||||
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
||||
double min_value = 10, optimal_x = 0;
|
||||
bsp_connect(hosts, bsp_job_node_0, dlib::ref(min_value), dlib::ref(optimal_x));
|
||||
|
||||
dlog << LINFO << "min_value: " << min_value;
|
||||
dlog << LINFO << "optimal_x: " << optimal_x;
|
||||
DLIB_TEST(std::abs(min_value - 0) < 1e-14);
|
||||
DLIB_TEST(std::abs(optimal_x - 2) < 1e-14);
|
||||
}
|
||||
catch (std::exception& e)
|
||||
{
|
||||
dlog << LERROR << "error during bsp_context: " << e.what();
|
||||
DLIB_TEST(false);
|
||||
}
|
||||
|
||||
}
|
||||
DLIB_TEST(error_occurred == false);
|
||||
}
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
class bsp_tester : public tester
|
||||
|
@ -444,6 +557,7 @@ namespace
|
|||
dotest3();
|
||||
dotest4();
|
||||
dotest5();
|
||||
dotest6();
|
||||
}
|
||||
}
|
||||
} a;
|
||||
|
|
Loading…
Reference in New Issue