diff --git a/examples/app/eos_ams.cpp b/examples/app/eos_ams.cpp index 6dcec2b4..cf7b9de5 100644 --- a/examples/app/eos_ams.cpp +++ b/examples/app/eos_ams.cpp @@ -5,10 +5,10 @@ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception */ -#include - #include "eos_ams.hpp" +#include + template void callBack(void *cls, long elements, @@ -47,12 +47,22 @@ AMSEOS::AMSEOS(EOS *model, uq_path, "ideal_gas", k_nearest); +#ifdef __ENABLE_MPI__ + wf_ = AMSCreateDistributedExecutor(model_descr, + dtype, + res_type, + (AMSPhysicFn)callBack, + MPI_COMM_WORLD, + mpi_task, + mpi_nproc); +#else wf_ = AMSCreateExecutor(model_descr, dtype, res_type, (AMSPhysicFn)callBack, mpi_task, mpi_nproc); +#endif } template diff --git a/src/AMSWorkflow/ams/rmq.py b/src/AMSWorkflow/ams/rmq.py index dc90ea1b..40addf1d 100644 --- a/src/AMSWorkflow/ams/rmq.py +++ b/src/AMSWorkflow/ams/rmq.py @@ -38,12 +38,13 @@ def header_format(self) -> str: - 2 bytes are the output dimension. Limit max: 65535 - 2 bytes are for aligning memory to 8 - |_Header_|_Datatype_|___Rank___|__DomainSize__|__#elems__|___InDim____|___OutDim___|_Pad_|.real data.| + |_Header_|_Datatype_|_Rank_|_DomainSize_|_#elems_|_InDim_|_OutDim_|_Pad_|_DomainName_|.Real_Data.| - Then the data starts at 16 and is structered as pairs of input/outputs. - Let K be the total number of elements, then we have K pairs of inputs/outputs (either float or double): + Then the data starts at byte 16 with the domain name, then the real data and + is structured as pairs of input/outputs. Let K be the total number of elements, + then we have K pairs of inputs/outputs (either float or double): - |__Header_(16B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| + |__Header_(16B)__|_Domain_Name_|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| """ return "BBHHIHHH" @@ -55,20 +56,23 @@ def endianness(self) -> str: """ return "=" - def encode(self, num_elem: int, input_dim: int, output_dim: int, dtype_byte: int = 4) -> bytes: + def encode(self, num_elem: int, domain_name: str, input_dim: int, output_dim: int, dtype_byte: int = 4) -> bytes: """ For debugging and testing purposes, this function encode a message identical to what AMS would send """ - header_format = self.endianness() + self.header_format() + header_format = self.ams_endianness() + self.ams_header_format() hsize = struct.calcsize(header_format) assert dtype_byte in [4, 8] dt = "f" if dtype_byte == 4 else "d" mpi_rank = 0 data = np.random.rand(num_elem * (input_dim + output_dim)) - header_content = (hsize, dtype_byte, mpi_rank, data.size, input_dim, output_dim) + domain_name_size = len(domain_name) + domain_name = bytes(domain_name, "utf-8") + padding = 0 + header_content = (hsize, dtype_byte, mpi_rank, domain_name_size, data.size, input_dim, output_dim, padding) # float or double - msg_format = f"{header_format}{data.size}{dt}" - return struct.pack(msg_format, *header_content, *data) + msg_format = f"{header_format}{domain_name_size}s{data.size}{dt}" + return struct.pack(msg_format, *header_content, domain_name, *data) def _parse_header(self, body: str) -> dict: """ diff --git a/src/AMSlib/AMS.cpp b/src/AMSlib/AMS.cpp index 632c2c6a..ef876eca 100644 --- a/src/AMSlib/AMS.cpp +++ b/src/AMSlib/AMS.cpp @@ -5,6 +5,8 @@ * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception */ +#include "AMS.h" + #include #include @@ -17,7 +19,6 @@ #include #include -#include "AMS.h" #include "include/AMS.h" #include "ml/uq.hpp" #include "wf/basedb.hpp" @@ -268,6 +269,7 @@ class AMSWrap std::unordered_map ams_candidate_models; AMSDBType dbType = AMSDBType::AMS_NONE; ams::ResourceManager &memManager; + int rId; private: void dumpEnv() @@ -372,10 +374,13 @@ class AMSWrap std::string rmq_user = getEntry(rmq_entry, "rabbitmq-user"); std::string rmq_vhost = getEntry(rmq_entry, "rabbitmq-vhost"); std::string rmq_cert = getEntry(rmq_entry, "rabbitmq-cert"); - std::string rmq_in_queue = - getEntry(rmq_entry, "rabbitmq-inbound-queue"); std::string rmq_out_queue = getEntry(rmq_entry, "rabbitmq-outbound-queue"); + std::string exchange = + getEntry(rmq_entry, "rabbitmq-exchange"); + std::string routing_key = + getEntry(rmq_entry, "rabbitmq-routing-key"); + bool update_surrogate = getEntry(entry, "update_surrogate"); auto &DB = ams::db::DBManager::getInstance(); DB.instantiate_rmq_db(port, @@ -385,8 +390,10 @@ class AMSWrap rmq_user, rmq_vhost, rmq_cert, - rmq_in_queue, - rmq_out_queue); + rmq_out_queue, + exchange, + routing_key, + update_surrogate); } void parseDatabase(json &jRoot) diff --git a/src/AMSlib/CMakeLists.txt b/src/AMSlib/CMakeLists.txt index ebcf6230..23b22ce0 100644 --- a/src/AMSlib/CMakeLists.txt +++ b/src/AMSlib/CMakeLists.txt @@ -17,6 +17,9 @@ if (WITH_HDF5) list(APPEND AMS_LIB_SRC wf/hdf5db.cpp) endif() +if (WITH_RMQ) + list(APPEND AMS_LIB_SRC wf/rmqdb.cpp) +endif() diff --git a/src/AMSlib/ml/surrogate.hpp b/src/AMSlib/ml/surrogate.hpp index b194de94..da90b815 100644 --- a/src/AMSlib/ml/surrogate.hpp +++ b/src/AMSlib/ml/surrogate.hpp @@ -135,7 +135,7 @@ class SurrogateModel inline void _load(const std::string& model_path, const std::string& device_name) { - DBG(Surrogate, "Using model at double precision"); + DBG(Surrogate, "Using model at double precision: %s", model_path.c_str()); _load_torch(model_path, torch::Device(device_name), torch::kFloat64); } @@ -145,7 +145,7 @@ class SurrogateModel inline void _load(const std::string& model_path, const std::string& device_name) { - DBG(Surrogate, "Using model at single precision"); + DBG(Surrogate, "Using model at single precision: %s", model_path.c_str()); _load_torch(model_path, torch::Device(device_name), torch::kFloat32); } diff --git a/src/AMSlib/wf/basedb.hpp b/src/AMSlib/wf/basedb.hpp index 334228d2..abee073e 100644 --- a/src/AMSlib/wf/basedb.hpp +++ b/src/AMSlib/wf/basedb.hpp @@ -64,6 +64,7 @@ namespace fs = std::experimental::filesystem; #include #include #include +#include #include #include @@ -85,12 +86,16 @@ class BaseDB { /** @brief unique id of the process running this simulation */ uint64_t id; + /** @brief True if surrogate model update is allowed */ + bool allowUpdate; public: BaseDB(const BaseDB&) = delete; BaseDB& operator=(const BaseDB&) = delete; - BaseDB(uint64_t id) : id(id) {} + BaseDB(uint64_t id) : id(id), allowUpdate(false) {} + + BaseDB(uint64_t id, bool allowUpdate) : id(id), allowUpdate(allowUpdate) {} virtual void close() {} @@ -125,11 +130,14 @@ class BaseDB std::vector& outputs, bool* predicate = nullptr) = 0; - uint64_t getId() const { return id; } + bool allowModelUpdate() { return allowUpdate; } + virtual bool updateModel() { return false; } + virtual std::string getLatestModel() { return {}; } + virtual bool storePredicate() const { return false; } }; @@ -644,6 +652,8 @@ class RedisDB : public BaseDB #ifdef __ENABLE_RMQ__ +enum RMQConnectionStatus { FAILED, CONNECTED, CLOSED, ERROR }; + /** * @brief AMS represents the header as follows: * The header is 16 bytes long: @@ -696,16 +706,7 @@ struct AMSMsgHeader { size_t num_elem, size_t in_dim, size_t out_dim, - size_t type_size) - : hsize(static_cast(AMSMsgHeader::size())), - dtype(static_cast(type_size)), - mpi_rank(static_cast(mpi_rank)), - domain_size(static_cast(domain_size)), - num_elem(static_cast(num_elem)), - in_dim(static_cast(in_dim)), - out_dim(static_cast(out_dim)) - { - } + size_t type_size); /** * @brief Constructor for AMSMsgHeader @@ -719,16 +720,7 @@ struct AMSMsgHeader { uint32_t num_elem, uint16_t in_dim, uint16_t out_dim, - uint8_t type_size) - : hsize(static_cast(AMSMsgHeader::size())), - dtype(type_size), - mpi_rank(mpi_rank), - domain_size(domain_size), - num_elem(num_elem), - in_dim(in_dim), - out_dim(out_dim) - { - } + uint8_t type_size); /** * @brief Return the size of a header in the AMS protocol. @@ -748,92 +740,14 @@ struct AMSMsgHeader { * @param[in] data_blob The buffer to fill * @return The number of bytes in the header or 0 if error */ - size_t encode(uint8_t* data_blob) - { - if (!data_blob) return 0; - - size_t current_offset = 0; - // Header size (should be 1 bytes) - data_blob[current_offset] = hsize; - current_offset += sizeof(hsize); - // Data type (should be 1 bytes) - data_blob[current_offset] = dtype; - current_offset += sizeof(dtype); - // MPI rank (should be 2 bytes) - std::memcpy(data_blob + current_offset, &(mpi_rank), sizeof(mpi_rank)); - current_offset += sizeof(mpi_rank); - // Domain Size (should be 2 bytes) - DBG(AMSMsgHeader, - "Generating domain name of size %d --- %d offset %d", - domain_size, - sizeof(domain_size), - current_offset); - std::memcpy(data_blob + current_offset, - &(domain_size), - sizeof(domain_size)); - current_offset += sizeof(domain_size); - // Num elem (should be 4 bytes) - std::memcpy(data_blob + current_offset, &(num_elem), sizeof(num_elem)); - current_offset += sizeof(num_elem); - // Input dim (should be 2 bytes) - std::memcpy(data_blob + current_offset, &(in_dim), sizeof(in_dim)); - current_offset += sizeof(in_dim); - // Output dim (should be 2 bytes) - std::memcpy(data_blob + current_offset, &(out_dim), sizeof(out_dim)); - current_offset += sizeof(out_dim); - - return AMSMsgHeader::size(); - } + size_t encode(uint8_t* data_blob); /** * @brief Return a valid header based on a pre-existing data buffer * @param[in] data_blob The buffer to fill * @return An AMSMsgHeader with the correct attributes */ - static AMSMsgHeader decode(uint8_t* data_blob) - { - size_t current_offset = 0; - // Header size (should be 1 bytes) - uint8_t new_hsize = data_blob[current_offset]; - CWARNING(AMSMsgHeader, - new_hsize != AMSMsgHeader::size(), - "buffer is likely not a valid AMSMessage (%d / %ld)", - new_hsize, - current_offset) - - current_offset += sizeof(uint8_t); - // Data type (should be 1 bytes) - uint8_t new_dtype = data_blob[current_offset]; - current_offset += sizeof(uint8_t); - // MPI rank (should be 2 bytes) - uint16_t new_mpirank = - (reinterpret_cast(data_blob + current_offset))[0]; - current_offset += sizeof(uint16_t); - - // Domain Size (should be 2 bytes) - uint16_t new_domain_size = - (reinterpret_cast(data_blob + current_offset))[0]; - current_offset += sizeof(uint16_t); - - // Num elem (should be 4 bytes) - uint32_t new_num_elem; - std::memcpy(&new_num_elem, data_blob + current_offset, sizeof(uint32_t)); - current_offset += sizeof(uint32_t); - // Input dim (should be 2 bytes) - uint16_t new_in_dim; - std::memcpy(&new_in_dim, data_blob + current_offset, sizeof(uint16_t)); - current_offset += sizeof(uint16_t); - // Output dim (should be 2 bytes) - uint16_t new_out_dim; - std::memcpy(&new_out_dim, data_blob + current_offset, sizeof(uint16_t)); - - return AMSMsgHeader(new_mpirank, - new_domain_size, - new_num_elem, - new_in_dim, - new_out_dim, - new_dtype); - } + static AMSMsgHeader decode(uint8_t* data_blob); }; @@ -846,7 +760,7 @@ class AMSMessage /** @brief message ID */ int _id; /** @brief The MPI rank (0 if MPI is not used) */ - int _rank; + uint64_t _rank; /** @brief The data represented as a binary blob */ uint8_t* _data; /** @brief The total size of the binary blob in bytes */ @@ -872,46 +786,29 @@ class AMSMessage { } - /** - * @brief Internal Method swapping for AMSMessage - * @param[in] other Message to swap - */ - void swap(const AMSMessage& other) - { - _id = other._id; - _rank = other._rank; - _num_elements = other._num_elements; - _input_dim = other._input_dim; - _output_dim = other._output_dim; - _total_size = other._total_size; - _data = other._data; - } - -public: /** * @brief Constructor * @param[in] id ID of the message + * @param[in] rId MPI Rank of the messages (0 default) * @param[in] num_elements Number of elements * @param[in] inputs Inputs * @param[in] outputs Outputs */ template AMSMessage(int id, + uint64_t rId, std::string& domain_name, size_t num_elements, const std::vector& inputs, const std::vector& outputs) : _id(id), - _rank(0), + _rank(rId), _num_elements(num_elements), _input_dim(inputs.size()), _output_dim(outputs.size()), _data(nullptr), _total_size(0) { -#ifdef __ENABLE_MPI__ - MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); -#endif AMSMsgHeader header(_rank, domain_name.size(), _num_elements, @@ -939,38 +836,10 @@ class AMSMessage /** * @brief Constructor * @param[in] id ID of the message + * @param[in] rId MPI rank of the message * @param[in] data Pointer containing data */ - AMSMessage(int id, uint8_t* data) - : _id(id), - _num_elements(0), - _input_dim(0), - _output_dim(0), - _data(data), - _total_size(0) - { - auto header = AMSMsgHeader::decode(data); - - int current_rank = 0; -#ifdef __ENABLE_MPI__ - MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, ¤t_rank)); -#endif - _rank = header.mpi_rank; - CWARNING(AMSMessage, - _rank != current_rank, - "MPI rank are not matching (using %d)", - _rank) - - _num_elements = header.num_elem; - _input_dim = header.in_dim; - _output_dim = header.out_dim; - _data = data; - auto type_value = header.dtype; - - _total_size = AMSMsgHeader::size() + getTotalElements() * type_value; - - DBG(AMSMessage, "Allocated message %d: %p", _id, _data); - } + AMSMessage(int id, uint64_t rId, uint8_t* data); AMSMessage(const AMSMessage& other) { @@ -978,6 +847,12 @@ class AMSMessage swap(other); }; + /** + * @brief Internal Method swapping for AMSMessage + * @param[in] other Message to swap + */ + void swap(const AMSMessage& other); + AMSMessage& operator=(const AMSMessage&) = delete; AMSMessage(AMSMessage&& other) noexcept { *this = std::move(other); } @@ -1065,61 +940,125 @@ class AMSMessage } }; // class AMSMessage -/** @brief Structure that represents a received RabbitMQ message. - * - The first field is the message content (body) - * - The second field is the RMQ exchange from which the message - * has been received - * - The third field is the routing key - * - The fourth is the delivery tag (ID of the message) - * - The fifth field is a boolean that indicates if that message - * has been redelivered by RMQ. + +/** + * @brief Structure that represents incoming RabbitMQ messages. */ -typedef std::tuple - inbound_msg; +class AMSMessageInbound +{ +public: + /** @brief Delivery tag (ID of the message) */ + uint64_t id; + /** @brief MPI rank */ + uint64_t rId; + /** @brief message content (body) */ + std::string body; + /** @brief RabbitMQ exchange from which the message has been received */ + std::string exchange; + /** @brief routing key */ + std::string routing_key; + /** @brief True if messages has been redelivered */ + bool redelivered; + + AMSMessageInbound() = default; + + AMSMessageInbound(AMSMessageInbound&) = default; + AMSMessageInbound& operator=(AMSMessageInbound&) = default; + + AMSMessageInbound(AMSMessageInbound&&) = default; + AMSMessageInbound& operator=(AMSMessageInbound&&) = default; + + AMSMessageInbound(uint64_t id, + uint64_t rId, + std::string body, + std::string exchange, + std::string routing_key, + bool redelivered); + + /** + * @brief Check if a message is empty. + * @return True if message is empty + */ + bool empty(); + + /** + * @brief Check if a message is empty. + * @return True if message is empty. + */ + bool isTraining(); + + /** + * @brief Get the model path from the message. + * @return Return model path or empty string if no model available. + */ + std::string getModelPath(); + +private: + /** + * @brief Check if a message is empty. + * @return True if message is empty + */ + std::vector splitString(std::string str, std::string delimiter); + +}; // class AMSMessageInbound + /** * @brief Specific handler for RabbitMQ connections based on libevent. */ -class RMQConsumerHandler final : public AMQP::LibEventHandler +class RMQHandler : public AMQP::LibEventHandler { -private: +protected: /** @brief Path to TLS certificate */ std::string _cacert; - /** @brief The MPI rank (0 if MPI is not used) */ - int _rank; + /** @brief MPI rank (0 if no MPI support) */ + uint64_t _rId; /** @brief LibEvent I/O loop */ std::shared_ptr _loop; - /** @brief main channel used to send data to the broker */ - std::shared_ptr _channel; - /** @brief RabbitMQ queue */ - std::string _queue; - /** @brief Queue that contains all the messages received on receiver queue */ - std::shared_ptr> _messages; + + std::promise establish_connection; + std::future established; + + std::promise close_connection; + std::future closed; + + std::promise error_connection; + std::future ftr_error; public: /** * @brief Constructor * @param[in] loop Event Loop + * @param[in] rId MPI rank * @param[in] cacert SSL Cacert - * @param[in] rank MPI rank */ - RMQConsumerHandler(std::shared_ptr loop, - std::string cacert, - std::string queue) - : AMQP::LibEventHandler(loop.get()), - _loop(loop), - _rank(0), - _cacert(std::move(cacert)), - _queue(queue), - _messages(std::make_shared>()), - _channel(nullptr) - { -#ifdef __ENABLE_MPI__ - MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); -#endif - } + RMQHandler(uint64_t rId, + std::shared_ptr loop, + std::string cacert); - ~RMQConsumerHandler() = default; + ~RMQHandler() = default; + + /** + * @brief Wait (blocking call) until connection has been established or that ms * repeat is over. + * @param[in] ms Number of milliseconds the function will wait on the future + * @param[in] repeat Number of times the function will wait + * @return True if connection has been established + */ + bool waitToEstablish(unsigned ms, int repeat = 1); + + /** + * @brief Wait (blocking call) until connection has been closed or that ms * repeat is over. + * @param[in] ms Number of milliseconds the function will wait on the future + * @param[in] repeat Number of times the function will wait + * @return True if connection has been closed + */ + bool waitToClose(unsigned ms, int repeat = 1); + + /** + * @brief Check if the connection can be used to send messages. + * @return True if connection is valid (i.e., can send messages) + */ + bool connectionValid(); private: /** @@ -1133,34 +1072,7 @@ class RMQConsumerHandler final : public AMQP::LibEventHandler * @return bool True to proceed / accept the connection, false * to break up */ - virtual bool onSecuring(AMQP::TcpConnection* connection, SSL* ssl) override - { - ERR_clear_error(); - unsigned long err; -#if OPENSSL_VERSION_NUMBER < 0x10100000L - int ret = SSL_use_certificate_file(ssl, _cacert.c_str(), SSL_FILETYPE_PEM); -#else - int ret = SSL_use_certificate_chain_file(ssl, _cacert.c_str()); -#endif - // FIXME: with openssl 3.0 - // Set => SSL_set_options(ssl, SSL_OP_IGNORE_UNEXPECTED_EOF); - - if (ret != 1) { - std::string error("openssl: error loading ca-chain (" + _cacert + - ") + from ["); - SSL_get_error(ssl, ret); - if ((err = ERR_get_error())) { - error += std::string(ERR_reason_error_string(err)); - } - error += "]"; - throw std::runtime_error(error); - } else { - DBG(RMQConsumerHandler, - "Success logged with ca-chain %s", - _cacert.c_str()) - return true; - } - } + virtual bool onSecuring(AMQP::TcpConnection* connection, SSL* ssl) override; /** * @brief Method that is called when the secure TLS connection has been @@ -1173,112 +1085,7 @@ class RMQConsumerHandler final : public AMQP::LibEventHandler * @return bool True if connection can be used */ virtual bool onSecured(AMQP::TcpConnection* connection, - const SSL* ssl) override - { - DBG(RMQConsumerHandler, - "[rank=%d] Secured TLS connection has been established.", - _rank) - return true; - } - - /** - * @brief Method that is called by the AMQP library when the login attempt - * succeeded. After this the connection is ready to use. - * @param[in] connection The connection that can now be used - */ - virtual void onReady(AMQP::TcpConnection* connection) override - { - DBG(RMQConsumerHandler, - "[rank=%d] Sucessfuly logged in. Connection ready to use.", - _rank) - - _channel = std::make_shared(connection); - _channel->onError([&](const char* message) { - CFATAL(RMQConsumerHandler, - false, - "[rank=%d] Error on channel: %s", - _rank, - message) - }); - - _channel->declareQueue(_queue) - .onSuccess([&](const std::string& name, - uint32_t messagecount, - uint32_t consumercount) { - if (messagecount > 0 || consumercount > 1) { - CWARNING(RMQConsumerHandler, - _rank == 0, - "[rank=%d] declared queue: %s (messagecount=%d, " - "consumercount=%d)", - _rank, - _queue.c_str(), - messagecount, - consumercount) - } - // We can now install callback functions for when we will consumme messages - // callback function that is called when the consume operation starts - auto startCb = [](const std::string& consumertag) { - DBG(RMQConsumerHandler, - "consume operation started with tag: %s", - consumertag.c_str()) - }; - - // callback function that is called when the consume operation failed - auto errorCb = [](const char* message) { - CFATAL(RMQConsumerHandler, - false, - "consume operation failed: %s", - message); - }; - // callback operation when a message was received - auto messageCb = [&](const AMQP::Message& message, - uint64_t deliveryTag, - bool redelivered) { - // acknowledge the message - _channel->ack(deliveryTag); - std::string msg(message.body(), message.bodySize()); - DBG(RMQConsumerHandler, - "message received [tag=%lu] : '%s' of size %lu B from " - "'%s'/'%s'", - deliveryTag, - msg.c_str(), - message.bodySize(), - message.exchange().c_str(), - message.routingkey().c_str()) - _messages->push_back(std::make_tuple(std::move(msg), - message.exchange(), - message.routingkey(), - deliveryTag, - redelivered)); - }; - - /* callback that is called when the consumer is cancelled by RabbitMQ (this - * only happens in rare situations, for example when someone removes the queue - * that you are consuming from) - */ - auto cancelledCb = [](const std::string& consumertag) { - WARNING(RMQConsumerHandler, - "consume operation cancelled by the RabbitMQ server: %s", - consumertag.c_str()) - }; - - // start consuming from the queue, and install the callbacks - _channel->consume(_queue) - .onReceived(messageCb) - .onSuccess(startCb) - .onCancelled(cancelledCb) - .onError(errorCb); - }) - .onError([&](const char* message) { - CFATAL(RMQConsumerHandler, - false, - "[ERROR][rank=%d] Error while creating broker queue (%s): " - "%s", - _rank, - _queue.c_str(), - message) - }); - } + const SSL* ssl) override; /** * Method that is called when the AMQP protocol is ended. This is the @@ -1287,10 +1094,7 @@ class RMQConsumerHandler final : public AMQP::LibEventHandler * active, and you will also receive calls to onLost() and onDetached() * @param connection The connection over which the AMQP protocol ended */ - virtual void onClosed(AMQP::TcpConnection* connection) override - { - DBG(RMQConsumerHandler, "[rank=%d] Connection is closed.\n", _rank) - } + virtual void onClosed(AMQP::TcpConnection* connection) override; /** * @brief Method that is called by the AMQP library when a fatal error occurs @@ -1302,26 +1106,107 @@ class RMQConsumerHandler final : public AMQP::LibEventHandler * @param[in] message A human readable error message */ virtual void onError(AMQP::TcpConnection* connection, - const char* message) override - { - DBG(RMQConsumerHandler, - "[rank=%d] fatal error when establishing TCP connection: %s\n", - _rank, - message) - } + const char* message) override; /** - * Final method that is called. This signals that no further calls to your + * @brief Final method that is called. This signals that no further calls to your * handler will be made about the connection. * @param connection The connection that can be destructed */ - virtual void onDetached(AMQP::TcpConnection* connection) override - { - // add your own implementation, like cleanup resources or exit the application - DBG(RMQConsumerHandler, "[rank=%d] Connection is detached.\n", _rank) - } + virtual void onDetached(AMQP::TcpConnection* connection) override; + + bool waitFuture(std::future& future, + unsigned ms, + int repeat); +}; // class RMQHandler + +/** + * @brief Specific handler for RabbitMQ connections based on libevent. + * + * Each MPI rank has its RMQConsumerHandler managing its own RabbitMQ queue. + * RabbitMQ will generate random queue name, this queue will be bound + * to the exchange provided. + * + * Important, if the exchange already exist for a given ExchangeType (from + * a previous run for example), then trying to create an exchange with the + * same name but with a different ExchangeType will lead to a crash. In that + * case, either you remove the exchange manually on the RabbitMQ server or + * you use an exchange name that does not exist (different name). + * + * Note that, if messages are sent to that exchange before a queue is bound, + * these messages are lost. RabbitMQ can notify the sender that these messages + * never arrived if the sender uses publication confirmation. + */ +class RMQConsumerHandler final : public RMQHandler +{ +private: + /** @brief main channel used to send data to the broker */ + std::shared_ptr _channel; + /** @brief RabbitMQ queue (internal use only) */ + std::string _queue; + /** @brief RabbitMQ exchange */ + std::string _exchange; + /** @brief RabbitMQ routing key */ + std::string _routing_key; + /** @brief Type of the exchange used (AMQP::topic, AMQP::fanout, AMQP::direct) */ + AMQP::ExchangeType _extype; + /** @brief Queue that contains all the messages received on receiver queue */ + std::shared_ptr> _messages; + +public: + /** + * @brief Constructor + * @param[in] loop Event Loop + * @param[in] cacert SSL Cacert + * @param[in] routing_key Routing key + * @param[in] exchange Exchange + */ + RMQConsumerHandler(uint64_t rId, + std::shared_ptr loop, + std::string cacert, + std::string exchange, + std::string routing_key, + AMQP::ExchangeType extype = AMQP::fanout); + + /** + * @brief Delete the message with given ID + * @param[in] delivery_tag Delivery tag that will be deleted (if found) + */ + void delMessage(uint64_t delivery_tag) { getMessages(delivery_tag, true); } + + /** + * @brief Check if messages received contains new model paths + * @return Return a tuple with the ID and path of the latest model available or ID=0 and empty string if no model available + */ + std::tuple getLatestModel(); + + /** + * @brief Return the most recent messages and delete it + * @return A structure AMSMessageInbound which is a std::tuple (see typedef) + */ + AMSMessageInbound popMessages(); + + /** + * @brief Return the message corresponding to the delivery tag. Do not delete the + * message. + * @param[in] delivery_tag Delivery tag that will be returned (if found) + * @param[in] erase if True, the element will also be deleted from underyling structure + * @return A structure AMSMessageInbound which is a std::tuple (see typedef) + */ + AMSMessageInbound getMessages(uint64_t delivery_tag, bool erase); + + ~RMQConsumerHandler() = default; + +private: + /** + * @brief Method that is called by the AMQP library when the login attempt + * succeeded. After this the connection is ready to use. + * @param[in] connection The connection that can now be used + */ + virtual void onReady(AMQP::TcpConnection* connection) override; }; // class RMQConsumerHandler + /** * @brief Class that manages a RabbitMQ broker and handles connection, event * loop and set up various handlers. @@ -1331,130 +1216,99 @@ class RMQConsumer private: /** @brief Connection to the broker */ AMQP::TcpConnection* _connection; - /** @brief name of the queue to send data */ - std::string _queue; + /** @brief name of the exchange */ + std::string _exchange; + /** @brief name of the routing binded to exchange */ + std::string _routing_key; /** @brief TLS certificate file */ std::string _cacert; /** @brief MPI rank (if MPI is used, otherwise 0) */ - int _rank; + uint64_t _rId; /** @brief The event loop for sender (usually the default one in libevent) */ std::shared_ptr _loop; /** @brief The handler which contains various callbacks for the sender */ std::shared_ptr _handler; /** @brief Queue that contains all the messages received on receiver queue (messages can be popped in) */ - std::vector _messages; + std::vector _messages; public: RMQConsumer(const RMQConsumer&) = delete; RMQConsumer& operator=(const RMQConsumer&) = delete; - RMQConsumer(const AMQP::Address& address, + RMQConsumer(uint64_t rId, + const AMQP::Address& address, std::string cacert, - std::string queue) - : _rank(0), _queue(queue), _cacert(cacert), _handler(nullptr) - { -#ifdef __ENABLE_MPI__ - MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); -#endif -#ifdef EVTHREAD_USE_PTHREADS_IMPLEMENTED - evthread_use_pthreads(); -#endif - CDEBUG(RMQConsumer, - _rank == 0, - "Libevent %s (LIBEVENT_VERSION_NUMBER = %#010x)", - event_get_version(), - event_get_version_number()); - CDEBUG(RMQConsumer, - _rank == 0, - "%s (OPENSSL_VERSION_NUMBER = %#010x)", - OPENSSL_VERSION_TEXT, - OPENSSL_VERSION_NUMBER); -#if OPENSSL_VERSION_NUMBER < 0x10100000L - SSL_library_init(); -#else - OPENSSL_init_ssl(0, NULL); -#endif - CINFO(RMQConsumer, - _rank == 0, - "RabbitMQ address: %s:%d/%s (queue = %s)", - address.hostname().c_str(), - address.port(), - address.vhost().c_str(), - _queue.c_str()) - - _loop = std::shared_ptr(event_base_new(), - [](struct event_base* event) { - event_base_free(event); - }); - _handler = std::make_shared(_loop, _cacert, _queue); - _connection = new AMQP::TcpConnection(_handler.get(), address); - } + std::string routing_key, + std::string exchange); /** - * @brief Start the underlying I/O loop (blocking call) + * @brief Start the underlying I/O loop (blocking call) */ - void start() { event_base_dispatch(_loop.get()); } + void start(); /** - * @brief Stop the underlying I/O loop + * @brief Stop the underlying I/O loop */ - void stop() { event_base_loopexit(_loop.get(), NULL); } + void stop(); /** - * @brief Return the most recent messages and delete it - * @return A structure inbound_msg which is a std::tuple (see typedef) + * @brief Check if the underlying RabbitMQ connection is ready and usable + * @return True if the publisher is ready to publish */ - inbound_msg pop_messages() - { - if (!_messages.empty()) { - inbound_msg msg = _messages.back(); - _messages.pop_back(); - return msg; - } - return std::make_tuple("", "", "", -1, false); - } + bool ready(); /** - * @brief Return the message corresponding to the delivery tag. Do not delete the - * message. - * @param[in] delivery_tag Delivery tag that will be returned (if found) - * @return A structure inbound_msg which is a std::tuple (see typedef) + * @brief Wait that the connection is ready (blocking call) + * @param[in] ms Number of milliseconds to wait between each tentative + * @param[in] repeat Number of tentatives + * @return True if the publisher is ready to publish */ - inbound_msg get_messages(uint64_t delivery_tag) - { - if (!_messages.empty()) { - auto it = std::find_if(_messages.begin(), - _messages.end(), - [&delivery_tag](const inbound_msg& e) { - return std::get<3>(e) == delivery_tag; - }); - if (it != _messages.end()) return *it; - } - return std::make_tuple("", "", "", -1, false); - } + bool waitToEstablish(unsigned ms, int repeat = 1); - ~RMQConsumer() - { - _connection->close(false); - delete _connection; - } -}; // class RMQConsumer + /** + * @brief Return the most recent messages and delete it + * @return A structure AMSMessageInbound which is a std::tuple (see typedef) + */ + AMSMessageInbound popMessages(); -enum RMQConnectionStatus { FAILED, CONNECTED, CLOSED, ERROR }; + /** + * @brief Delete the message with given ID + * @param[in] delivery_tag Delivery tag that will be deleted (if found) + */ + void delMessage(uint64_t delivery_tag); + + /** + * @brief Return the message corresponding to the delivery tag. Do not delete the + * message. + * @param[in] delivery_tag Delivery tag that will be returned (if found) + * @param[in] erase if True, the element will also be deleted from underyling structure + * @return A structure AMSMessageInbound which is a std::tuple (see typedef) + */ + AMSMessageInbound getMessages(uint64_t delivery_tag, bool erase = false); + + /** + * @brief Return the path of latest ML model available + * @return Tuple with ID of new model and ML model path or empty string if no model available + */ + std::tuple getLatestModel(); + + /** + * @brief Close the unerlying connection + * @param[in] ms Number of milliseconds to wait between each tentative + * @param[in] repeat Number of tentatives + * @return True if connection was closed properly + */ + bool close(unsigned ms, int repeat = 1); + + ~RMQConsumer(); +}; // class RMQConsumer /** * @brief Specific handler for RabbitMQ connections based on libevent. */ -class RMQPublisherHandler final : public AMQP::LibEventHandler +class RMQPublisherHandler final : public RMQHandler { private: - /** @brief Path to TLS certificate */ - std::string _cacert; - /** @brief The MPI rank (0 if MPI is not used) */ - int _rank; - /** @brief LibEvent I/O loop */ - std::shared_ptr _loop; - /** @brief main channel used to send data to the broker */ std::shared_ptr _channel; /** @brief AMQP reliable channel (wrapper of classic channel with added functionalities) */ std::shared_ptr> _rchannel; @@ -1464,16 +1318,7 @@ class RMQPublisherHandler final : public AMQP::LibEventHandler int _nb_msg; /** @brief Number of messages successfully acknowledged */ int _nb_msg_ack; - - std::promise establish_connection; - std::future established; - - std::promise close_connection; - std::future closed; - - std::promise _error_connection; - std::future _ftr_error; - + /** @brief Mutex to protect multithread accesses to _messages */ std::mutex _mutex; /** @brief Messages that have not been successfully acknowledged */ std::vector _messages; @@ -1485,26 +1330,10 @@ class RMQPublisherHandler final : public AMQP::LibEventHandler * @param[in] cacert SSL Cacert * @param[in] rank MPI rank */ - RMQPublisherHandler(std::shared_ptr loop, + RMQPublisherHandler(uint64_t rId, + std::shared_ptr loop, std::string cacert, - std::string queue) - : AMQP::LibEventHandler(loop.get()), - _loop(loop), - _rank(0), - _cacert(std::move(cacert)), - _queue(queue), - _nb_msg_ack(0), - _nb_msg(0), - _channel(nullptr), - _rchannel(nullptr) - { -#ifdef __ENABLE_MPI__ - MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); -#endif - established = establish_connection.get_future(); - closed = close_connection.get_future(); - _ftr_error = _error_connection.get_future(); - } + std::string queue); ~RMQPublisherHandler() = default; @@ -1512,358 +1341,63 @@ class RMQPublisherHandler final : public AMQP::LibEventHandler * @brief Publish data on RMQ queue. * @param[in] msg The AMSMessage to publish */ - void publish(AMSMessage&& msg) - { - { - const std::lock_guard lock(_mutex); - _messages.push_back(msg); - } - if (_rchannel) { - // publish a message via the reliable-channel - // onAck : message has been explicitly ack'ed by RabbitMQ - // onNack : message has been explicitly nack'ed by RabbitMQ - // onError : error occurred before any ack or nack was received - // onLost : messages that have either been nack'ed, or lost - _rchannel - ->publish("", _queue, reinterpret_cast(msg.data()), msg.size()) - .onAck([this, - &_nb_msg_ack = _nb_msg_ack, - id = msg.id(), - data = msg.data(), - &_messages = this->_messages]() mutable { - DBG(RMQPublisherHandler, - "[rank=%d] message #%d (Addr:%p) got acknowledged " - "successfully " - "by " - "RMQ " - "server", - _rank, - id, - data) - this->free_ams_message(id, _messages); - _nb_msg_ack++; - }) - .onNack([this, id = msg.id(), data = msg.data()]() mutable { - WARNING(RMQPublisherHandler, - "[rank=%d] message #%d (%p) received negative " - "acknowledged " - "by " - "RMQ " - "server", - _rank, - id, - data) - }) - .onError([this, id = msg.id(), data = msg.data()]( - const char* err_message) mutable { - WARNING(RMQPublisherHandler, - "[rank=%d] message #%d (%p) did not get send: %s", - _rank, - id, - data, - err_message) - }); - } else { - WARNING(RMQPublisherHandler, - "[rank=%d] The reliable channel was not ready for message #%d.", - _rank, - msg.id()) - } - _nb_msg++; - } - - /** - * @brief Wait (blocking call) until connection has been established or that ms * repeat is over. - * @param[in] ms Number of milliseconds the function will wait on the future - * @param[in] repeat Number of times the function will wait - * @return True if connection has been established - */ - bool waitToEstablish(unsigned ms, int repeat = 1) - { - if (waitFuture(established, ms, repeat)) { - auto status = established.get(); - DBG(RMQPublisherHandler, "Connection Status: %d", status); - return status == CONNECTED; - } - return false; - } - - /** - * @brief Wait (blocking call) until connection has been closed or that ms * repeat is over. - * @param[in] ms Number of milliseconds the function will wait on the future - * @param[in] repeat Number of times the function will wait - * @return True if connection has been closed - */ - bool waitToClose(unsigned ms, int repeat = 1) - { - if (waitFuture(closed, ms, repeat)) { - return closed.get() == CLOSED; - } - return false; - } - - /** - * @brief Check if the connection can be used to send messages. - * @return True if connection is valid (i.e., can send messages) - */ - bool connection_valid() - { - std::chrono::milliseconds span(1); - return _ftr_error.wait_for(span) != std::future_status::ready; - } + void publish(AMSMessage&& msg); /** * @brief Return the messages that have NOT been acknowledged by the RabbitMQ server. * @return A vector of AMSMessage */ - std::vector& internal_msg_buffer() { return _messages; } + std::vector& msgBuffer(); /** * @brief Free AMSMessages held by the handler */ - void cleanup() { free_all_messages(_messages); } + void cleanup(); /** * @brief Total number of messages sent * @return Number of messages */ - int msg_sent() const { return _nb_msg; } + int msgSent() const; /** * @brief Total number of messages successfully acknowledged * @return Number of messages */ - int msg_acknowledged() const { return _nb_msg_ack; } - - unsigned unacknowledged() const { return _rchannel->unacknowledged(); } - - void flush() - { - uint32_t tries = 0; - while (auto unAck = unacknowledged()) { - DBG(RMQPublisherHandler, - "Waiting for %lu messages to be acknowledged", - unAck); - - if (++tries > 10) break; - std::this_thread::sleep_for(std::chrono::milliseconds(50 * tries)); - } - free_all_messages(_messages); - } + int msgAcknowledged() const; -private: /** - * @brief Method that is called after a TCP connection has been set up, and - * right before the SSL handshake is going to be performed to secure the - * connection (only for amqps:// connections). This method can be overridden - * in user space to load client side certificates. - * @param[in] connection The connection for which TLS was just started - * @param[in] ssl Pointer to the SSL structure that can be - * modified - * @return bool True to proceed / accept the connection, false - * to break up + * @brief Total number of messages unacknowledged + * @return Number of messages unacknowledged */ - virtual bool onSecuring(AMQP::TcpConnection* connection, SSL* ssl) override - { - ERR_clear_error(); - unsigned long err; -#if OPENSSL_VERSION_NUMBER < 0x10100000L - int ret = SSL_use_certificate_file(ssl, _cacert.c_str(), SSL_FILETYPE_PEM); -#else - int ret = SSL_use_certificate_chain_file(ssl, _cacert.c_str()); -#endif - if (ret != 1) { - std::string error("openssl: error loading ca-chain (" + _cacert + - ") + from ["); - SSL_get_error(ssl, ret); - if ((err = ERR_get_error())) { - error += std::string(ERR_reason_error_string(err)); - } - error += "]"; - establish_connection.set_value(FAILED); - return false; - } else { - DBG(RMQPublisherHandler, - "Success logged with ca-chain %s", - _cacert.c_str()) - return true; - } - } + unsigned unacknowledged() const; /** - * @brief Method that is called when the secure TLS connection has been - * established. This is only called for amqps:// connections. It allows you to - * inspect whether the connection is secure enough for your liking (you can - * for example check the server certificate). The AMQP protocol still has - * to be started. - * @param[in] connection The connection that has been secured - * @param[in] ssl SSL structure from openssl library - * @return bool True if connection can be used + * @brief Flush the handler by waiting for all unacknowledged mesages. + * it will wait for a given amount of time until timeout. */ - virtual bool onSecured(AMQP::TcpConnection* connection, - const SSL* ssl) override - { - DBG(RMQPublisherHandler, - "[rank=%d] Secured TLS connection has been established.", - _rank) - return true; - } + void flush(); +private: /** * @brief Method that is called by the AMQP library when the login attempt * succeeded. After this the connection is ready to use. * @param[in] connection The connection that can now be used */ - virtual void onReady(AMQP::TcpConnection* connection) override - { - DBG(RMQPublisherHandler, - "[rank=%d] Sucessfuly logged in (connection %p). Connection ready to " - "use.", - _rank, - connection) - - _channel = std::make_shared(connection); - _channel->onError([&](const char* message) { - CFATAL(RMQPublisherHandler, - false, - "[rank=%d] Error on channel: %s", - _rank, - message) - }); - - _channel->declareQueue(_queue) - .onSuccess([&](const std::string& name, - uint32_t messagecount, - uint32_t consumercount) { - if (messagecount > 0 || consumercount > 1) { - CWARNING(RMQPublisherHandler, - _rank == 0, - "[rank=%d] declared queue: %s (messagecount=%d, " - "consumercount=%d)", - _rank, - _queue.c_str(), - messagecount, - consumercount) - } - // We can now instantiate the shared buffer between AMS and RMQ - DBG(RMQPublisherHandler, - "[rank=%d] declared queue: %s", - _rank, - _queue.c_str()) - _rchannel = - std::make_shared>(*_channel.get()); - establish_connection.set_value(CONNECTED); - }) - .onError([&](const char* message) { - CFATAL(RMQPublisherHandler, - false, - "[ERROR][rank=%d] Error while creating broker queue (%s): " - "%s", - _rank, - _queue.c_str(), - message) - establish_connection.set_value(FAILED); - }); - } - - /** - * Method that is called when the AMQP protocol is ended. This is the - * counter-part of a call to connection.close() to graceful shutdown - * the connection. Note that the TCP connection is at this time still - * active, and you will also receive calls to onLost() and onDetached() - * @param connection The connection over which the AMQP protocol ended - */ - virtual void onClosed(AMQP::TcpConnection* connection) override - { - DBG(RMQPublisherHandler, "[rank=%d] Connection is closed.", _rank) - } - - /** - * @brief Method that is called by the AMQP library when a fatal error occurs - * on the connection, for example because data received from RabbitMQ - * could not be recognized, or the underlying connection is lost. This - * call is normally followed by a call to onLost() (if the error occurred - * after the TCP connection was established) and onDetached(). - * @param[in] connection The connection on which the error occurred - * @param[in] message A human readable error message - */ - virtual void onError(AMQP::TcpConnection* connection, - const char* message) override - { - WARNING(RMQPublisherHandler, - "[rank=%d] fatal error on TCP connection: %s", - _rank, - message) - try { - _error_connection.set_value(ERROR); - } catch (const std::future_error& e) { - DBG(RMQPublisherHandler, "[rank=%d] future already set.", _rank) - } - } - - /** - * @brief Final method that is called. This signals that no further calls to your - * handler will be made about the connection. - * @param connection The connection that can be destructed - */ - virtual void onDetached(AMQP::TcpConnection* connection) override - { - // add your own implementation, like cleanup resources or exit the application - DBG(RMQPublisherHandler, "[rank=%d] Connection is detached.", _rank) - close_connection.set_value(CLOSED); - } - - bool waitFuture(std::future& future, - unsigned ms, - int repeat) - { - std::chrono::milliseconds span(ms); - int iters = 0; - std::future_status status; - while ((status = future.wait_for(span)) == std::future_status::timeout && - (iters++ < repeat)) - std::future established; - return status == std::future_status::ready; - } + virtual void onReady(AMQP::TcpConnection* connection) override; /** * @brief Free the data pointed pointer in a vector and update vector. * @param[in] addr Address of memory to free. * @param[in] buffer The vector containing memory buffers */ - void free_ams_message(int msg_id, std::vector& buf) - { - const std::lock_guard lock(_mutex); - auto it = - std::find_if(buf.begin(), buf.end(), [&msg_id](const AMSMessage& obj) { - return obj.id() == msg_id; - }); - CFATAL(RMQPublisherHandler, - it == buf.end(), - "Failed to deallocate msg #%d: not found", - msg_id) - auto& msg = *it; - auto& rm = ams::ResourceManager::getInstance(); - rm.deallocate(msg.data(), AMSResourceType::AMS_HOST); - - DBG(RMQPublisherHandler, "Deallocated msg #%d (%p)", msg.id(), msg.data()) - buf.erase(it); - } + void freeMessage(int msg_id, std::vector& buf); /** * @brief Free the data pointed by each pointer in a vector. * @param[in] buffer The vector containing memory buffers */ - void free_all_messages(std::vector& buffer) - { - const std::lock_guard lock(_mutex); - auto& rm = ams::ResourceManager::getInstance(); - for (auto& dp : buffer) { - DBG(RMQPublisherHandler, "deallocate msg #%d (%p)", dp.id(), dp.data()) - rm.deallocate(dp.data(), AMSResourceType::AMS_HOST); - } - buffer.clear(); - } + void freeAllMessages(std::vector& buffer); }; // class RMQPublisherHandler @@ -1877,6 +1411,8 @@ class RMQPublisher private: /** @brief Connection to the broker */ AMQP::TcpConnection* _connection; + /** @brief MPI rank (0 if no MPI support) */ + uint64_t _rId; /** @brief name of the queue to send data */ std::string _queue; /** @brief TLS certificate file */ @@ -1895,90 +1431,45 @@ class RMQPublisher RMQPublisher& operator=(const RMQPublisher&) = delete; RMQPublisher( + uint64_t rId, const AMQP::Address& address, std::string cacert, std::string queue, - std::vector&& msgs_to_send = std::vector()) - : _rank(0), - _queue(queue), - _cacert(cacert), - _handler(nullptr), - _buffer_msg(std::move(msgs_to_send)) - { -#ifdef __ENABLE_MPI__ - MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &_rank)); -#endif -#ifdef EVTHREAD_USE_PTHREADS_IMPLEMENTED - evthread_use_pthreads(); -#endif - CDEBUG(RMQPublisher, - _rank == 0, - "Libevent %s (LIBEVENT_VERSION_NUMBER = %#010x)", - event_get_version(), - event_get_version_number()); - CDEBUG(RMQPublisher, - _rank == 0, - "%s (OPENSSL_VERSION_NUMBER = %#010x)", - OPENSSL_VERSION_TEXT, - OPENSSL_VERSION_NUMBER); -#if OPENSSL_VERSION_NUMBER < 0x10100000L - SSL_library_init(); -#else - OPENSSL_init_ssl(0, NULL); -#endif - CINFO(RMQPublisher, - _rank == 0, - "RabbitMQ address: %s:%d/%s (queue = %s)", - address.hostname().c_str(), - address.port(), - address.vhost().c_str(), - _queue.c_str()) - - _loop = std::shared_ptr(event_base_new(), - [](struct event_base* event) { - event_base_free(event); - }); - - _handler = std::make_shared(_loop, _cacert, _queue); - _connection = new AMQP::TcpConnection(_handler.get(), address); - } + std::vector&& msgs_to_send = std::vector()); /** * @brief Check if the underlying RabbitMQ connection is ready and usable * @return True if the publisher is ready to publish */ - bool ready_publish() { return _connection->ready() && _connection->usable(); } + bool ready_publish(); /** * @brief Wait that the connection is ready (blocking call) * @return True if the publisher is ready to publish */ - bool waitToEstablish(unsigned ms, int repeat = 1) - { - return _handler->waitToEstablish(ms, repeat); - } + bool waitToEstablish(unsigned ms, int repeat = 1); /** * @brief Return the number of unacknowledged messages * @return Number of unacknowledged messages */ - unsigned unacknowledged() const { return _handler->unacknowledged(); } + unsigned unacknowledged() const; /** * @brief Start the underlying I/O loop (blocking call) */ - void start() { event_base_dispatch(_loop.get()); } + void start(); /** * @brief Stop the underlying I/O loop */ - void stop() { event_base_loopexit(_loop.get(), NULL); } + void stop(); /** * @brief Check if the underlying connection has no errors * @return True if no errors */ - bool connection_valid() { return _handler->connection_valid(); } + bool connectionValid(); /** * @brief Return the messages that have not been acknowledged. @@ -1986,57 +1477,33 @@ class RMQPublisher * acknowledgements have not arrived yet. * @return A vector of AMSMessage */ - std::vector& get_buffer_msgs() - { - return _handler->internal_msg_buffer(); - } + std::vector& getMsgBuffer(); /** * @brief Total number of messages successfully acknowledged * @return Number of messages */ - void cleanup() { _handler->cleanup(); } - - void publish(AMSMessage&& message) - { - // We have some messages to send first (from a potential restart) - if (_buffer_msg.size() > 0) { - for (auto& msg : _buffer_msg) { - DBG(RMQPublisher, - "Publishing backed up message %d: %p", - msg.id(), - msg.data()) - _handler->publish(std::move(msg)); - } - _buffer_msg.clear(); - } + void cleanup(); - DBG(RMQPublisher, "Publishing message %d: %p", message.id(), message.data()) - _handler->publish(std::move(message)); - } + void publish(AMSMessage&& message); /** * @brief Total number of messages sent * @return Number of messages */ - int msg_sent() const { return _handler->msg_sent(); } + int msgSent() const; /** * @brief Total number of messages successfully acknowledged * @return Number of messages */ - int msg_acknowledged() const { return _handler->msg_acknowledged(); } + int msgAcknowledged() const; /** * @brief Total number of messages successfully acknowledged * @return Number of messages */ - bool close(unsigned ms, int repeat = 1) - { - _handler->flush(); - _connection->close(false); - return _handler->waitToClose(ms, repeat); - } + bool close(unsigned ms, int repeat = 1); ~RMQPublisher() = default; @@ -2061,8 +1528,9 @@ class RMQPublisher * "service-port": 31495, * "service-host": "url.czapps.llnl.gov", * "rabbitmq-cert": "tls-cert.crt", - * "rabbitmq-inbound-queue": "test4", - * "rabbitmq-outbound-queue": "test3" + * "rabbitmq-outbound-queue": "test3", + * "rabbitmq-exchange": "ams-fanout", + * "rabbitmq-routing-key": "training" * } * * The TLS certificate must be generated by the user and the absolute paths are preferred. @@ -2099,10 +1567,14 @@ class RMQInterface private: /** @brief Path of the config file (JSON) */ std::string _config; + /** @brief MPI rank (0 if no MPI support) */ + uint64_t _rId; /** @brief name of the queue to send data */ std::string _queue_sender; - /** @brief name of the queue to receive data */ - std::string _queue_receiver; + /** @brief name of the exchange to receive data */ + std::string _exchange; + /** @brief name of the routing key to receive data */ + std::string _routing_key; /** @brief Address of the RabbitMQ server */ std::shared_ptr _address; /** @brief TLS certificate path */ @@ -2117,11 +1589,27 @@ class RMQInterface std::shared_ptr _consumer; /** @brief Thread in charge of the consumer */ std::thread _consumer_thread; - + /** @brief True if connected to RabbitMQ */ bool connected; public: - RMQInterface() : connected(false) {} + RMQInterface() : connected(false), _rId(0) {} + + /** + * @brief Connect to a RabbitMQ server + * @param[in] rmq_name The name of the RabbitMQ server + * @param[in] rmq_name The name of the RabbitMQ server + * @param[in] rmq_password The password + * @param[in] rmq_user Username + * @param[in] rmq_vhost Virtual host (by default RabbitMQ vhost = '/') + * @param[in] service_port The port number + * @param[in] service_host URL of RabbitMQ server + * @param[in] rmq_cert Path to TLS certificate + * @param[in] outbound_queue Name of the queue on which AMSlib publishes (send) messages + * @param[in] exchange Exchange for incoming messages + * @param[in] routing_key Routing key for incoming messages (must match what the AMS Python side is using) + * @return True if connection succeeded + */ bool connect(std::string rmq_name, std::string rmq_password, std::string rmq_user, @@ -2129,75 +1617,41 @@ class RMQInterface int service_port, std::string service_host, std::string rmq_cert, - std::string inbouund_queue, - std::string outbound_queue) - { - _queue_sender = outbound_queue; - _queue_receiver = inbouund_queue; - _cacert = rmq_cert; - - AMQP::Login login(rmq_user, rmq_password); - _address = std::make_shared(service_host, - service_port, - login, - rmq_vhost, - /*is_secure*/ true); - _publisher = - std::make_shared(*_address, _cacert, _queue_sender); - - _publisher_thread = std::thread([&]() { _publisher->start(); }); - - if (!_publisher->waitToEstablish(100, 10)) { - _publisher->stop(); - _publisher_thread.join(); - FATAL(RabbitMQInterface, "Could not establish connection"); - } - - connected = true; - return connected; - } + std::string outbound_queue, + std::string exchange, + std::string routing_key); + /** + * @brief Check if the RabbitMQ connection is connected. + * @return True if connected + */ bool isConnected() const { return connected; } - void restart(int rank) - { - std::vector messages = _publisher->get_buffer_msgs(); - - AMSMessage& msg_min = - *(std::min_element(messages.begin(), - messages.end(), - [](const AMSMessage& a, const AMSMessage& b) { - return a.id() < b.id(); - })); - - DBG(RMQPublisher, - "[rank=%d] we have %lu buffered messages that will get re-send " - "(starting from msg #%d).", - rank, - messages.size(), - msg_min.id()) - - // Stop the faulty publisher - _publisher->stop(); - _publisher_thread.join(); - _publisher.reset(); - connected = false; - - _publisher = std::make_shared(*_address, - _cacert, - _queue_sender, - std::move(messages)); - _publisher_thread = std::thread([&]() { _publisher->start(); }); - connected = true; - } + /** + * @brief Set the internal ID of the interface (usually MPI rank). + * @param[in] id The ID + */ + void setId(uint64_t id) { _rId = id; } + /** + * @brief Try to restart the RabbitMQ publisher (restart the thread managing messages publishing) + */ + void restartPublisher(); + + /** + * @brief Return the latest model and, by default, delete the corresponding message from the Consumer + * @param[in] domain_name The name of the domain + * @param[in] num_elements The number of elements for inputs/outputs + * @param[in] inputs A vector containing arrays of inputs, each array has num_elements elements + * @param[in] outputs A vector containing arrays of outputs, each array has num_elements elements + */ template void publish(std::string& domain_name, size_t num_elements, std::vector& inputs, std::vector& outputs) { - DBG(RabbitMQDB, + DBG(RMQInterface, "[tag=%d] stores %ld elements of input/output " "dimensions (%ld, %ld)", _msg_tag, @@ -2205,40 +1659,58 @@ class RMQInterface inputs.size(), outputs.size()) - AMSMessage msg(_msg_tag, domain_name, num_elements, inputs, outputs); + AMSMessage msg(_msg_tag, _rId, domain_name, num_elements, inputs, outputs); - if (!_publisher->connection_valid()) { - restart(msg._rank); + if (!_publisher->connectionValid()) { + connected = false; + restartPublisher(); bool status = _publisher->waitToEstablish(100, 10); if (!status) { _publisher->stop(); _publisher_thread.join(); - FATAL(RabbitMQDB, "Could not establish connection"); + FATAL(RMQInterface, + "Could not establish publisher RabbitMQ connection"); } + connected = true; } _publisher->publish(std::move(msg)); _msg_tag++; } - void close() + /** + * @brief Close the underlying connection + */ + void close(); + + /** + * @brief Check if a new ML model is available + * @return True if there is a valid ML model + */ + bool updateModel() + { + // NOTE: The architecture here is not great for now, we have redundant call to getLatestModel + // Solution: when switching to C++ use std::variant to return an std::optional + // the std::optional would be a string if a model is available otherwise it's a bool false + auto data = _consumer->getLatestModel(); + return !std::get<1>(data).empty(); + } + + /** + * @brief Return the latest model and, by default, delete the corresponding message from the Consumer + * @param[in] remove_msg if True, delete the message corresponding to the model + * @return The Path of the new model + */ + std::string getLatestModel(bool remove_msg = true) { - if (!_publisher_thread.joinable()) { - return; + auto res = _consumer->getLatestModel(); + bool empty = std::get<1>(res).empty(); + if (remove_msg && !empty) { + auto id = std::get<0>(res); + _consumer->delMessage(id); } - bool status = _publisher->close(100, 10); - CWARNING(RabbitMQDB, !status, "Could not gracefully close TCP connection") - DBG(RabbitMQInterface, "Number of messages sent: %d", _msg_tag) - DBG(RabbitMQInterface, - "Number of unacknowledged messages are %d", - _publisher->unacknowledged()) - _publisher->stop(); - //_consumer->stop(); - _publisher_thread.join(); - //_consumer_thread.join(); - connected = false; + return std::get<1>(res); } - ~RMQInterface() { if (connected) close(); @@ -2252,19 +1724,32 @@ class RMQInterface class RabbitMQDB final : public BaseDB { private: - /** @brief the application domain that stores the data*/ + /** @brief the application domain that stores the data */ std::string appDomain; - - /** An interface to RMQ to push the data to*/ + /** @brief An interface to RMQ to push the data to */ RMQInterface& interface; public: RabbitMQDB(const RabbitMQDB&) = delete; RabbitMQDB& operator=(const RabbitMQDB&) = delete; - RabbitMQDB(RMQInterface& interface, std::string& domain, uint64_t id) - : BaseDB(id), appDomain(domain), interface(interface) - { + RabbitMQDB(RMQInterface& interface, + std::string& domain, + uint64_t id, + bool allowModelUpdate) + : BaseDB(id, allowModelUpdate), appDomain(domain), interface(interface) + { + /* We set manually the MPI rank here because when + * RMQInterface was statically initialized, MPI was not + * necessarily initialized and ready. So we provide the + * option of setting the distributed ID afterward. + * + * Note: this ID is encoded into AMSMessage but for + * logging we use a randomly generated ID to stay + * consistent over time (some logging could happen + * before setId is called). + */ + interface.setId(id); } /** @@ -2275,6 +1760,7 @@ class RabbitMQDB final : public BaseDB * @param[in] inputs Vector of 1-D vectors containing the inputs to be sent * @param[in] outputs Vector of 1-D vectors, each 1-D vectors contains * 'num_elements' values to be sent + * @param[in] predicate (NOT SUPPORTED YET) Series of predicate */ PERFFASPECT() void store(size_t num_elements, @@ -2285,7 +1771,6 @@ class RabbitMQDB final : public BaseDB CFATAL(RMQDB, predicate != nullptr, "RMQ database does not support storing uq-predicates") - interface.publish(appDomain, num_elements, inputs, outputs); } @@ -2297,18 +1782,28 @@ class RabbitMQDB final : public BaseDB CFATAL(RMQDB, predicate != nullptr, "RMQ database does not support storing uq-predicates") - interface.publish(appDomain, num_elements, inputs, outputs); } - void restart() {} - /** * @brief Return the type of this broker * @return The type of the broker */ std::string type() override { return "rabbitmq"; } + /** + * @brief Check if the surrogate model can be updated (i.e., if + * RMQConsumer received a training message) + * @return True if the model can be updated + */ + bool updateModel() { return interface.updateModel(); } + + /** + * @brief Return the path of the latest surrogate model if available + * @return The path of the latest available surrogate model + */ + std::string getLatestModel() { return interface.getLatestModel(); } + /** * @brief Return the DB enumerationt type (File, Redis etc) */ @@ -2327,7 +1822,7 @@ class RMQInterface RMQInterface() : connected(false) {} bool connect() { - FATAL(RMQInterface, "RMQ Disabled yet we are Requesting to connect") + FATAL(RMQInterface, "RMQ Disabled yet we are requesting to connect") return false; } @@ -2388,8 +1883,11 @@ class DBManager private: std::unordered_map> db_instances; AMSDBType dbType; + uint64_t rId; + /** @brief If True, the DB is allowed to update the surrogate model */ + bool updateSurrogate; - DBManager() : dbType(AMSDBType::AMS_NONE){}; + DBManager() : dbType(AMSDBType::AMS_NONE), updateSurrogate(false){}; protected: RMQInterface rmq_interface; @@ -2402,7 +1900,6 @@ class DBManager return instance; } -public: ~DBManager() { for (auto& e : db_instances) { @@ -2471,7 +1968,10 @@ class DBManager #endif #ifdef __ENABLE_RMQ__ case AMSDBType::AMS_RMQ: - return std::make_shared(rmq_interface, domainName, rId); + return std::make_shared(rmq_interface, + domainName, + rId, + updateSurrogate); #endif default: return nullptr; @@ -2480,7 +1980,6 @@ class DBManager return nullptr; } - /** * @brief get a data base object referred by this string. * This should never be used for large scale simulations as txt/csv format will @@ -2562,8 +2061,10 @@ class DBManager std::string& rmq_user, std::string& rmq_vhost, std::string& rmq_cert, - std::string& inbouund_queue, - std::string& outbound_queue) + std::string& outbound_queue, + std::string& exchange, + std::string& routing_key, + bool update_surrogate) { fs::path Path(rmq_cert); std::error_code ec; @@ -2572,8 +2073,7 @@ class DBManager "Certificate file '%s' for RMQ server does not exist", rmq_cert.c_str()); dbType = AMSDBType::AMS_RMQ; - - + updateSurrogate = update_surrogate; #ifdef __ENABLE_RMQ__ rmq_interface.connect(rmq_name, rmq_pass, @@ -2582,8 +2082,9 @@ class DBManager port, host, rmq_cert, - inbouund_queue, - outbound_queue); + outbound_queue, + exchange, + routing_key); #else FATAL(DBManager, "Requsted RMQ database but AMS is not built with such support " diff --git a/src/AMSlib/wf/rmqdb.cpp b/src/AMSlib/wf/rmqdb.cpp new file mode 100644 index 00000000..a33cb3b6 --- /dev/null +++ b/src/AMSlib/wf/rmqdb.cpp @@ -0,0 +1,1032 @@ +/* + * Copyright 2021-2023 Lawrence Livermore National Security, LLC and other + * AMSLib Project Developers + * + * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + */ + +#include "wf/basedb.hpp" + +using namespace ams::db; + +/** + * AMSMsgHeader + */ + +AMSMsgHeader::AMSMsgHeader(size_t mpi_rank, + size_t domain_size, + size_t num_elem, + size_t in_dim, + size_t out_dim, + size_t type_size) + : hsize(static_cast(AMSMsgHeader::size())), + dtype(static_cast(type_size)), + mpi_rank(static_cast(mpi_rank)), + domain_size(static_cast(domain_size)), + num_elem(static_cast(num_elem)), + in_dim(static_cast(in_dim)), + out_dim(static_cast(out_dim)) +{ +} + +AMSMsgHeader::AMSMsgHeader(uint16_t mpi_rank, + uint16_t domain_size, + uint32_t num_elem, + uint16_t in_dim, + uint16_t out_dim, + uint8_t type_size) + : hsize(static_cast(AMSMsgHeader::size())), + dtype(type_size), + mpi_rank(mpi_rank), + domain_size(domain_size), + num_elem(num_elem), + in_dim(in_dim), + out_dim(out_dim) +{ +} + +size_t AMSMsgHeader::encode(uint8_t* data_blob) +{ + if (!data_blob) return 0; + + size_t current_offset = 0; + // Header size (should be 1 bytes) + data_blob[current_offset] = hsize; + current_offset += sizeof(hsize); + // Data type (should be 1 bytes) + data_blob[current_offset] = dtype; + current_offset += sizeof(dtype); + // MPI rank (should be 2 bytes) + std::memcpy(data_blob + current_offset, &(mpi_rank), sizeof(mpi_rank)); + current_offset += sizeof(mpi_rank); + // Domain Size (should be 2 bytes) + DBG(AMSMsgHeader, + "Generating domain name of size %d --- %d offset %d", + domain_size, + sizeof(domain_size), + current_offset); + std::memcpy(data_blob + current_offset, &(domain_size), sizeof(domain_size)); + current_offset += sizeof(domain_size); + // Num elem (should be 4 bytes) + std::memcpy(data_blob + current_offset, &(num_elem), sizeof(num_elem)); + current_offset += sizeof(num_elem); + // Input dim (should be 2 bytes) + std::memcpy(data_blob + current_offset, &(in_dim), sizeof(in_dim)); + current_offset += sizeof(in_dim); + // Output dim (should be 2 bytes) + std::memcpy(data_blob + current_offset, &(out_dim), sizeof(out_dim)); + current_offset += sizeof(out_dim); + + return AMSMsgHeader::size(); +} + +AMSMsgHeader AMSMsgHeader::decode(uint8_t* data_blob) +{ + size_t current_offset = 0; + // Header size (should be 1 bytes) + uint8_t new_hsize = data_blob[current_offset]; + CWARNING(AMSMsgHeader, + new_hsize != AMSMsgHeader::size(), + "buffer is likely not a valid AMSMessage (%d / %ld)", + new_hsize, + current_offset) + + current_offset += sizeof(uint8_t); + // Data type (should be 1 bytes) + uint8_t new_dtype = data_blob[current_offset]; + current_offset += sizeof(uint8_t); + // MPI rank (should be 2 bytes) + uint16_t new_mpirank = + (reinterpret_cast(data_blob + current_offset))[0]; + current_offset += sizeof(uint16_t); + + // Domain Size (should be 2 bytes) + uint16_t new_domain_size = + (reinterpret_cast(data_blob + current_offset))[0]; + current_offset += sizeof(uint16_t); + + // Num elem (should be 4 bytes) + uint32_t new_num_elem; + std::memcpy(&new_num_elem, data_blob + current_offset, sizeof(uint32_t)); + current_offset += sizeof(uint32_t); + // Input dim (should be 2 bytes) + uint16_t new_in_dim; + std::memcpy(&new_in_dim, data_blob + current_offset, sizeof(uint16_t)); + current_offset += sizeof(uint16_t); + // Output dim (should be 2 bytes) + uint16_t new_out_dim; + std::memcpy(&new_out_dim, data_blob + current_offset, sizeof(uint16_t)); + + return AMSMsgHeader(new_mpirank, + new_domain_size, + new_num_elem, + new_in_dim, + new_out_dim, + new_dtype); +} + +/** + * AMSMessage + */ + +void AMSMessage::swap(const AMSMessage& other) +{ + _id = other._id; + _rank = other._rank; + _num_elements = other._num_elements; + _input_dim = other._input_dim; + _output_dim = other._output_dim; + _total_size = other._total_size; + _data = other._data; +} + +AMSMessage::AMSMessage(int id, uint64_t rId, uint8_t* data) + : _id(id), + _num_elements(0), + _input_dim(0), + _output_dim(0), + _data(data), + _total_size(0) +{ + auto header = AMSMsgHeader::decode(data); + + int current_rank = rId; + _rank = header.mpi_rank; + CWARNING(AMSMessage, + _rank != current_rank, + "MPI rank are not matching (using %d)", + _rank) + + _num_elements = header.num_elem; + _input_dim = header.in_dim; + _output_dim = header.out_dim; + _data = data; + auto type_value = header.dtype; + + _total_size = AMSMsgHeader::size() + getTotalElements() * type_value; + + DBG(AMSMessage, "Allocated message %d: %p", _id, _data); +} + +/** + * AMSMessageInbound + */ + +AMSMessageInbound::AMSMessageInbound(uint64_t id, + uint64_t rId, + std::string body, + std::string exchange, + std::string routing_key, + bool redelivered) + : id(id), + rId(rId), + body(std::move(body)), + exchange(std::move(exchange)), + routing_key(std::move(routing_key)), + redelivered(redelivered){}; + + +bool AMSMessageInbound::empty() { return body.empty() || routing_key.empty(); } + +bool AMSMessageInbound::isTraining() +{ + auto split = splitString(body, ":"); + return split[0] == "UPDATE"; +} + +std::string AMSMessageInbound::getModelPath() +{ + auto split = splitString(body, ":"); + if (split[0] == "UPDATE") { + return split[1]; + } + return {}; +} + +std::vector AMSMessageInbound::splitString(std::string str, + std::string delimiter) +{ + size_t pos = 0; + std::string token; + std::vector res; + while ((pos = str.find(delimiter)) != std::string::npos) { + token = str.substr(0, pos); + res.push_back(token); + str.erase(0, pos + delimiter.length()); + } + res.push_back(str); + return res; +} + +/** + * RMQHandler + */ + +RMQHandler::RMQHandler(uint64_t rId, + std::shared_ptr loop, + std::string cacert) + : AMQP::LibEventHandler(loop.get()), + _rId(rId), + _loop(loop), + _cacert(std::move(cacert)) +{ + established = establish_connection.get_future(); + closed = close_connection.get_future(); + ftr_error = error_connection.get_future(); +} + +bool RMQHandler::waitToEstablish(unsigned ms, int repeat) +{ + if (waitFuture(established, ms, repeat)) { + auto status = established.get(); + DBG(RMQHandler, "Connection Status: %d", status); + return status == CONNECTED; + } + return false; +} + +bool RMQHandler::waitToClose(unsigned ms, int repeat) +{ + if (waitFuture(closed, ms, repeat)) { + return closed.get() == CLOSED; + } + return false; +} + +bool RMQHandler::connectionValid() +{ + std::chrono::milliseconds span(1); + return ftr_error.wait_for(span) != std::future_status::ready; +} + +bool RMQHandler::onSecuring(AMQP::TcpConnection* connection, SSL* ssl) +{ + ERR_clear_error(); + unsigned long err; +#if OPENSSL_VERSION_NUMBER < 0x10100000L + int ret = SSL_use_certificate_file(ssl, _cacert.c_str(), SSL_FILETYPE_PEM); +#else + int ret = SSL_use_certificate_chain_file(ssl, _cacert.c_str()); +#endif + if (ret != 1) { + std::string error("openssl: error loading ca-chain (" + _cacert + + ") + from ["); + SSL_get_error(ssl, ret); + if ((err = ERR_get_error())) { + error += std::string(ERR_reason_error_string(err)); + } + error += "]"; + establish_connection.set_value(FAILED); + return false; + } else { + DBG(RMQHandler, "Success logged with ca-chain %s", _cacert.c_str()) + return true; + } +} + +bool RMQHandler::onSecured(AMQP::TcpConnection* connection, const SSL* ssl) +{ + DBG(RMQHandler, "[r%d] Secured TLS connection has been established.", _rId) + return true; +} + +void RMQHandler::onClosed(AMQP::TcpConnection* connection) +{ + DBG(RMQHandler, "[r%d] Connection is closed.", _rId) +} + +void RMQHandler::onError(AMQP::TcpConnection* connection, const char* message) +{ + WARNING(RMQHandler, "[r%d] fatal error on TCP connection: %s", _rId, message) + try { + error_connection.set_value(ERROR); + } catch (const std::future_error& e) { + DBG(RMQHandler, "[r%d] future already set.", _rId) + } +} + +void RMQHandler::onDetached(AMQP::TcpConnection* connection) +{ + DBG(RMQHandler, "[r%d] Connection is detached.", _rId) + close_connection.set_value(CLOSED); +} + +bool RMQHandler::waitFuture(std::future& future, + unsigned ms, + int repeat) +{ + std::chrono::milliseconds span(ms); + int iters = 0; + std::future_status status; + while ((status = future.wait_for(span)) == std::future_status::timeout && + (iters++ < repeat)) + std::future established; + return status == std::future_status::ready; +} + +/** + * RMQConsumerHandler + */ + +RMQConsumerHandler::RMQConsumerHandler(uint64_t rId, + std::shared_ptr loop, + std::string cacert, + std::string exchange, + std::string routing_key, + AMQP::ExchangeType extype) + : RMQHandler(rId, loop, cacert), + _exchange(exchange), + _extype(extype), + _routing_key(routing_key), + _messages(std::make_shared>()), + _channel(nullptr) +{ +} + +std::tuple RMQConsumerHandler::getLatestModel() +{ + std::string model = ""; + uint64_t latest_tag = 0; + for (AMSMessageInbound& e : *_messages) { + if (latest_tag < e.id) { + model = e.getModelPath(); + latest_tag = e.id; + } + } + return std::make_tuple(latest_tag, model); +} + +AMSMessageInbound RMQConsumerHandler::popMessages() +{ + if (!_messages->empty()) { + AMSMessageInbound msg = _messages->back(); + _messages->pop_back(); + return msg; + } + return AMSMessageInbound(); +} + +AMSMessageInbound RMQConsumerHandler::getMessages(uint64_t delivery_tag, + bool erase) +{ + if (!_messages->empty()) { + auto it = std::find_if(_messages->begin(), + _messages->end(), + [&delivery_tag](const AMSMessageInbound& e) { + return e.id == delivery_tag; + }); + if (it != _messages->end()) { + AMSMessageInbound msg(std::move(*it)); + if (erase) _messages->erase(it); + return msg; + } + } + return AMSMessageInbound(); +} + +void RMQConsumerHandler::onReady(AMQP::TcpConnection* connection) +{ + DBG(RMQConsumerHandler, + "[r%d] Sucessfuly logged in. Connection ready to use.", + _rId) + + _channel = std::make_shared(connection); + _channel->onError([&](const char* message) { + WARNING(RMQConsumerHandler, "[r%d] Error on channel: %s", _rId, message) + establish_connection.set_value(FAILED); + }); + + // The exchange will be deleted once all bound queues are removed + _channel->declareExchange(_exchange, _extype, AMQP::autodelete) + .onSuccess([&, this]() { + DBG(RMQConsumerHandler, + "[r%d] declared exchange %s (type: %d)", + _rId, + _exchange.c_str(), + _extype) + establish_connection.set_value(CONNECTED); + _channel->declareQueue(AMQP::exclusive) + .onSuccess([&, this](const std::string& name, + uint32_t messagecount, + uint32_t consumercount) { + DBG(RMQConsumerHandler, + "[r%d] declared queue: %s (messagecount=%d, " + "consumercount=%d)", + _rId, + name.c_str(), + messagecount, + consumercount) + _channel->bindQueue(_exchange, name, _routing_key) + .onSuccess([&, name, this]() { + DBG(RMQConsumerHandler, + "[r%d] Bounded queue %s to exchange %s with " + "routing key = %s", + _rId, + name.c_str(), + _exchange.c_str(), + _routing_key.c_str()) + + // We can now install callback functions for when we will consumme messages + // callback function that is called when the consume operation starts + auto startCb = [&](const std::string& consumertag) { + DBG(RMQConsumerHandler, + "[r%d] consume operation started with tag: %s", + _rId, + consumertag.c_str()) + }; + + // callback function that is called when the consume operation failed + auto errorCb = [&](const char* message) { + WARNING(RMQConsumerHandler, + "[r%d] consume operation failed: %s", + _rId, + message); + }; + // callback operation when a message was received + auto messageCb = [&](const AMQP::Message& message, + uint64_t deliveryTag, + bool redelivered) { + // acknowledge the message + _channel->ack(deliveryTag); + // _on_message_received(message, deliveryTag, redelivered); + std::string msg(message.body(), message.bodySize()); + DBG(RMQConsumerHandler, + "[r%d] message received [tag=%d] : '%s' of size " + "%d B from " + "'%s'/'%s'", + _rId, + deliveryTag, + msg.c_str(), + message.bodySize(), + message.exchange().c_str(), + message.routingkey().c_str()) + _messages->push_back( + AMSMessageInbound(deliveryTag, + _rId, + msg, + message.exchange(), + message.routingkey(), + redelivered)); + }; + + /* callback that is called when the consumer is cancelled by RabbitMQ (this + * only happens in rare situations, for example when someone removes the queue + * that you are consuming from) + */ + auto cancelledCb = [&](const std::string& consumertag) { + WARNING(RMQConsumerHandler, + "[r%d] consume operation cancelled by the " + "RabbitMQ server: %s", + _rId, + consumertag.c_str()) + }; + + DBG(RMQConsumerHandler, + "[r%d] starting consume operation", + _rId) + + // start consuming from the queue, and install the callbacks + _channel->consume(name) + .onReceived(std::move(messageCb)) + .onSuccess(std::move(startCb)) + .onCancelled(std::move(cancelledCb)) + .onError(std::move(errorCb)); + }) //consume + .onError([&](const char* message) { + WARNING(RMQConsumerHandler, + "[r%d] creating queue: %s", + _rId, + message) + establish_connection.set_value(FAILED); + }); //consume + }) //bindQueue + .onError([&](const char* message) { + WARNING(RMQConsumerHandler, + "[r%d] failed to bind queue to exchange: %s", + _rId, + message) + }); //bindQueue + }) //declareExchange + .onError([&](const char* message) { + WARNING(RMQConsumerHandler, + "[r%d] failed to create exchange: %s", + _rId, + message) + establish_connection.set_value(FAILED); + }); //declareExchange +} + +/** + * RMQConsumer + */ + +RMQConsumer::RMQConsumer(uint64_t rId, + const AMQP::Address& address, + std::string cacert, + std::string exchange, + std::string routing_key) + : _rId(rId), + _cacert(cacert), + _routing_key(routing_key), + _exchange(exchange), + _handler(nullptr) +{ +#ifdef EVTHREAD_USE_PTHREADS_IMPLEMENTED + evthread_use_pthreads(); +#endif + DBG(RMQConsumer, + "Libevent %s (LIBEVENT_VERSION_NUMBER = %#010x)", + event_get_version(), + event_get_version_number()); + DBG(RMQConsumer, + "%s (OPENSSL_VERSION_NUMBER = %#010x)", + OPENSSL_VERSION_TEXT, + OPENSSL_VERSION_NUMBER); +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSL_library_init(); +#else + OPENSSL_init_ssl(0, NULL); +#endif + DBG(RMQConsumer, + "RabbitMQ address: %s:%d/%s (exchange = %s / routing key = %s)", + address.hostname().c_str(), + address.port(), + address.vhost().c_str(), + _exchange.c_str(), + _routing_key.c_str()) + + _loop = std::shared_ptr(event_base_new(), + [](struct event_base* event) { + event_base_free(event); + }); + _handler = std::make_shared( + rId, _loop, _cacert, _exchange, _routing_key, AMQP::fanout); + _connection = new AMQP::TcpConnection(_handler.get(), address); +} + +void RMQConsumer::start() { event_base_dispatch(_loop.get()); } + +void RMQConsumer::stop() { event_base_loopexit(_loop.get(), NULL); } + +bool RMQConsumer::ready() +{ + return _connection->ready() && _connection->usable(); +} + +bool RMQConsumer::waitToEstablish(unsigned ms, int repeat) +{ + return _handler->waitToEstablish(ms, repeat); +} + +AMSMessageInbound RMQConsumer::popMessages() +{ + return _handler->popMessages(); +}; + +void RMQConsumer::delMessage(uint64_t delivery_tag) +{ + _handler->delMessage(delivery_tag); +} + +AMSMessageInbound RMQConsumer::getMessages(uint64_t delivery_tag, bool erase) +{ + return _handler->getMessages(delivery_tag, erase); +} + +std::tuple RMQConsumer::getLatestModel() +{ + return _handler->getLatestModel(); +} + +bool RMQConsumer::close(unsigned ms, int repeat) +{ + _connection->close(false); + return _handler->waitToClose(ms, repeat); +} + +RMQConsumer::~RMQConsumer() +{ + _connection->close(false); + delete _connection; +} + +/** + * RMQPublisherHandler + */ + +RMQPublisherHandler::RMQPublisherHandler( + uint64_t rId, + std::shared_ptr loop, + std::string cacert, + std::string queue) + : RMQHandler(rId, loop, cacert), + _queue(queue), + _nb_msg_ack(0), + _nb_msg(0), + _channel(nullptr), + _rchannel(nullptr) +{ +} + +/** + * @brief Return the messages that have NOT been acknowledged by the RabbitMQ server. + * @return A vector of AMSMessage + */ +std::vector& RMQPublisherHandler::msgBuffer() { return _messages; } + +/** + * @brief Free AMSMessages held by the handler + */ +void RMQPublisherHandler::cleanup() { freeAllMessages(_messages); } + +/** + * @brief Total number of messages sent + * @return Number of messages + */ +int RMQPublisherHandler::msgSent() const { return _nb_msg; } + +/** + * @brief Total number of messages successfully acknowledged + * @return Number of messages + */ +int RMQPublisherHandler::msgAcknowledged() const { return _nb_msg_ack; } + +/** + * @brief Total number of messages unacknowledged + * @return Number of messages unacknowledged + */ +unsigned RMQPublisherHandler::unacknowledged() const +{ + return _rchannel->unacknowledged(); +} + +void RMQPublisherHandler::publish(AMSMessage&& msg) +{ + { + const std::lock_guard lock(_mutex); + _messages.push_back(msg); + } + if (_rchannel) { + // publish a message via the reliable-channel + // onAck : message has been explicitly ack'ed by RabbitMQ + // onNack : message has been explicitly nack'ed by RabbitMQ + // onError : error occurred before any ack or nack was received + // onLost : messages that have either been nack'ed, or lost + _rchannel + ->publish("", _queue, reinterpret_cast(msg.data()), msg.size()) + .onAck([this, + &_nb_msg_ack = _nb_msg_ack, + id = msg.id(), + data = msg.data(), + &_messages = this->_messages]() mutable { + DBG(RMQPublisherHandler, + "[r%d] message #%d (Addr:%p) got acknowledged " + "successfully " + "by " + "RMQ " + "server", + _rId, + id, + data) + this->freeMessage(id, _messages); + _nb_msg_ack++; + }) + .onNack([this, id = msg.id(), data = msg.data()]() mutable { + WARNING(RMQPublisherHandler, + "[r%d] message #%d (%p) received negative " + "acknowledged " + "by " + "RMQ " + "server", + _rId, + id, + data) + }) + .onError([this, id = msg.id(), data = msg.data()]( + const char* err_message) mutable { + WARNING(RMQPublisherHandler, + "[r%d] message #%d (%p) did not get send: %s", + _rId, + id, + data, + err_message) + }); + } else { + WARNING(RMQPublisherHandler, + "[r%d] The reliable channel was not ready for message #%d.", + _rId, + msg.id()) + } + _nb_msg++; +} + +void RMQPublisherHandler::onReady(AMQP::TcpConnection* connection) +{ + DBG(RMQPublisherHandler, + "[r%d] Sucessfuly logged in (connection %p). Connection ready to " + "use.", + _rId, + connection) + + _channel = std::make_shared(connection); + _channel->onError([&](const char* message) { + CFATAL( + RMQPublisherHandler, false, "[r%d] Error on channel: %s", _rId, message) + }); + + _channel->declareQueue(_queue) + .onSuccess([&](const std::string& name, + uint32_t messagecount, + uint32_t consumercount) { + DBG(RMQPublisherHandler, + "[r%d] declared queue: %s (messagecount=%d, " + "consumercount=%d)", + _rId, + _queue.c_str(), + messagecount, + consumercount) + // We can now instantiate the shared buffer between AMS and RMQ + _rchannel = + std::make_shared>(*_channel.get()); + establish_connection.set_value(CONNECTED); + }) + .onError([&](const char* message) { + CFATAL(RMQPublisherHandler, + false, + "[r%d] Error while creating broker queue (%s): " + "%s", + _rId, + _queue.c_str(), + message) + establish_connection.set_value(FAILED); + }); +} + +void RMQPublisherHandler::freeMessage(int msg_id, std::vector& buf) +{ + const std::lock_guard lock(_mutex); + auto it = + std::find_if(buf.begin(), buf.end(), [&msg_id](const AMSMessage& obj) { + return obj.id() == msg_id; + }); + CFATAL(RMQPublisherHandler, + it == buf.end(), + "Failed to deallocate msg #%d: not found", + msg_id) + auto& msg = *it; + auto& rm = ams::ResourceManager::getInstance(); + rm.deallocate(msg.data(), AMSResourceType::AMS_HOST); + + DBG(RMQPublisherHandler, "Deallocated msg #%d (%p)", msg.id(), msg.data()) + buf.erase(it); +} + +void RMQPublisherHandler::freeAllMessages(std::vector& buffer) +{ + const std::lock_guard lock(_mutex); + auto& rm = ams::ResourceManager::getInstance(); + for (auto& dp : buffer) { + DBG(RMQPublisherHandler, "deallocate msg #%d (%p)", dp.id(), dp.data()) + rm.deallocate(dp.data(), AMSResourceType::AMS_HOST); + } + buffer.clear(); +} + +void RMQPublisherHandler::flush() +{ + uint32_t tries = 0; + while (auto unAck = unacknowledged()) { + DBG(RMQPublisherHandler, + "Waiting for %lu messages to be acknowledged", + unAck); + + if (++tries > 10) break; + std::this_thread::sleep_for(std::chrono::milliseconds(50 * tries)); + } + freeAllMessages(_messages); +} + +/** + * RMQPublisher + */ + +RMQPublisher::RMQPublisher(uint64_t rId, + const AMQP::Address& address, + std::string cacert, + std::string queue, + std::vector&& msgs_to_send) + : _rId(rId), + _queue(queue), + _cacert(cacert), + _handler(nullptr), + _buffer_msg(std::move(msgs_to_send)) +{ +#ifdef EVTHREAD_USE_PTHREADS_IMPLEMENTED + evthread_use_pthreads(); +#endif + DBG(RMQPublisher, + "Libevent %s (LIBEVENT_VERSION_NUMBER = %#010x)", + event_get_version(), + event_get_version_number()); + DBG(RMQPublisher, + "%s (OPENSSL_VERSION_NUMBER = %#010x)", + OPENSSL_VERSION_TEXT, + OPENSSL_VERSION_NUMBER); +#if OPENSSL_VERSION_NUMBER < 0x10100000L + SSL_library_init(); +#else + OPENSSL_init_ssl(0, NULL); +#endif + DBG(RMQPublisher, + "RabbitMQ address: %s:%d/%s (queue = %s)", + address.hostname().c_str(), + address.port(), + address.vhost().c_str(), + _queue.c_str()) + + _loop = std::shared_ptr(event_base_new(), + [](struct event_base* event) { + event_base_free(event); + }); + + _handler = + std::make_shared(_rId, _loop, _cacert, _queue); + _connection = new AMQP::TcpConnection(_handler.get(), address); +} + +void RMQPublisher::publish(AMSMessage&& message) +{ + // We have some messages to send first (from a potential restart) + if (_buffer_msg.size() > 0) { + for (auto& msg : _buffer_msg) { + DBG(RMQPublisher, + "Publishing backed up message %d: %p", + msg.id(), + msg.data()) + _handler->publish(std::move(msg)); + } + _buffer_msg.clear(); + } + + DBG(RMQPublisher, "Publishing message %d: %p", message.id(), message.data()) + _handler->publish(std::move(message)); +} + +bool RMQPublisher::ready_publish() +{ + return _connection->ready() && _connection->usable(); +} + +bool RMQPublisher::waitToEstablish(unsigned ms, int repeat) +{ + return _handler->waitToEstablish(ms, repeat); +} + +unsigned RMQPublisher::unacknowledged() const +{ + return _handler->unacknowledged(); +} + +void RMQPublisher::start() { event_base_dispatch(_loop.get()); } + +void RMQPublisher::stop() { event_base_loopexit(_loop.get(), NULL); } + +bool RMQPublisher::connectionValid() { return _handler->connectionValid(); } + +std::vector& RMQPublisher::getMsgBuffer() +{ + return _handler->msgBuffer(); +} + +void RMQPublisher::cleanup() { _handler->cleanup(); } + +int RMQPublisher::msgSent() const { return _handler->msgSent(); } + +int RMQPublisher::msgAcknowledged() const +{ + return _handler->msgAcknowledged(); +} + +bool RMQPublisher::close(unsigned ms, int repeat) +{ + _handler->flush(); + _connection->close(false); + return _handler->waitToClose(ms, repeat); +} + +/** + * RMQInterface + */ + +bool RMQInterface::connect(std::string rmq_name, + std::string rmq_password, + std::string rmq_user, + std::string rmq_vhost, + int service_port, + std::string service_host, + std::string rmq_cert, + std::string outbound_queue, + std::string exchange, + std::string routing_key) +{ + _queue_sender = outbound_queue; + _exchange = exchange; + _routing_key = routing_key; + _cacert = rmq_cert; + + // Here we generate 64-bits wide random numbers to have a unique distributed ID + // WARNING: there is no guarantee of uniqueness here as each MPI rank will have its own generator + unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); + std::default_random_engine generator(seed); + std::uniform_int_distribution distrib(0, + std::numeric_limits::max()); + _rId = static_cast(distrib(generator)); + + AMQP::Login login(rmq_user, rmq_password); + _address = std::make_shared(service_host, + service_port, + login, + rmq_vhost, + /*is_secure*/ true); + _publisher = + std::make_shared(_rId, *_address, _cacert, _queue_sender); + + _publisher_thread = std::thread([&]() { _publisher->start(); }); + + if (!_publisher->waitToEstablish(100, 10)) { + _publisher->stop(); + _publisher_thread.join(); + FATAL(RabbitMQInterface, "Could not establish connection"); + } + + _consumer = std::make_shared( + _rId, *_address, _cacert, _exchange, _routing_key); + _consumer_thread = std::thread([&]() { _consumer->start(); }); + + if (!_consumer->waitToEstablish(100, 10)) { + _consumer->stop(); + _consumer_thread.join(); + FATAL(RabbitMQDB, "Could not establish consumer connection"); + } + + connected = true; + return connected; +} + +void RMQInterface::restartPublisher() +{ + std::vector messages = _publisher->getMsgBuffer(); + + AMSMessage& msg_min = + *(std::min_element(messages.begin(), + messages.end(), + [](const AMSMessage& a, const AMSMessage& b) { + return a.id() < b.id(); + })); + + DBG(RMQPublisher, + "[r%d] we have %lu buffered messages that will get re-send " + "(starting from msg #%d).", + _rId, + messages.size(), + msg_min.id()) + + // Stop the faulty publisher + _publisher->stop(); + _publisher_thread.join(); + _publisher.reset(); + connected = false; + + _publisher = std::make_shared( + _rId, *_address, _cacert, _queue_sender, std::move(messages)); + _publisher_thread = std::thread([&]() { _publisher->start(); }); + connected = true; +} + +void RMQInterface::close() +{ + if (!_publisher_thread.joinable() || !_consumer_thread.joinable()) { + DBG(RMQInterface, "Threads are not joinable") + return; + } + bool status = _publisher->close(100, 10); + CWARNING(RabbitMQDB, + !status, + "Could not gracefully close publisher TCP connection") + + DBG(RabbitMQInterface, "Number of messages sent: %d", _msg_tag) + DBG(RabbitMQInterface, + "Number of unacknowledged messages are %d", + _publisher->unacknowledged()) + _publisher->stop(); + _publisher_thread.join(); + + status = _consumer->close(100, 10); + CWARNING(RabbitMQDB, + !status, + "Could not gracefully close consumer TCP connection") + _consumer->stop(); + _consumer_thread.join(); + + connected = false; +} diff --git a/src/AMSlib/wf/workflow.hpp b/src/AMSlib/wf/workflow.hpp index 9743457e..68810d8e 100644 --- a/src/AMSlib/wf/workflow.hpp +++ b/src/AMSlib/wf/workflow.hpp @@ -86,7 +86,6 @@ class AMSWorkflow MPI_Comm comm; #endif - /** @brief Is the evaluate a distributed execution **/ bool isDistributed; @@ -100,7 +99,7 @@ class AMSWorkflow * @param[in] outputs vector to 1-D vectors storing num_elements * items to be stored in the database */ - void Store(size_t num_elements, + void store(size_t num_elements, std::vector &inputs, std::vector &outputs, bool *predicate = nullptr) @@ -152,7 +151,7 @@ class AMSWorkflow return; } - void Store(size_t num_elements, + void store(size_t num_elements, std::vector &inputs, std::vector &outputs, bool *predicate = nullptr) @@ -162,9 +161,26 @@ class AMSWorkflow mInputs.push_back(const_cast(I)); } - Store(num_elements, mInputs, outputs, predicate); + store(num_elements, mInputs, outputs, predicate); } + /** \brief Check if we can perform a surrogate model update. + * AMS can update surrogate model only when all MPI ranks have received + * the latest model from RabbitMQ. + * @return True if surrogate model can be updated + */ + bool updateModel() + { + if (!DB || !DB->allowModelUpdate()) return false; + bool local = DB->updateModel(); +#ifdef __ENABLE_MPI__ + bool global = false; + MPI_Allreduce(&local, &global, 1, MPI_CXX_BOOL, MPI_LAND, comm); + return global; +#else + return local; +#endif + } public: AMSWorkflow() @@ -199,7 +215,6 @@ class AMSWorkflow comm(MPI_COMM_NULL), #endif ePolicy(AMSExecPolicy::AMS_UBALANCED) - { DB = nullptr; auto &dbm = ams::db::DBManager::getInstance(); @@ -311,16 +326,23 @@ class AMSWorkflow CALIPER(CALI_MARK_END("PHYSICS MODULE");) if (DB) { CALIPER(CALI_MARK_BEGIN("DBSTORE");) - Store(totalElements, tmpIn, origOutputs); + store(totalElements, tmpIn, origOutputs); CALIPER(CALI_MARK_END("DBSTORE");) } CALIPER(CALI_MARK_END("AMSEvaluate");) return; } - if (DB && DB->updateModel()) { - UQModel->updateModel(""); + CALIPER(CALI_MARK_BEGIN("UPDATEMODEL");) + if (updateModel()) { + auto model = DB->getLatestModel(); + CINFO(Workflow, + rId == 0, + "Updating surrogate model with %s", + model.c_str()) + UQModel->updateModel(model); } + CALIPER(CALI_MARK_END("UPDATEMODEL");) // The predicate with which we will split the data on a later step bool *predicate = rm.allocate(totalElements, appDataLoc); @@ -416,12 +438,12 @@ class AMSWorkflow DBG(Workflow, "Storing data (#elements = %d) to database", packedElements); - Store(packedElements, packedInputs, packedOutputs); + store(packedElements, packedInputs, packedOutputs); } else { DBG(Workflow, "Storing data (#elements = %d) to database including predicates", totalElements); - Store(totalElements, origInputs, origOutputs, predicate); + store(totalElements, origInputs, origOutputs, predicate); } CALIPER(CALI_MARK_END("DBSTORE");) diff --git a/tests/AMSlib/json_configs/rmq.json.in b/tests/AMSlib/json_configs/rmq.json.in index c502a2af..9c29487a 100644 --- a/tests/AMSlib/json_configs/rmq.json.in +++ b/tests/AMSlib/json_configs/rmq.json.in @@ -1,7 +1,7 @@ { "db" : { "dbType" : "rmq", - "rmq_config" : { + "rmq_config" : { "service-port": , "service-host": "", "rabbitmq-erlang-cookie": "", @@ -10,9 +10,11 @@ "rabbitmq-user": "", "rabbitmq-vhost": "", "rabbitmq-cert": "", - "rabbitmq-inbound-queue": "", - "rabbitmq-outbound-queue": "" - } + "rabbitmq-outbound-queue": "", + "rabbitmq-exchange": "", + "rabbitmq-routing-key": "" + }, + "update_surrogate": false }, "ml_models" : { "random_50": { diff --git a/tools/AMSLogReader.py b/tools/AMSLogReader.py index 7ea0a33d..1a106802 100644 --- a/tools/AMSLogReader.py +++ b/tools/AMSLogReader.py @@ -1,3 +1,9 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + import argparse from pathlib import Path import glob diff --git a/tools/rmq/README.md b/tools/rmq/README.md new file mode 100644 index 00000000..59a489e2 --- /dev/null +++ b/tools/rmq/README.md @@ -0,0 +1,75 @@ +# Tools to interact with RabbitMQ + +This folder contains several scripts to send and receive messages using +RabbitMQ. They are useful to test and interact with AMSlib. Each script +is completetly standalone and does not require the AMS Python package, +however they require `pika` and `numpy`. + +## Generate TLS certificate + +To use most of the tools related to RabbitMQ you might need to provide TLS certificates. +To generate such certificate you can use OpenSSL, for example: + +```bash + openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' > rmq-tls.crt +``` +where `REMOTE_HOST` is the hostname of the RabbitMQ server and `REMOTE_PORT` is the port. + +## Consume messages from AMSlib + +To receive, or consume, messages emitted by AMSlib you can use `recv_binary.py`: + +```bash +python3 recv_binary.py -c rmq-credentials.json -t rmq-tls.crt -q test3 +``` + +If the credentials match, every messages sent by a simulation integrated +with AMS will be received by `recv_binary.py`. + +## Send a message to AMSlib + +### Send string messages +To send a simple text message to AMSlib, for example to force AMS to update its +surrogate model, you can do it using `send.py`: + +```bash +python3 send.py -c rmq-credentials.json -t rmq-tls.crt -e ams-fanout -r training -n 1 -m +"UPDATE:ConstantOneModel_cpu.pt" +``` + +where `rmq-pds.json` contains the RabbitMQ credentials and `rmq-pds.crt` the +TLS certificate. See `send.py -h` for more options. + +The RabbitMQ credentials file must follow this template: +```json +{ + "rabbitmq-erlang-cookie": "", + "rabbitmq-name": "", + "rabbitmq-password": "", + "rabbitmq-user": "", + "rabbitmq-vhost": "", + "service-port": 0, + "service-host": "", +} +``` + +> Note that you can use `send.py` to send any type of string messages to any RabbitMQ +> server. + +### Send binary-compatible AMSlib messages + +To send a message that mimics what each MPI rank in AMSlib would send to +the AMS Python module, one can use `send_ams.py`. For example, + +```bash +python3 send_ams.py -c rmq-credentials.json -t rmq-tls.crt -r test3 -n 10 +``` + +In another terminal, to receive the message just sent, you can run: + +```bash +python3 recv_binary.py -c rmq-credentials.json -t rmq-tls.crt -q test3 +``` + +This tool is useful to test the AMS python workflow without +actually running a simulation. diff --git a/tools/rmq/recv_binary.py b/tools/rmq/recv_binary.py new file mode 100755 index 00000000..2aa628d0 --- /dev/null +++ b/tools/rmq/recv_binary.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import argparse +import pika +import ssl +import sys +import os +import json +import copy +import numpy as np + +from typing import Tuple + +# CA Cert, can be generated with (where $REMOTE_HOST and $REMOTE_PORT can be found in the JSON file): +# openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null +# 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' > rabbitmq-credentials.cacert + +nbmsg = 0 +all_messages = [] +byte_received = 0 + +def get_rmq_connection(json_file): + data = {} + with open(json_file, 'r') as f: + data = json.load(f) + return data + +def parse_header(body: str) -> dict: + """ + We encode the message as follow: + - 1 byte is the size of the header (here 16). Limit max: 255 + - 1 byte is the precision (4 for float, 8 for double). Limit max: 255 + - 2 bytes are the MPI rank (0 if AMS is not running with MPI). Limit max: 65535 + - 2 bytes to store the size of the MSG domain name. Limit max: 65535 + - 4 bytes are the number of elements in the message. Limit max: 2^32 - 1 + - 2 bytes are the input dimension. Limit max: 65535 + - 2 bytes are the output dimension. Limit max: 65535 + + |_Header_|_Datatype_|___Rank___|__DomainSize__|__#elems__|___InDim____|___OutDim___|.real data + + Then the data starts at 16 and is structered as pairs of input/outputs. + Let K be the total number of elements, then we have K pairs of inputs/outputs (either float or double): + + |__Header_(16B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| + """ + + if len(body) == 0: + print(f"Empty message. skipping") + return {} + + header_size = np.frombuffer(body[0:1], dtype=np.uint8)[0] + res = {} + + if header_size != 16: + print(f"Incomplete message of size {len(body)}. Header size is {header_size}, it should be of size 16. skipping ({body})") + return {} + + try: + res["header_size"] = header_size + res["datatype"] = np.frombuffer(body[1:2], dtype=np.uint8)[0] + res["mpirank"] = np.frombuffer(body[2:4], dtype=np.uint16)[0] + res["domain_size"] = np.frombuffer(body[4:6], dtype=np.uint16)[0] + res["num_element"] = np.frombuffer(body[6:10], dtype=np.uint32)[0] + res["input_dim"] = np.frombuffer(body[10:12], dtype=np.uint16)[0] + res["output_dim"] = np.frombuffer(body[12:14], dtype=np.uint16)[0] + res["padding"] = np.frombuffer(body[14:16], dtype=np.uint16)[0] + # Theoritical size in Bytes for the incoming message (without the header) + # Int() is needed otherwise we might overflow here (because of uint16 / uint8) + res["data_size"] = int(res["datatype"]) * res["num_element"] * (int(res["input_dim"]) + int(res["output_dim"])+res["padding"]) + res["multiple_msg"] = len(body) != (header_size + res["data_size"]) + except ValueError as e: + return {} + return res + +def multiple_messages(body: str) -> bool: + return parse_header(body)["multiple_msg"] + +def parse_data(body: str, header_info: dict) -> Tuple[str, np.array, np.array]: + data = np.array([]) + if len(body) == 0: + return data + + header_size = header_info["header_size"] + data_size = header_info["data_size"] + domain_name_size = header_info["domain_size"] + domain_name = body[header_size : header_size + domain_name_size] + domain_name = domain_name.decode("utf-8") + + try: + if data_size == 4: #if datatype takes 4 bytes + data = np.frombuffer(body[header_siz+domain_name_size:header_size+domain_name_size+data_size], dtype=np.float32) + else: + data = np.frombuffer(body[header_size+domain_name_size:header_size+domain_name_size+data_size], dtype=np.float64) + except ValueError as e: + print(f"Error: {e} => {header_info}") + + idim = header_info["input_dim"] + odim = header_info["output_dim"] + data = data.reshape((-1, idim + odim)) + # Return input, output + return (domain_name, data[:, :idim], data[:, idim:]) + +def callback(ch, method, properties, body, args = None): + global nbmsg + global all_messages + global byte_received + nbmsg += 1 + byte_received += len(body) + + i = 1 + stream = copy.deepcopy(body) + while stream: + header_info = parse_header(stream) + if not header_info: + break + domain_name, data_input, data_output = parse_data(stream, header_info) + num_element = header_info["num_element"] + mpirank = header_info["mpirank"] + # total size of byte we read for that message + chunk_size = header_info["header_size"] + header_info["domain_size"] + header_info["data_size"] + + print( + f" [{nbmsg}/{i}] Received from exchange=\"{method.exchange}\" routing_key=\"{method.routing_key}\"\n" + f" > [r{mpirank}] ({domain_name}) : {len(stream)/(1024*1024)} MB / {num_element} elements\n") + + if data_input.size > 0: + all_messages.append(data_input) + # We remove the current message and keep going + stream = stream[chunk_size:] + i += 1 + +def main(credentials: str, cacert: str, queue: str): + conn = get_rmq_connection(credentials) + if cacert is None: + ssl_options = None + else: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = False + context.load_verify_locations(cacert) + ssl_options = pika.SSLOptions(context) + + credentials = pika.PlainCredentials(conn["rabbitmq-user"], conn["rabbitmq-password"]) + cp = pika.ConnectionParameters( + host=conn["service-host"], + port=conn["service-port"], + virtual_host=conn["rabbitmq-vhost"], + credentials=credentials, + ssl_options=ssl_options + ) + + connection = pika.BlockingConnection(cp) + channel = connection.channel() + + print(f"Connecting to {conn['service-host']} ...") + + # Warning: + # if no queue is specified then RabbitMQ will NOT hold messages that are not routed to queues. + # So in order to receive the message, the receiver will have to be started BEFORE the sender + # Otherwise the message will be lost. + + result = channel.queue_declare(queue=queue, exclusive=False) + queue_name = result.method.queue + channel.basic_consume(queue=queue_name, on_message_callback=callback, auto_ack=True) + print(f"Listening on queue = {queue_name}") + + print(" [*] Waiting for messages. To exit press CTRL+C") + channel.start_consuming() + +def parse_args(): + parser = argparse.ArgumentParser(description="Tools that consumes AMS-encoded messages from RabbitMQ queue") + parser.add_argument('-c', '--creds', help="Credentials file (JSON)", required=True) + parser.add_argument('-t', '--tls-cert', help="TLS certificate file", required=False) + parser.add_argument('-q', '--queue', help="Queue to listen to", required=True) + + args = parser.parse_args() + return args + +if __name__ == "__main__": + try: + args = parse_args() + main(credentials = args.creds, cacert = args.tls_cert, queue = args.queue) + except KeyboardInterrupt: + print("") + print("Done") + try: + sys.exit(0) + except SystemExit: + os._exit(0) diff --git a/tools/rmq/send.py b/tools/rmq/send.py new file mode 100755 index 00000000..ac90338e --- /dev/null +++ b/tools/rmq/send.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import pika +import sys +import ssl +import json +import argparse + +# JSON file containing host, port, password etc.. +# PDS_JSON = "creds.json" +PDS_JSON = "rmq-pds.json" + +# CA Cert, can be generated with (where $REMOTE_HOST and $REMOTE_PORT can be found in the JSON file): +# openssl s_client -connect $REMOTE_HOST:$REMOTE_PORT -showcerts < /dev/null 2>/dev/null | sed -ne '/-BEGIN CERTIFICATE-/,/-END CERTIFICATE-/p' > rmq-pds.crt + +# CA_CERT = "creds.pem" +CA_CERT = "rmq-pds.crt" + +def get_rmq_connection(json_file): + data = {} + with open(json_file, 'r') as f: + data = json.load(f) + return data + +def callback(ch, method, properties, body, args): + data = body.decode() + print(properties) + print(f"Received \"{data}\" from exchange=\"{method.exchange}\" routing_key=\"{method.routing_key}\" args={args}") + + +def main(args): + conn = get_rmq_connection(args.creds) + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = False + context.load_verify_locations(args.tls_cert) + + print(f"[send.py] Connecting to {conn['service-host']} ...") + + credentials = pika.PlainCredentials(conn["rabbitmq-user"], conn["rabbitmq-password"]) + cp = pika.ConnectionParameters( + host=conn["service-host"], + port=conn["service-port"], + virtual_host=conn["rabbitmq-vhost"], + credentials=credentials, + ssl_options=pika.SSLOptions(context) + ) + + connection = pika.BlockingConnection(cp) + channel = connection.channel() + + + # Turn on delivery confirmations + channel.confirm_delivery() + + result = channel.queue_declare(queue='', exclusive=False) + + queue_name = result.method.queue + for i in range(args.num_msg): + try: + channel.basic_publish(exchange=args.exchange, routing_key=args.routing_key, body=args.msg) + print(f" [{i}] Sent '{args.msg}' on exchange='{args.exchange}'/routing_key='{args.routing_key}'") + except pika.exceptions.UnroutableError: + print(f" [{i}] Message could not be confirmed") + connection.close() + +def parse_args(): + + parser = argparse.ArgumentParser(description="Tool that sends AMS-compatible messages to a RabbitMQ server") + parser.add_argument('-c', '--creds', help="Credentials file (JSON)", required=True) + parser.add_argument('-t', '--tls-cert', help="TLS certificate file", required=True) + parser.add_argument('-e', '--exchange', help="On which exchange to send messages (default = '')", default='', required=False) + parser.add_argument('-r', '--routing-key', help="Routing key for the messages", required=True) + parser.add_argument('-n', '--num-msg', type=int, help="Number of messages that will get sent (default: 1)", default=1) + parser.add_argument('-m', '--msg', type=str, help="Content of the message") + + args = parser.parse_args() + + return args + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/tools/rmq/send_ams.py b/tools/rmq/send_ams.py new file mode 100755 index 00000000..d686607c --- /dev/null +++ b/tools/rmq/send_ams.py @@ -0,0 +1,134 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Lawrence Livermore National Security, LLC and other +# AMSLib Project Developers +# +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +import argparse +import pika +import sys +import ssl +import json +import struct +from typing import Tuple + +import numpy as np + +def ams_header_format() -> str: + """ + This string represents the AMS format in Python pack format: + See https://docs.python.org/3/library/struct.html#format-characters + - 1 byte is the size of the header (here 12). Limit max: 255 + - 1 byte is the precision (4 for float, 8 for double). Limit max: 255 + - 2 bytes are the MPI rank (0 if AMS is not running with MPI). Limit max: 65535 + - 2 bytes to store the size of the MSG domain name. Limit max: 65535 + - 4 bytes are the number of elements in the message. Limit max: 2^32 - 1 + - 2 bytes are the input dimension. Limit max: 65535 + - 2 bytes are the output dimension. Limit max: 65535 + - 2 bytes are for aligning memory to 8 + + |_Header_|_Datatype_|_Rank_|_DomainSize_|_#elems_|_InDim_|_OutDim_|_Pad_|_DomainName_|.Real_Data.| + + Then the data starts at byte 16 with the domain name, then the real data and + is structured as pairs of input/outputs. Let K be the total number of elements, + then we have K pairs of inputs/outputs (either float or double): + + |__Header_(16B)__|_Domain_Name_|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__| + + """ + return "BBHHIHHH" + +def ams_endianness() -> str: + """ + '=' means native endianness in standart size (system). + See https://docs.python.org/3/library/struct.html#format-characters + """ + return "=" + +def ams_encode_message(num_elem: int, domain_name: str, input_dim: int, output_dim: int, dtype_byte: int = 4) -> bytes: + """ + For debugging and testing purposes, this function encode a message identical to what AMS would send + """ + header_format = ams_endianness() + ams_header_format() + hsize = struct.calcsize(header_format) + assert dtype_byte in [4, 8] + dt = "f" if dtype_byte == 4 else "d" + mpi_rank = 0 + data = np.random.rand(num_elem * (input_dim + output_dim)) + domain_name_size = len(domain_name) + domain_name = bytes(domain_name, "utf-8") + padding = 0 + header_content = (hsize, dtype_byte, mpi_rank, domain_name_size, data.size, input_dim, output_dim, padding) + # float or double + msg_format = f"{header_format}{domain_name_size}s{data.size}{dt}" + return struct.pack(msg_format, *header_content, domain_name, *data) + +def ams_parse_creds(json_file: str) -> dict: + """ + Parse the credentials to retrieve connection informations + """ + data = {} + with open(json_file, 'r') as f: + data = json.load(f) + return data + +def main(args: dict): + conn = ams_parse_creds(args.creds) + if args.tls_cert is None: + ssl_context = None + else: + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + context.check_hostname = False + context.verify_mode = ssl.CERT_REQUIRED + context.load_verify_locations(args.tls_cert) + ssl_context = pika.SSLOptions(context) + + credentials = pika.PlainCredentials(conn["rabbitmq-user"], conn["rabbitmq-password"]) + cp = pika.ConnectionParameters( + host=conn["service-host"], + port=conn["service-port"], + virtual_host=conn["rabbitmq-vhost"], + credentials=credentials, + ssl_options=ssl_context + ) + + connection = pika.BlockingConnection(cp) + channel = connection.channel() + result = channel.queue_declare(queue = args.queue, exclusive = False) + queue_name = result.method.queue + + encoded_msg = ams_encode_message( + num_elem = args.num_elem, + domain_name = args.domain_name, + input_dim = args.input_dim, + output_dim = args.output_dim, + dtype_byte = args.data_type + ) + for i in range(1, args.num_msg+1): + channel.basic_publish(exchange='', routing_key = args.routing_key, body = encoded_msg) + print(f"[{i}/{args.num_msg}] Sent message with {args.num_elem} elements of dim=({args.input_dim},{args.output_dim}) elements on queue='{queue_name}'/routing_key='{args.routing_key}'") + connection.close() + +def parse_args() -> dict: + + parser = argparse.ArgumentParser(description="Tool that sends AMS-compatible messages to a RabbitMQ server") + parser.add_argument('-c', '--creds', help="Credentials file (JSON)", required=True) + parser.add_argument('-t', '--tls-cert', help="TLS certificate file", required=False) + parser.add_argument('-q', '--queue', help="On which queue to send messages (default = '')", default='', required=False) + parser.add_argument('-r', '--routing-key', help="Routing key for the messages", required=True) + + parser.add_argument('-m', '--num-msg', type=int, help="Number of messages that will get sent (default: 1)", default=1) + parser.add_argument('-n', '--num-elem', type=int, help="Number of elements per message", required=True) + parser.add_argument('-i', '--input-dim', type=int, help="Input dimensions (default: 2)", default=2) + parser.add_argument('-o', '--output-dim', type=int, help="Output dimensions (default: 4)", default=4) + parser.add_argument('-d', '--data-type', type=int, help="Data size in bytes: float (4) or double (8) (default: 4)", choices=[4, 8], default=4) + parser.add_argument('-x', '--domain-name', type=str, help="Domain name", default="domain_test") + + args = parser.parse_args() + + return args + +if __name__ == "__main__": + args = parse_args() + main(args)