Fixed a bug in the BSP code and added more tests

This commit is contained in:
Davis King 2012-10-21 18:21:20 -04:00
parent 7619b5308a
commit 56c6d33c4a
2 changed files with 121 additions and 8 deletions

View File

@ -222,6 +222,13 @@ namespace dlib
// now wait for all the other nodes to terminate
while (num_terminated_nodes < _cons.size() )
{
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
{
num_waiting_nodes = 0;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
++current_epoch;
}
if (!msg_buffer.pop(msg))
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
@ -251,14 +258,6 @@ namespace dlib
{
++num_waiting_nodes;
}
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
{
num_waiting_nodes = 0;
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
++current_epoch;
}
}
if (node_id() == 0)

View File

@ -5,6 +5,7 @@
#include <dlib/bsp.h>
#include <dlib/threads.h>
#include <dlib/pipe.h>
#include <dlib/matrix.h>
#include "tester.h"
@ -420,6 +421,118 @@ namespace
DLIB_TEST(val == 25);
}
// ----------------------------------------------------------------------------------------
double f ( double x)
{
return std::pow(x-2.0, 2.0);
}
void bsp_job_node_0 (
bsp_context& context,
double& min_value,
double& optimal_x
)
{
double left = -100;
double right = 100;
min_value = std::numeric_limits<double>::infinity();
double interval_width = std::abs(right-left);
// This is doing a BSP based grid search for the minimum of f(). Here we
// do 100 iterations where we keep shrinking the grid size.
for (int i = 0; i < 100; ++i)
{
context.broadcast(left);
context.broadcast(right);
for (unsigned int k = 1; k < context.number_of_nodes(); ++k)
{
std::pair<double,double> val;
context.receive(val);
if (val.second < min_value)
{
min_value = val.second;
optimal_x = val.first;
}
}
interval_width *= 0.5;
left = optimal_x - interval_width/2;
right = optimal_x + interval_width/2;
}
}
void bsp_job_other_nodes (
bsp_context& context
)
{
double left, right;
while (context.try_receive(left))
{
context.receive(right);
const double l = (context.node_id()-1)/(context.number_of_nodes()-1.0);
const double r = context.node_id() /(context.number_of_nodes()-1.0);
const double width = right-left;
matrix<double> values_to_check = linspace(left +l*width, left + r*width, 100);
double best_x;
double best_val = std::numeric_limits<double>::infinity();
for (long j = 0; j < values_to_check.size(); ++j)
{
double temp = f(values_to_check(j));
if (temp < best_val)
{
best_val = temp;
best_x = values_to_check(j);
}
}
context.send(make_pair(best_x, best_val), 0);
}
}
void dotest6()
{
dlog << LINFO << "start dotest6()";
print_spinner();
bool error_occurred = false;
{
dlib::pipe<unsigned short> ports(5);
thread_function t1(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
thread_function t2(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
thread_function t3(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
try
{
std::vector<network_address> hosts;
unsigned short port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
double min_value = 10, optimal_x = 0;
bsp_connect(hosts, bsp_job_node_0, dlib::ref(min_value), dlib::ref(optimal_x));
dlog << LINFO << "min_value: " << min_value;
dlog << LINFO << "optimal_x: " << optimal_x;
DLIB_TEST(std::abs(min_value - 0) < 1e-14);
DLIB_TEST(std::abs(optimal_x - 2) < 1e-14);
}
catch (std::exception& e)
{
dlog << LERROR << "error during bsp_context: " << e.what();
DLIB_TEST(false);
}
}
DLIB_TEST(error_occurred == false);
}
// ----------------------------------------------------------------------------------------
class bsp_tester : public tester
@ -444,6 +557,7 @@ namespace
dotest3();
dotest4();
dotest5();
dotest6();
}
}
} a;