Skip to content

Commit

Permalink
Implement Delta UQ
Browse files Browse the repository at this point in the history
  • Loading branch information
ggeorgakoudis committed Nov 7, 2023
1 parent 209f715 commit 80d3bb5
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 57 deletions.
31 changes: 20 additions & 11 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <mfem.hpp>
#include <random>
#include <stdexcept>
#include <string>
#include <umpire/strategy/QuickPool.hpp>
#include <unordered_set>
Expand Down Expand Up @@ -182,14 +184,18 @@ int run(const char *device_name,
dbType = AMSDBType::RMQ;
}

AMSUQPolicy uq_policy = (std::strcmp(uq_policy_opt, "max") == 0)
? AMSUQPolicy::FAISSMax
: AMSUQPolicy::FAISSMean;
AMSUQPolicy uq_policy;

if (uq_policy != AMSUQPolicy::FAISSMax)
uq_policy = ((std::strcmp(uq_policy_opt, "deltauq") == 0))
? AMSUQPolicy::DeltaUQ
: AMSUQPolicy::FAISSMean;
if (strcmp(uq_policy_opt, "faiss-max") == 0)
uq_policy = AMSUQPolicy::FAISS_Max;
else if (strcmp(uq_policy_opt, "faiss-mean") == 0)
uq_policy = AMSUQPolicy::FAISS_Mean;
else if (strcmp(uq_policy_opt, "deltauq-max") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Max;
else if (strcmp(uq_policy_opt, "deltauq-mean") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Mean;
else
throw std::runtime_error("Invalid UQ policy");

// set up a randomization seed
srand(seed + rId);
Expand Down Expand Up @@ -671,7 +677,7 @@ int main(int argc, char **argv)
const char *precision_opt = "double";
AMSDType precision = AMSDType::Double;

const char *uq_policy_opt = "mean";
const char *uq_policy_opt = "faiss-mean";
int k_nearest = 5;

int seed = 0;
Expand Down Expand Up @@ -795,11 +801,14 @@ int main(int argc, char **argv)
"-uq",
"--uqtype",
"Types of UQ to select from: \n"
"\t 'mean' Uncertainty is computed in comparison against the "
"\t 'faiss-mean' Uncertainty is computed in comparison "
"against the "
"mean distance of k-nearest neighbors\n"
"\t 'max': Uncertainty is computed in comparison with the "
"\t 'faiss-max': Uncertainty is computed in comparison with "
"the "
"k'st cluster \n"
"\t 'deltauq': Uncertainty through DUQ (not supported)\n");
"\t 'deltauq-mean': Uncertainty through DUQ using mean\n"
"\t 'deltauq-max': Uncertainty through DUQ using max\n");

args.AddOption(
&verbose, "-v", "--verbose", "-qu", "--quiet", "Print extra stuff");
Expand Down
10 changes: 7 additions & 3 deletions src/include/AMS.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ typedef enum { UBALANCED = 0, BALANCED } AMSExecPolicy;

typedef enum { None = 0, CSV, REDIS, HDF5, RMQ } AMSDBType;

// TODO: create a cleaner interface that separates UQ type (FAISS, DeltaUQ) with policy (max, mean).
typedef enum {
FAISSMean = 0,
FAISSMax,
DeltaUQ // Not supported
AMSUQPolicy_BEGIN = 0,
FAISS_Mean,
FAISS_Max,
DeltaUQ_Mean,
DeltaUQ_Max,
AMSUQPolicy_END
} AMSUQPolicy;

typedef struct ams_conf {
Expand Down
21 changes: 11 additions & 10 deletions src/ml/hdcache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HDCache

const bool m_use_random;
const int m_knbrs = 0;
const AMSUQPolicy m_policy = AMSUQPolicy::FAISSMean;
const AMSUQPolicy m_policy = AMSUQPolicy::FAISS_Mean;

AMSResourceType cache_location;

Expand Down Expand Up @@ -209,6 +209,11 @@ class HDCache
return cache;
}

if (uqPolicy != AMSUQPolicy::FAISS_Mean &&
uqPolicy != AMSUQPolicy::FAISS_Max)
THROW(std::invalid_argument,
"Invalid UQ policy for hdcache" + std::to_string(uqPolicy));

