mirror of https://github.com/davisking/dlib.git
Fixed a race condition in the BSP code and also simplified the
logic somewhat.
This commit is contained in:
parent
cf643ee99c
commit
d38723f4af
222
dlib/bsp/bsp.cpp
222
dlib/bsp/bsp.cpp
|
@ -97,33 +97,36 @@ namespace dlib
|
|||
|
||||
namespace impl2
|
||||
{
|
||||
// These control bytes are sent before each message nodes send to each other.
|
||||
// These control bytes are sent before each message between nodes. Note that many
|
||||
// of these are only sent between the control node (node 0) and the other nodes.
|
||||
// This is because the controller node is responsible for handling the
|
||||
// synchronization that needs to happen when all nodes block on calls to
|
||||
// receive_data()
|
||||
// at the same time.
|
||||
|
||||
// denotes a normal content message.
|
||||
const static char MESSAGE_HEADER = 0;
|
||||
|
||||
// sent back to sender, means message was returned by receive().
|
||||
// sent to the controller node when someone receives a message via receive_data().
|
||||
const static char GOT_MESSAGE = 1;
|
||||
|
||||
// broadcast when a node goes into a state where it has no outstanding sent
|
||||
// messages (i.e. it received GOT_MESSAGE for all its sent messages) and is waiting
|
||||
// on receive().
|
||||
const static char IN_WAITING_STATE = 2;
|
||||
// sent to the controller node when someone sends a message via send().
|
||||
const static char SENT_MESSAGE = 2;
|
||||
|
||||
// broadcast when no longer in IN_WAITING_STATE state.
|
||||
const static char NOT_IN_WAITING_STATE = 3;
|
||||
// sent to the controller node when someone enters a call to receive_data()
|
||||
const static char IN_WAITING_STATE = 3;
|
||||
|
||||
// broadcast when a node terminates itself.
|
||||
const static char NODE_TERMINATE = 4;
|
||||
const static char NODE_TERMINATE = 5;
|
||||
|
||||
// broadcast when a node finds out that all non-terminated nodes are in the
|
||||
// IN_WAITING_STATE state. sending this message puts a node into the
|
||||
// SEE_ALL_IN_WAITING_STATE where it will wait until it gets this message from all
|
||||
// others and then return from receive() once this happens.
|
||||
const static char SEE_ALL_IN_WAITING_STATE = 5;
|
||||
// broadcast by the controller node when it determines that all nodes are blocked
|
||||
// on calls to receive_data() and there aren't any messages in flight. This is also
|
||||
// what makes us go to the next epoch.
|
||||
const static char SEE_ALL_IN_WAITING_STATE = 6;
|
||||
|
||||
|
||||
const static char READ_ERROR = 6;
|
||||
// This isn't ever transmitted between nodes. It is used internally to indicate
|
||||
// that an error occurred.
|
||||
const static char READ_ERROR = 7;
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -131,7 +134,7 @@ namespace dlib
|
|||
impl1::bsp_con* con,
|
||||
unsigned long node_id,
|
||||
unsigned long sender_id,
|
||||
impl1::thread_safe_deque& msg_buffer
|
||||
impl1::thread_safe_message_queue& msg_buffer
|
||||
)
|
||||
{
|
||||
try
|
||||
|
@ -145,6 +148,7 @@ namespace dlib
|
|||
if (msg.msg_type == MESSAGE_HEADER)
|
||||
{
|
||||
msg.data.reset(new std::string);
|
||||
deserialize(msg.epoch, con->stream);
|
||||
deserialize(*msg.data, con->stream);
|
||||
}
|
||||
|
||||
|
@ -203,12 +207,15 @@ namespace dlib
|
|||
close_all_connections_gracefully(
|
||||
)
|
||||
{
|
||||
_cons.reset();
|
||||
while (_cons.move_next())
|
||||
if (node_id() != 0)
|
||||
{
|
||||
// tell the other end that we are intentionally dropping the connection
|
||||
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
|
||||
_cons.element().value()->stream.flush();
|
||||
_cons.reset();
|
||||
while (_cons.move_next())
|
||||
{
|
||||
// tell the other end that we are intentionally dropping the connection
|
||||
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
|
||||
_cons.element().value()->stream.flush();
|
||||
}
|
||||
}
|
||||
|
||||
impl1::msg_data msg;
|
||||
|
@ -219,20 +226,59 @@ namespace dlib
|
|||
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
|
||||
|
||||
if (msg.msg_type == impl2::NODE_TERMINATE)
|
||||
{
|
||||
++num_terminated_nodes;
|
||||
_cons[msg.sender_id]->terminated = true;
|
||||
}
|
||||
else if (msg.msg_type == impl2::READ_ERROR)
|
||||
{
|
||||
throw dlib::socket_error(*msg.data);
|
||||
}
|
||||
else if (msg.msg_type == impl2::MESSAGE_HEADER)
|
||||
{
|
||||
throw dlib::socket_error("A BSP node received a message after it has terminated.");
|
||||
}
|
||||
else if (msg.msg_type == impl2::GOT_MESSAGE)
|
||||
{
|
||||
--num_waiting_nodes;
|
||||
--outstanding_messages;
|
||||
}
|
||||
else if (msg.msg_type == impl2::SENT_MESSAGE)
|
||||
{
|
||||
++outstanding_messages;
|
||||
}
|
||||
else if (msg.msg_type == impl2::IN_WAITING_STATE)
|
||||
{
|
||||
++num_waiting_nodes;
|
||||
}
|
||||
|
||||
|
||||
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0)
|
||||
{
|
||||
num_waiting_nodes = 0;
|
||||
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
|
||||
++current_epoch;
|
||||
}
|
||||
}
|
||||
|
||||
if (outstanding_messages != 0)
|
||||
if (node_id() == 0)
|
||||
{
|
||||
std::ostringstream sout;
|
||||
sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
|
||||
sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n";
|
||||
sout << "have a corresponding call to receive().";
|
||||
throw dlib::socket_error(sout.str());
|
||||
_cons.reset();
|
||||
while (_cons.move_next())
|
||||
{
|
||||
// tell the other end that we are intentionally dropping the connection
|
||||
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream);
|
||||
_cons.element().value()->stream.flush();
|
||||
}
|
||||
|
||||
if (outstanding_messages != 0)
|
||||
{
|
||||
std::ostringstream sout;
|
||||
sout << "A BSP job was allowed to terminate before all sent messages have been received.\n";
|
||||
sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n";
|
||||
sout << "have a corresponding call to receive().";
|
||||
throw dlib::socket_error(sout.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -263,6 +309,7 @@ namespace dlib
|
|||
outstanding_messages(0),
|
||||
num_waiting_nodes(0),
|
||||
num_terminated_nodes(0),
|
||||
current_epoch(1),
|
||||
_cons(cons_),
|
||||
_node_id(node_id_)
|
||||
{
|
||||
|
@ -288,95 +335,73 @@ namespace dlib
|
|||
unsigned long& sending_node_id
|
||||
)
|
||||
{
|
||||
if (outstanding_messages == 0)
|
||||
broadcast_byte(impl2::IN_WAITING_STATE);
|
||||
|
||||
unsigned long num_in_see_all_in_waiting_state = 0;
|
||||
bool sent_see_all_in_waiting_state = false;
|
||||
std::stack<impl1::msg_data> buf;
|
||||
notify_control_node(impl2::IN_WAITING_STATE);
|
||||
|
||||
while (true)
|
||||
{
|
||||
// if there aren't any nodes left to give us messages then return right now.
|
||||
if (num_terminated_nodes == _cons.size())
|
||||
// If there aren't any nodes left to give us messages then return right now.
|
||||
// We need to check the msg_buffer size to make sure there aren't any
|
||||
// unprocessed message there. Recall that this can happen because status
|
||||
// messages always jump to the front of the message buffer. So we might have
|
||||
// learned about the node terminations before processing their messages for us.
|
||||
if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// if all running nodes are currently blocking forever on receive_data()
|
||||
if (outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
|
||||
if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size())
|
||||
{
|
||||
num_waiting_nodes = 0;
|
||||
sent_see_all_in_waiting_state = true;
|
||||
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE);
|
||||
|
||||
// Note that the reason we have this epoch counter is so we can tell if a
|
||||
// sent message is from before or after one of these "all nodes waiting"
|
||||
// synchronization events. If we didn't have the epoch count we would have
|
||||
// a race condition where one node gets the SEE_ALL_IN_WAITING_STATE
|
||||
// message before others and then sends out a message to another node
|
||||
// before that node got the SEE_ALL_IN_WAITING_STATE message. Then that
|
||||
// node would think the normal message came before SEE_ALL_IN_WAITING_STATE
|
||||
// which would be bad.
|
||||
++current_epoch;
|
||||
return false;
|
||||
}
|
||||
|
||||
impl1::msg_data data;
|
||||
if (!msg_buffer.pop(data))
|
||||
if (!msg_buffer.pop(data, current_epoch))
|
||||
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context.");
|
||||
|
||||
if (sent_see_all_in_waiting_state)
|
||||
{
|
||||
// Once we have gotten one SEE_ALL_IN_WAITING_STATE, all we care about is
|
||||
// getting the rest of them. So the effect of this code is to always move
|
||||
// any SEE_ALL_IN_WAITING_STATE messages to the front of the message queue.
|
||||
if (data.msg_type != impl2::SEE_ALL_IN_WAITING_STATE)
|
||||
{
|
||||
buf.push(data);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
switch(data.msg_type)
|
||||
{
|
||||
case impl2::MESSAGE_HEADER: {
|
||||
item = data.data;
|
||||
sending_node_id = data.sender_id;
|
||||
|
||||
// if we would have send the IN_WAITING_STATE message before getting to
|
||||
// this point then let other nodes know that we aren't waiting anymore.
|
||||
if (outstanding_messages == 0)
|
||||
broadcast_byte(impl2::NOT_IN_WAITING_STATE);
|
||||
|
||||
send_byte(impl2::GOT_MESSAGE, data.sender_id);
|
||||
|
||||
notify_control_node(impl2::GOT_MESSAGE);
|
||||
return true;
|
||||
|
||||
} break;
|
||||
|
||||
case impl2::IN_WAITING_STATE: {
|
||||
++num_waiting_nodes;
|
||||
} break;
|
||||
|
||||
case impl2::NOT_IN_WAITING_STATE: {
|
||||
case impl2::GOT_MESSAGE: {
|
||||
--outstanding_messages;
|
||||
--num_waiting_nodes;
|
||||
} break;
|
||||
|
||||
case impl2::GOT_MESSAGE: {
|
||||
--outstanding_messages;
|
||||
if (outstanding_messages == 0)
|
||||
broadcast_byte(impl2::IN_WAITING_STATE);
|
||||
case impl2::SENT_MESSAGE: {
|
||||
++outstanding_messages;
|
||||
} break;
|
||||
|
||||
case impl2::NODE_TERMINATE: {
|
||||
++num_terminated_nodes;
|
||||
_cons[data.sender_id]->terminated = true;
|
||||
if (num_terminated_nodes == _cons.size())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
} break;
|
||||
|
||||
case impl2::SEE_ALL_IN_WAITING_STATE: {
|
||||
++num_in_see_all_in_waiting_state;
|
||||
if (num_in_see_all_in_waiting_state + num_terminated_nodes == _cons.size())
|
||||
{
|
||||
// put stuff from buf back into msg_buffer
|
||||
while (buf.size() != 0)
|
||||
{
|
||||
msg_buffer.push_front(buf.top());
|
||||
buf.pop();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
++current_epoch;
|
||||
return false;
|
||||
} break;
|
||||
|
||||
case impl2::READ_ERROR: {
|
||||
|
@ -393,13 +418,36 @@ namespace dlib
|
|||
// ----------------------------------------------------------------------------------------
|
||||
|
||||
void bsp_context::
|
||||
send_byte (
|
||||
char val,
|
||||
unsigned long target_node_id
|
||||
notify_control_node (
|
||||
char val
|
||||
)
|
||||
{
|
||||
serialize(val, _cons[target_node_id]->stream);
|
||||
_cons[target_node_id]->stream.flush();
|
||||
if (node_id() == 0)
|
||||
{
|
||||
using namespace impl2;
|
||||
switch(val)
|
||||
{
|
||||
case SENT_MESSAGE: {
|
||||
++outstanding_messages;
|
||||
} break;
|
||||
|
||||
case GOT_MESSAGE: {
|
||||
--outstanding_messages;
|
||||
} break;
|
||||
|
||||
case IN_WAITING_STATE: {
|
||||
// nothing to do in this case
|
||||
} break;
|
||||
|
||||
default:
|
||||
DLIB_CASSERT(false,"This should never happen");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
serialize(val, _cons[0]->stream);
|
||||
_cons[0]->stream.flush();
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
@ -415,7 +463,8 @@ namespace dlib
|
|||
if (i == node_id() || _cons[i]->terminated)
|
||||
continue;
|
||||
|
||||
send_byte(val,i);
|
||||
serialize(val, _cons[i]->stream);
|
||||
_cons[i]->stream.flush();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -432,10 +481,11 @@ namespace dlib
|
|||
throw socket_error("Attempt to send a message to a node that has terminated.");
|
||||
|
||||
serialize(MESSAGE_HEADER, _cons[target_node_id]->stream);
|
||||
serialize(current_epoch, _cons[target_node_id]->stream);
|
||||
serialize(item, _cons[target_node_id]->stream);
|
||||
_cons[target_node_id]->stream.flush();
|
||||
|
||||
++outstanding_messages;
|
||||
notify_control_node(SENT_MESSAGE);
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------------------
|
||||
|
|
110
dlib/bsp/bsp.h
110
dlib/bsp/bsp.h
|
@ -12,7 +12,7 @@
|
|||
#include "../serialize.h"
|
||||
#include "../map.h"
|
||||
#include "../ref.h"
|
||||
#include <deque>
|
||||
#include <queue>
|
||||
#include <vector>
|
||||
|
||||
namespace dlib
|
||||
|
@ -210,15 +210,64 @@ namespace dlib
|
|||
shared_ptr<std::string> data;
|
||||
unsigned long sender_id;
|
||||
char msg_type;
|
||||
dlib::uint64 epoch;
|
||||
|
||||
msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {}
|
||||
};
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
|
||||
class thread_safe_deque
|
||||
class thread_safe_message_queue : noncopyable
|
||||
{
|
||||
public:
|
||||
thread_safe_deque() : sig(class_mutex),disabled(false) {}
|
||||
/*!
|
||||
WHAT THIS OBJECT REPRESENTS
|
||||
This is a simple message queue for msg_data objects. Note that it
|
||||
has the special property that, while messages will generally leave
|
||||
the queue in the order they are inserted, any message with a smaller
|
||||
epoch value will always be popped out first. But for all messages
|
||||
with equal epoch values the queue functions as a normal FIFO queue.
|
||||
!*/
|
||||
private:
|
||||
struct msg_wrap
|
||||
{
|
||||
msg_wrap(
|
||||
const msg_data& data_,
|
||||
const dlib::uint64& sequence_number_
|
||||
) : data(data_), sequence_number(sequence_number_) {}
|
||||
|
||||
~thread_safe_deque()
|
||||
msg_wrap() : sequence_number(0){}
|
||||
|
||||
msg_data data;
|
||||
dlib::uint64 sequence_number;
|
||||
|
||||
// Make it so that when msg_wrap objects are in a std::priority_queue,
|
||||
// messages with a smaller epoch number always come first. Then, within an
|
||||
// epoch, messages are ordered by their sequence number (so smaller first
|
||||
// there as well).
|
||||
bool operator<(const msg_wrap& item) const
|
||||
{
|
||||
if (data.epoch < item.data.epoch)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
else if (data.epoch > item.data.epoch)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (sequence_number < item.sequence_number)
|
||||
return false;
|
||||
else
|
||||
return true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {}
|
||||
|
||||
~thread_safe_message_queue()
|
||||
{
|
||||
disable();
|
||||
}
|
||||
|
@ -230,19 +279,16 @@ namespace dlib
|
|||
sig.broadcast();
|
||||
}
|
||||
|
||||
unsigned long size() const { return data.size(); }
|
||||
|
||||
void push_front( const msg_data& item)
|
||||
{
|
||||
unsigned long size() const
|
||||
{
|
||||
auto_mutex lock(class_mutex);
|
||||
data.push_front(item);
|
||||
sig.signal();
|
||||
return data.size();
|
||||
}
|
||||
|
||||
void push_and_consume( msg_data& item)
|
||||
{
|
||||
auto_mutex lock(class_mutex);
|
||||
data.push_back(item);
|
||||
data.push(msg_wrap(item, next_seq_num++));
|
||||
// do this here so that we don't have to worry about different threads touching the shared_ptr.
|
||||
item.data.reset();
|
||||
sig.signal();
|
||||
|
@ -266,17 +312,43 @@ namespace dlib
|
|||
if (disabled)
|
||||
return false;
|
||||
|
||||
item = data.front();
|
||||
data.pop_front();
|
||||
item = data.top().data;
|
||||
data.pop();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool pop (
|
||||
msg_data& item,
|
||||
const dlib::uint64& max_epoch
|
||||
)
|
||||
/*!
|
||||
ensures
|
||||
- if (this function returns true) then
|
||||
- #item == the next thing from the queue that has an epoch <= max_epoch
|
||||
- else
|
||||
- this object is disabled
|
||||
!*/
|
||||
{
|
||||
auto_mutex lock(class_mutex);
|
||||
while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled)
|
||||
sig.wait();
|
||||
|
||||
if (disabled)
|
||||
return false;
|
||||
|
||||
item = data.top().data;
|
||||
data.pop();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
std::deque<msg_data> data;
|
||||
std::priority_queue<msg_wrap> data;
|
||||
dlib::mutex class_mutex;
|
||||
dlib::signaler sig;
|
||||
bool disabled;
|
||||
dlib::uint64 next_seq_num;
|
||||
};
|
||||
|
||||
|
||||
|
@ -396,9 +468,8 @@ namespace dlib
|
|||
);
|
||||
|
||||
|
||||
void send_byte (
|
||||
char val,
|
||||
unsigned long target_node_id
|
||||
void notify_control_node (
|
||||
char val
|
||||
);
|
||||
|
||||
void broadcast_byte (
|
||||
|
@ -423,8 +494,9 @@ namespace dlib
|
|||
unsigned long outstanding_messages;
|
||||
unsigned long num_waiting_nodes;
|
||||
unsigned long num_terminated_nodes;
|
||||
dlib::uint64 current_epoch;
|
||||
|
||||
impl1::thread_safe_deque msg_buffer;
|
||||
impl1::thread_safe_message_queue msg_buffer;
|
||||
|
||||
impl1::map_id_to_con& _cons;
|
||||
const unsigned long _node_id;
|
||||
|
|
Loading…
Reference in New Issue