Skip to content

Commit

Permalink
Support CRC validation on entire message (#448)
Browse files Browse the repository at this point in the history
* So far, CRC validation has been done on the message header, excluding
the user payload like log entries.

* Added a new Asio service option to enable CRC validation on the entire
message. It is disabled by default for backward compatibility.

* Also added a callback function that can be invoked on any corrupted
message, so as to do more investigation on user application side.

* Removed direct buffer get/put API calls, instead refactored the code
with `buffer_serializer`.
  • Loading branch information
greensky00 committed Jul 16, 2023
1 parent a8ca233 commit 41cc11d
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 51 deletions.
19 changes: 19 additions & 0 deletions include/libnuraft/asio_service_options.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.

#include <cstdint>
#include <functional>
#include <memory>
#include <string>
#include <system_error>

Expand All @@ -27,6 +28,7 @@ typedef struct ssl_ctx_st SSL_CTX;

namespace nuraft {

class buffer;
class req_msg;
class resp_msg;

Expand Down Expand Up @@ -120,6 +122,8 @@ struct asio_service_options {
, verify_sn_(nullptr)
, custom_resolver_(nullptr)
, replicate_log_timestamp_(false)
, crc_on_entire_message_(false)
, corrupted_msg_handler_(nullptr)
{}

/**
Expand Down Expand Up @@ -243,6 +247,21 @@ struct asio_service_options {
* this flag.
*/
bool replicate_log_timestamp_;

/**
* If `true`, NuRaft will validate the entire message with CRC.
* Otherwise, it validates the header part only.
*/
bool crc_on_entire_message_;

/**
* Callback function that will be invoked when the received message is corrupted.
* The first `buffer` contains the raw binary of message header,
* and the second `buffer` contains the user payload including metadata,
* if it is not null.
*/
std::function< void( std::shared_ptr<buffer>,
std::shared_ptr<buffer> ) > corrupted_msg_handler_;
};

}
Expand Down
210 changes: 163 additions & 47 deletions src/asio_service.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ limitations under the License.
// If set, each log entry will contain timestamp.
#define INCLUDE_LOG_TIMESTAMP (0x4)

// If set, CRC number represents the entire message.
#define CRC_ON_ENTIRE_MESSAGE (0x8)

// =======================

namespace nuraft {
Expand Down Expand Up @@ -238,6 +241,8 @@ class rpc_session
, src_id_(-1)
, is_leader_(false)
, cached_port_(0)
, crc_header_(0)
, crc_from_msg_(0)
{
p_tr("asio rpc session created: %p", this);
}
Expand Down Expand Up @@ -339,41 +344,67 @@ class rpc_session
// NOTE:
// due to async_read() above, header_ size will be always
// equal to or greater than RPC_REQ_HEADER_SIZE.
header_->pos(0);
byte* header_data = header_->data();
uint32_t crc_local = crc32_8( header_data,
RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN,
0 );

header_->pos(RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN);
uint64_t flags_and_crc = header_->get_ulong();
uint32_t crc_hdr = flags_and_crc & (uint32_t)0xffffffff;

// Deprecate `buffer::get` and use `buffer_serializer`.

// header_->pos(0);
buffer_serializer h_bs(header_);
byte* header_data = header_->data_begin();
crc_header_ = crc32_8( header_data,
RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN,
0 );

// header_->pos(RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN);
h_bs.pos(RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN);
// uint64_t flags_and_crc = header_->get_ulong();
uint64_t flags_and_crc = h_bs.get_u64();
crc_from_msg_ = flags_and_crc & (uint32_t)0xffffffff;
flags_ = (flags_and_crc >> 32);

// Verify CRC.
if (crc_local != crc_hdr) {
p_er("CRC mismatch: local calculation %x, from header %x",
crc_local, crc_hdr);
// Verify CRC (if entire message validation is disbaled).
if ( !(flags_ & CRC_ON_ENTIRE_MESSAGE) &&
crc_header_ != crc_from_msg_ ) {
p_er("header CRC mismatch: local calculation %x, from message %x",
crc_header_, crc_from_msg_);

if (impl_->get_options().corrupted_msg_handler_) {
impl_->get_options().corrupted_msg_handler_(header_, nullptr);
}

this->stop();
return;
}

header_->pos(0);
byte marker = header_->get_byte();
// header_->pos(0);
// byte marker = header_->get_byte();
h_bs.pos(0);
byte marker = h_bs.get_u8();
if (marker == 0x1) {
// Means that this is RPC_RESP, shouldn't happen.
p_er("Wrong packet: expected REQ, got RESP");

if (impl_->get_options().corrupted_msg_handler_) {
impl_->get_options().corrupted_msg_handler_(header_, nullptr);
}

this->stop();
return;
}

header_->pos(RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN - DATA_SIZE_LEN);
int32 data_size = header_->get_int();
// header_->pos(RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN - DATA_SIZE_LEN);
// int32 data_size = header_->get_int();
h_bs.pos(RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN - DATA_SIZE_LEN);
int32 data_size = h_bs.get_i32();
// Up to 1GB.
if (data_size < 0 || data_size > 0x40000000) {
p_er("bad log data size in the header %d, stop "
"this session to protect further corruption",
data_size);

if (impl_->get_options().corrupted_msg_handler_) {
impl_->get_options().corrupted_msg_handler_(header_, nullptr);
}

this->stop();
return;
}
Expand Down Expand Up @@ -465,14 +496,47 @@ class rpc_session
ptr<rpc_session> self = this->shared_from_this();

try {
hdr->pos(1);
msg_type t = (msg_type)hdr->get_byte();
int32 src = hdr->get_int();
int32 dst = hdr->get_int();
ulong term = hdr->get_ulong();
ulong last_term = hdr->get_ulong();
ulong last_idx = hdr->get_ulong();
ulong commit_idx = hdr->get_ulong();
// Deprecate `buffer::get` and use `buffer_serializer`.
// hdr->pos(1);
// msg_type t = (msg_type)hdr->get_byte();
// int32 src = hdr->get_int();
// int32 dst = hdr->get_int();
// ulong term = hdr->get_ulong();
// ulong last_term = hdr->get_ulong();
// ulong last_idx = hdr->get_ulong();
// ulong commit_idx = hdr->get_ulong();

buffer_serializer h_bs(header_);
h_bs.pos(1);
msg_type t = (msg_type)h_bs.get_u8();
int32 src = h_bs.get_i32();
int32 dst = h_bs.get_i32();
ulong term = h_bs.get_u64();
ulong last_term = h_bs.get_u64();
ulong last_idx = h_bs.get_u64();
ulong commit_idx = h_bs.get_u64();
int32 log_data_size = h_bs.get_i32();

if (flags_ & CRC_ON_ENTIRE_MESSAGE) {
// Calculate the CRC of `log_ctx`.
uint32_t crc_payload =
log_ctx
? crc32_8( log_ctx->data_begin(),
log_ctx->size(),
crc_header_ )
: crc_header_;
if (crc_payload != crc_from_msg_) {
p_er("request CRC mismatch: local calculation %x, from message %x",
crc_payload, crc_from_msg_);

if (impl_->get_options().corrupted_msg_handler_) {
impl_->get_options().corrupted_msg_handler_(header_, log_ctx);
}

this->stop();
return;
}
}

if (src_id_ == -1) {
// It means this is the first message on this session.
Expand Down Expand Up @@ -515,7 +579,7 @@ class rpc_session
std::string meta_str;
ptr<req_msg> req = cs_new<req_msg>
( term, t, src, dst, last_term, last_idx, commit_idx );
if (hdr->get_int() > 0 && log_ctx) {
if (log_data_size > 0 && log_ctx) {
buffer_serializer ss(log_ctx);
size_t log_ctx_size = log_ctx->size();

Expand All @@ -538,6 +602,11 @@ class rpc_session
// Possibly corrupted packet. Stop here.
p_wn("wrong log ctx size %zu pos %zu, stop this session",
log_ctx_size, ss.pos());

if (impl_->get_options().corrupted_msg_handler_) {
impl_->get_options().corrupted_msg_handler_(header_, log_ctx);
}

this->stop();
return;
}
Expand All @@ -551,6 +620,11 @@ class rpc_session
p_wn("wrong value size %zu log ctx %zu %zu, "
"stop this session",
val_size, log_ctx_size, ss.pos());

if (impl_->get_options().corrupted_msg_handler_) {
impl_->get_options().corrupted_msg_handler_(header_, log_ctx);
}

this->stop();
return;
}
Expand Down Expand Up @@ -744,6 +818,16 @@ class rpc_session

std::string cached_address_;
uint32_t cached_port_;

/**
* Locally calculated CRC number of the request header.
*/
uint32_t crc_header_;

/**
* CRC number from the request header.
*/
uint32_t crc_from_msg_;
};

// rpc listener implementation
Expand Down Expand Up @@ -1164,37 +1248,69 @@ class asio_rpc_client
ptr<buffer> req_buf =
buffer::alloc(RPC_REQ_HEADER_SIZE + meta_size + log_data_size);

req_buf->pos(0);
byte* req_buf_data = req_buf->data();

byte marker = 0x0;
req_buf->put(marker);
req_buf->put((byte)req->get_type());
req_buf->put(req->get_src());
req_buf->put(req->get_dst());
req_buf->put(req->get_term());
req_buf->put(req->get_last_log_term());
req_buf->put(req->get_last_log_idx());
req_buf->put(req->get_commit_idx());
req_buf->put((int32)meta_size + log_data_size);
// Deprecate `buffer::put` and use `buffer_serializer`.

// req_buf->pos(0);
// byte* req_buf_data = req_buf->data();

// byte marker = 0x0;
// req_buf->put(marker);
// req_buf->put((byte)req->get_type());
// req_buf->put(req->get_src());
// req_buf->put(req->get_dst());
// req_buf->put(req->get_term());
// req_buf->put(req->get_last_log_term());
// req_buf->put(req->get_last_log_idx());
// req_buf->put(req->get_commit_idx());
// req_buf->put((int32)meta_size + log_data_size);

buffer_serializer req_buf_bs(req_buf);
req_buf_bs.put_u8(0x0);
req_buf_bs.put_u8((byte)req->get_type());
req_buf_bs.put_i32(req->get_src());
req_buf_bs.put_i32(req->get_dst());
req_buf_bs.put_u64(req->get_term());
req_buf_bs.put_u64(req->get_last_log_term());
req_buf_bs.put_u64(req->get_last_log_idx());
req_buf_bs.put_u64(req->get_commit_idx());
req_buf_bs.put_i32((int32)meta_size + log_data_size);

// Calculate CRC32 on header-only.
uint32_t crc_val = crc32_8( req_buf_data,
RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN,
0 );
uint32_t crc_header = crc32_8( req_buf->data_begin(),
RPC_REQ_HEADER_SIZE - CRC_FLAGS_LEN,
0 );

uint64_t flags_and_crc = ((uint64_t)flags << 32) | crc_val;
req_buf->put((ulong)flags_and_crc);
uint64_t flags_and_crc = ((uint64_t)flags << 32) | crc_header;
// req_buf->put((ulong)flags_and_crc);
size_t crc_pos = req_buf_bs.pos();
req_buf_bs.put_u64(flags_and_crc);

// From now on, it will contain the payload (== meta + log entries).
// byte* req_buf_payload = req_buf->data();
size_t payload_pos = req_buf_bs.pos();

// Handling meta if the flag is set.
if (flags & INCLUDE_META) {
req_buf->put( (byte*)meta_str.data(), meta_str.size() );
// req_buf->put( (byte*)meta_str.data(), meta_str.size() );
req_buf_bs.put_bytes( (byte*)meta_str.data(), meta_str.size() );
}

for (auto& it: log_entry_bufs) {
req_buf->put(*(it));
// req_buf->put(*(it));
req_buf_bs.put_buffer(*(it));
}
// req_buf->pos(0);

if (impl_->get_options().crc_on_entire_message_) {
uint32_t crc_payload = crc32_8( req_buf->data_begin() + payload_pos,
meta_size + log_data_size,
crc_header );
// Overwrite CRC field.
flags |= CRC_ON_ENTIRE_MESSAGE;
flags_and_crc = ((uint64_t)flags << 32) | crc_payload;
req_buf_bs.pos(crc_pos);
req_buf_bs.put_u64(flags_and_crc);
}
req_buf->pos(0);

if (send_timeout_ms != 0)
{
Expand Down
Loading

0 comments on commit 41cc11d

Please sign in to comment.