DBG(UQModule, "Generating new cache under (%s)", cache_path.c_str())
std::shared_ptr<HDCache<TypeInValue>> new_cache =
std::shared_ptr<HDCache<TypeInValue>>(new HDCache<TypeInValue>(
Expand All @@ -224,7 +229,7 @@ class HDCache
{
static std::string random_path("random");
std::shared_ptr<HDCache<TypeInValue>> cache = find_cache(
random_path, resource, AMSUQPolicy::FAISSMean, -1, threshold);
random_path, resource, AMSUQPolicy::FAISS_Mean, -1, threshold);
if (cache) {
DBG(UQModule, "Returning existing cache under (%s)", random_path.c_str())
return cache;
Expand Down Expand Up @@ -547,16 +552,13 @@ class HDCache
// compute means
if (cache_location == AMSResourceType::HOST) {
for (size_t i = 0; i < ndata; ++i) {
CFATAL(UQModule,
m_policy == AMSUQPolicy::DeltaUQ,
"DeltaUQ is not supported yet");
if (m_policy == AMSUQPolicy::FAISSMean) {
if (m_policy == AMSUQPolicy::FAISS_Mean) {
TypeValue mean_dist = std::accumulate(kdists + i * knbrs,
kdists + (i + 1) * knbrs,
0.) *
ook;
is_acceptable[i] = mean_dist < acceptable_error;
} else if (m_policy == AMSUQPolicy::FAISSMax) {
} else if (m_policy == AMSUQPolicy::FAISS_Max) {
// Take the furtherst cluster as the distance metric
TypeValue max_dist =
*std::max_element(&kdists[i * knbrs],
Expand All @@ -566,9 +568,8 @@ class HDCache
}
} else {
CFATAL(UQModule,
(m_policy == AMSUQPolicy::DeltaUQ) ||
(m_policy == AMSUQPolicy::FAISSMax),
"DeltaUQ is not supported yet");
m_policy == AMSUQPolicy::FAISS_Max,
"FAISS Max on device is not supported yet");

ams::Device::computePredicate(
kdists, is_acceptable, ndata, knbrs, acceptable_error);
Expand Down
87 changes: 74 additions & 13 deletions src/ml/surrogate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#ifndef __AMS_SURROGATE_HPP__
#define __AMS_SURROGATE_HPP__


#include <ATen/core/ivalue.h>
#include <memory>
#include <stdexcept>
#include <string>
Expand Down Expand Up @@ -39,7 +39,7 @@ class SurrogateModel
private:
const std::string model_path;
AMSResourceType model_resource;

const bool _is_DeltaUQ;

#ifdef __ENABLE_TORCH__
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -106,6 +106,30 @@ class SurrogateModel
}
}

PERFFASPECT()
inline void tensorToHostArray(at::Tensor tensor,
long numRows,
long numCols,
TypeInValue** array)
{
// Transpose to get continuous memory and
// perform single memcpy.
tensor = tensor.transpose(1, 0);
if (model_resource == AMSResourceType::HOST) {
for (long j = 0; j < numCols; j++) {
auto tmp = tensor[j].contiguous();
TypeInValue* ptr = tmp.data_ptr<TypeInValue>();
HtoHMemcpy(array[j], ptr, sizeof(TypeInValue) * numRows);
}
} else {
for (long j = 0; j < numCols; j++) {
auto tmp = tensor[j].contiguous();
TypeInValue* ptr = tmp.data_ptr<TypeInValue>();
DtoHMemcpy(array[j], ptr, sizeof(TypeInValue) * numRows);
}
}
}

// -------------------------------------------------------------------------
// loading a surrogate model!
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -154,21 +178,36 @@ class SurrogateModel
size_t num_in,
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs)
TypeInValue** outputs,
TypeInValue** outputs_stdev)
{
//torch::NoGradGuard no_grad;
c10::InferenceMode guard(true);
auto input = arrayToTensor(num_elements, num_in, inputs);
input.set_requires_grad(false);
at::Tensor output = module.forward({input}).toTensor().detach();
if (_is_DeltaUQ) {
assert(outputs_stdev && "Expected non-null outputs_stdev");
// The deltauq surrogate returns a tuple of (outputs, outputs_stdev)
auto output_tuple = module.forward({input}).toTuple();
at::Tensor output_mean_tensor = output_tuple->elements()[0].toTensor().detach();
at::Tensor output_stdev_tensor = output_tuple->elements()[1].toTensor().detach();
tensorToArray(output_mean_tensor, num_elements, num_out, outputs);
tensorToHostArray(output_stdev_tensor,
num_elements,
num_out,
outputs_stdev);
}
else {
at::Tensor output = module.forward({input}).toTensor().detach();
tensorToArray(output, num_elements, num_out, outputs);
}

DBG(Surrogate,
"Evaluate surrogate model (%ld, %ld) -> (%ld, %ld)",
num_elements,
num_in,
num_elements,
num_out);
tensorToArray(output, num_elements, num_out, outputs);
}

#else
Expand All @@ -190,10 +229,11 @@ class SurrogateModel

#endif

SurrogateModel(const char* model_path, AMSResourceType resource = AMSResourceType::HOST)
: model_path(model_path), model_resource(resource)
SurrogateModel(const char* model_path,
AMSResourceType resource = AMSResourceType::HOST,
bool is_DeltaUQ = false)
: model_path(model_path), model_resource(resource), _is_DeltaUQ(is_DeltaUQ)
{

if (resource != AMSResourceType::DEVICE)
_load<TypeInValue>(model_path, "cpu");
else
Expand Down Expand Up @@ -226,7 +266,8 @@ class SurrogateModel

static std::shared_ptr<SurrogateModel<TypeInValue>> getInstance(
const char* model_path,
AMSResourceType resource = AMSResourceType::HOST)
AMSResourceType resource = AMSResourceType::HOST,
bool is_DeltaUQ = false)
{
auto model =
SurrogateModel<TypeInValue>::instances.find(std::string(model_path));
Expand All @@ -238,6 +279,9 @@ class SurrogateModel
"Currently we are not supporting loading the same model file on "
"different devices.");

if(is_DeltaUQ != torch_model->is_DeltaUQ())
THROW(std::runtime_error, "Loaded model instance is not DeltaUQ");

if (!same_type<TypeInValue>(torch_model->is_double()))
throw std::runtime_error(
"Requesting model loading of different data types.");
Expand All @@ -252,7 +296,7 @@ class SurrogateModel
DBG(Surrogate, "Generating new model under (%s)", model_path);
std::shared_ptr<SurrogateModel<TypeInValue>> torch_model =
std::shared_ptr<SurrogateModel<TypeInValue>>(
new SurrogateModel<TypeInValue>(model_path, resource));
new SurrogateModel<TypeInValue>(model_path, resource, is_DeltaUQ));
instances.insert(std::make_pair(std::string(model_path), torch_model));
return torch_model;
};
Expand All @@ -268,9 +312,24 @@ class SurrogateModel
size_t num_in,
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs)
TypeInValue** outputs,
TypeInValue **outputs_stdev = nullptr)
{
_evaluate(num_elements, num_in, num_out, inputs, outputs);
_evaluate(num_elements, num_in, num_out, inputs, outputs, outputs_stdev);
}

PERFFASPECT()
inline void evaluate(long num_elements,
std::vector<const TypeInValue*> inputs,
std::vector<TypeInValue*> outputs,
std::vector<TypeInValue*> outputs_stdev)
{
_evaluate(num_elements,
inputs.size(),
outputs.size(),
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()),
static_cast<TypeInValue**>(outputs_stdev.data()));
}

PERFFASPECT()
Expand All @@ -282,7 +341,8 @@ class SurrogateModel
inputs.size(),
outputs.size(),
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()));
static_cast<TypeInValue**>(outputs.data()),
nullptr);
}

#ifdef __ENABLE_TORCH__
Expand All @@ -295,6 +355,7 @@ class SurrogateModel
}
#endif

bool is_DeltaUQ() { return _is_DeltaUQ; }
};

template <typename T>
Expand Down
Loading

0 comments on commit 80d3bb5

Please sign in to comment.