Skip to content

Commit

Permalink
Surrogate ML model update (#78)
Browse files Browse the repository at this point in the history
* Added ML model update
* Moved most implementations of RMQ to its own .cpp file
* Added scripts to debug/test RabbitMQ in tools/
* Moved RMQ implementation to corresponding .cpp file
* We use random queues generated by RMQ instead of doing it manually in AMSlib
* Added JSON option db:update_surrogate (boolean) to control whether we update surrogate

Signed-off-by: Loic Pottier <[email protected]>
  • Loading branch information
lpottier committed Aug 8, 2024
1 parent 223d32e commit 6d4446b
Show file tree
Hide file tree
Showing 14 changed files with 2,066 additions and 990 deletions.
14 changes: 12 additions & 2 deletions examples/app/eos_ams.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/

#include <vector>

#include "eos_ams.hpp"

#include <vector>

template <typename FPType>
void callBack(void *cls,
long elements,
Expand Down Expand Up @@ -47,12 +47,22 @@ AMSEOS<FPType>::AMSEOS(EOS<FPType> *model,
uq_path,
"ideal_gas",
k_nearest);
#ifdef __ENABLE_MPI__
wf_ = AMSCreateDistributedExecutor(model_descr,
dtype,
res_type,
(AMSPhysicFn)callBack<FPType>,
MPI_COMM_WORLD,
mpi_task,
mpi_nproc);
#else
wf_ = AMSCreateExecutor(model_descr,
dtype,
res_type,
(AMSPhysicFn)callBack<FPType>,
mpi_task,
mpi_nproc);
#endif
}

template <typename FPType>
Expand Down
22 changes: 13 additions & 9 deletions src/AMSWorkflow/ams/rmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ def header_format(self) -> str:
- 2 bytes are the output dimension. Limit max: 65535
- 2 bytes are for aligning memory to 8
|_Header_|_Datatype_|___Rank___|__DomainSize__|__#elems__|___InDim____|___OutDim___|_Pad_|.real data.|
|_Header_|_Datatype_|_Rank_|_DomainSize_|_#elems_|_InDim_|_OutDim_|_Pad_|_DomainName_|.Real_Data.|
Then the data starts at 16 and is structered as pairs of input/outputs.
Let K be the total number of elements, then we have K pairs of inputs/outputs (either float or double):
Then the data starts at byte 16 with the domain name, then the real data and
is structured as pairs of input/outputs. Let K be the total number of elements,
then we have K pairs of inputs/outputs (either float or double):
|__Header_(16B)__|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__|
|__Header_(16B)__|_Domain_Name_|__Input 1__|__Output 1__|...|__Input_K__|__Output_K__|
"""
return "BBHHIHHH"
Expand All @@ -55,20 +56,23 @@ def endianness(self) -> str:
"""
return "="

def encode(self, num_elem: int, input_dim: int, output_dim: int, dtype_byte: int = 4) -> bytes:
def encode(self, num_elem: int, domain_name: str, input_dim: int, output_dim: int, dtype_byte: int = 4) -> bytes:
"""
For debugging and testing purposes, this function encode a message identical to what AMS would send
"""
header_format = self.endianness() + self.header_format()
header_format = self.ams_endianness() + self.ams_header_format()
hsize = struct.calcsize(header_format)
assert dtype_byte in [4, 8]
dt = "f" if dtype_byte == 4 else "d"
mpi_rank = 0
data = np.random.rand(num_elem * (input_dim + output_dim))
header_content = (hsize, dtype_byte, mpi_rank, data.size, input_dim, output_dim)
domain_name_size = len(domain_name)
domain_name = bytes(domain_name, "utf-8")
padding = 0
header_content = (hsize, dtype_byte, mpi_rank, domain_name_size, data.size, input_dim, output_dim, padding)
# float or double
msg_format = f"{header_format}{data.size}{dt}"
return struct.pack(msg_format, *header_content, *data)
msg_format = f"{header_format}{domain_name_size}s{data.size}{dt}"
return struct.pack(msg_format, *header_content, domain_name, *data)

def _parse_header(self, body: str) -> dict:
"""
Expand Down
17 changes: 12 additions & 5 deletions src/AMSlib/AMS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
*/

#include "AMS.h"

#include <limits.h>
#include <unistd.h>

Expand All @@ -17,7 +19,6 @@
#include <utility>
#include <vector>

#include "AMS.h"
#include "include/AMS.h"
#include "ml/uq.hpp"
#include "wf/basedb.hpp"
Expand Down Expand Up @@ -268,6 +269,7 @@ class AMSWrap
std::unordered_map<std::string, int> ams_candidate_models;
AMSDBType dbType = AMSDBType::AMS_NONE;
ams::ResourceManager &memManager;
int rId;

private:
void dumpEnv()
Expand Down Expand Up @@ -372,10 +374,13 @@ class AMSWrap
std::string rmq_user = getEntry<std::string>(rmq_entry, "rabbitmq-user");
std::string rmq_vhost = getEntry<std::string>(rmq_entry, "rabbitmq-vhost");
std::string rmq_cert = getEntry<std::string>(rmq_entry, "rabbitmq-cert");
std::string rmq_in_queue =
getEntry<std::string>(rmq_entry, "rabbitmq-inbound-queue");
std::string rmq_out_queue =
getEntry<std::string>(rmq_entry, "rabbitmq-outbound-queue");
std::string exchange =
getEntry<std::string>(rmq_entry, "rabbitmq-exchange");
std::string routing_key =
getEntry<std::string>(rmq_entry, "rabbitmq-routing-key");
bool update_surrogate = getEntry<bool>(entry, "update_surrogate");

auto &DB = ams::db::DBManager::getInstance();
DB.instantiate_rmq_db(port,
Expand All @@ -385,8 +390,10 @@ class AMSWrap
rmq_user,
rmq_vhost,
rmq_cert,
rmq_in_queue,
rmq_out_queue);
rmq_out_queue,
exchange,
routing_key,
update_surrogate);
}

void parseDatabase(json &jRoot)
Expand Down
3 changes: 3 additions & 0 deletions src/AMSlib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ if (WITH_HDF5)
list(APPEND AMS_LIB_SRC wf/hdf5db.cpp)
endif()

if (WITH_RMQ)
list(APPEND AMS_LIB_SRC wf/rmqdb.cpp)
endif()



Expand Down
4 changes: 2 additions & 2 deletions src/AMSlib/ml/surrogate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class SurrogateModel
inline void _load(const std::string& model_path,
const std::string& device_name)
{
DBG(Surrogate, "Using model at double precision");
DBG(Surrogate, "Using model at double precision: %s", model_path.c_str());
_load_torch(model_path, torch::Device(device_name), torch::kFloat64);
}

Expand All @@ -145,7 +145,7 @@ class SurrogateModel
inline void _load(const std::string& model_path,
const std::string& device_name)
{
DBG(Surrogate, "Using model at single precision");
DBG(Surrogate, "Using model at single precision: %s", model_path.c_str());
_load_torch(model_path, torch::Device(device_name), torch::kFloat32);
}

Expand Down
Loading

0 comments on commit 6d4446b

Please sign in to comment.