Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Delta UQ #16

Merged
merged 5 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions examples/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <mfem.hpp>
#include <random>
#include <stdexcept>
#include <string>
#include <umpire/strategy/QuickPool.hpp>
#include <unordered_set>
Expand Down Expand Up @@ -182,14 +184,18 @@ int run(const char *device_name,
dbType = AMSDBType::RMQ;
}

AMSUQPolicy uq_policy = (std::strcmp(uq_policy_opt, "max") == 0)
? AMSUQPolicy::FAISSMax
: AMSUQPolicy::FAISSMean;
AMSUQPolicy uq_policy;

if (uq_policy != AMSUQPolicy::FAISSMax)
uq_policy = ((std::strcmp(uq_policy_opt, "deltauq") == 0))
? AMSUQPolicy::DeltaUQ
: AMSUQPolicy::FAISSMean;
if (strcmp(uq_policy_opt, "faiss-max") == 0)
uq_policy = AMSUQPolicy::FAISS_Max;
else if (strcmp(uq_policy_opt, "faiss-mean") == 0)
uq_policy = AMSUQPolicy::FAISS_Mean;
else if (strcmp(uq_policy_opt, "deltauq-max") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Max;
else if (strcmp(uq_policy_opt, "deltauq-mean") == 0)
uq_policy = AMSUQPolicy::DeltaUQ_Mean;
else
throw std::runtime_error("Invalid UQ policy");

// set up a randomization seed
srand(seed + rId);
Expand Down Expand Up @@ -671,7 +677,7 @@ int main(int argc, char **argv)
const char *precision_opt = "double";
AMSDType precision = AMSDType::Double;

const char *uq_policy_opt = "mean";
const char *uq_policy_opt = "faiss-mean";
int k_nearest = 5;

int seed = 0;
Expand Down Expand Up @@ -795,11 +801,14 @@ int main(int argc, char **argv)
"-uq",
"--uqtype",
"Types of UQ to select from: \n"
"\t 'mean' Uncertainty is computed in comparison against the "
"\t 'faiss-mean' Uncertainty is computed in comparison "
"against the "
"mean distance of k-nearest neighbors\n"
"\t 'max': Uncertainty is computed in comparison with the "
"\t 'faiss-max': Uncertainty is computed in comparison with "
"the "
"k'st cluster \n"
"\t 'deltauq': Uncertainty through DUQ (not supported)\n");
"\t 'deltauq-mean': Uncertainty through DUQ using mean\n"
"\t 'deltauq-max': Uncertainty through DUQ using max\n");

args.AddOption(
&verbose, "-v", "--verbose", "-qu", "--quiet", "Print extra stuff");
Expand Down
10 changes: 7 additions & 3 deletions src/include/AMS.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ typedef enum { UBALANCED = 0, BALANCED } AMSExecPolicy;

typedef enum { None = 0, CSV, REDIS, HDF5, RMQ } AMSDBType;

// TODO: create a cleaner interface that separates UQ type (FAISS, DeltaUQ) with policy (max, mean).
typedef enum {
FAISSMean = 0,
FAISSMax,
DeltaUQ // Not supported
AMSUQPolicy_BEGIN = 0,
FAISS_Mean,
FAISS_Max,
DeltaUQ_Mean,
DeltaUQ_Max,
koparasy marked this conversation as resolved.
Show resolved Hide resolved
AMSUQPolicy_END
} AMSUQPolicy;

typedef struct ams_conf {
Expand Down
21 changes: 11 additions & 10 deletions src/ml/hdcache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class HDCache

const bool m_use_random;
const int m_knbrs = 0;
const AMSUQPolicy m_policy = AMSUQPolicy::FAISSMean;
const AMSUQPolicy m_policy = AMSUQPolicy::FAISS_Mean;

AMSResourceType cache_location;

Expand Down Expand Up @@ -209,6 +209,11 @@ class HDCache
return cache;
}

if (uqPolicy != AMSUQPolicy::FAISS_Mean &&
uqPolicy != AMSUQPolicy::FAISS_Max)
THROW(std::invalid_argument,
"Invalid UQ policy for hdcache" + std::to_string(uqPolicy));

DBG(UQModule, "Generating new cache under (%s)", cache_path.c_str())
std::shared_ptr<HDCache<TypeInValue>> new_cache =
std::shared_ptr<HDCache<TypeInValue>>(new HDCache<TypeInValue>(
Expand All @@ -224,7 +229,7 @@ class HDCache
{
static std::string random_path("random");
std::shared_ptr<HDCache<TypeInValue>> cache = find_cache(
random_path, resource, AMSUQPolicy::FAISSMean, -1, threshold);
random_path, resource, AMSUQPolicy::FAISS_Mean, -1, threshold);
if (cache) {
DBG(UQModule, "Returning existing cache under (%s)", random_path.c_str())
return cache;
Expand Down Expand Up @@ -547,16 +552,13 @@ class HDCache
// compute means
if (cache_location == AMSResourceType::HOST) {
for (size_t i = 0; i < ndata; ++i) {
CFATAL(UQModule,
m_policy == AMSUQPolicy::DeltaUQ,
"DeltaUQ is not supported yet");
if (m_policy == AMSUQPolicy::FAISSMean) {
if (m_policy == AMSUQPolicy::FAISS_Mean) {
TypeValue mean_dist = std::accumulate(kdists + i * knbrs,
kdists + (i + 1) * knbrs,
0.) *
ook;
is_acceptable[i] = mean_dist < acceptable_error;
} else if (m_policy == AMSUQPolicy::FAISSMax) {
} else if (m_policy == AMSUQPolicy::FAISS_Max) {
// Take the furtherst cluster as the distance metric
TypeValue max_dist =
*std::max_element(&kdists[i * knbrs],
Expand All @@ -566,9 +568,8 @@ class HDCache
}
} else {
CFATAL(UQModule,
(m_policy == AMSUQPolicy::DeltaUQ) ||
(m_policy == AMSUQPolicy::FAISSMax),
"DeltaUQ is not supported yet");
m_policy == AMSUQPolicy::FAISS_Max,
"FAISS Max on device is not supported yet");

ams::Device::computePredicate(
kdists, is_acceptable, ndata, knbrs, acceptable_error);
Expand Down
95 changes: 80 additions & 15 deletions src/ml/surrogate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
#ifndef __AMS_SURROGATE_HPP__
#define __AMS_SURROGATE_HPP__


#include <memory>
#include <stdexcept>
#include <string>
#include <unordered_map>

#ifdef __ENABLE_TORCH__
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <torch/script.h> // One-stop header.
#endif

Expand All @@ -39,7 +39,7 @@ class SurrogateModel
private:
const std::string model_path;
AMSResourceType model_resource;

const bool _is_DeltaUQ;

#ifdef __ENABLE_TORCH__
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -106,6 +106,30 @@ class SurrogateModel
}
}

PERFFASPECT()
inline void tensorToHostArray(at::Tensor tensor,
long numRows,
long numCols,
TypeInValue** array)
{
// Transpose to get continuous memory and
// perform single memcpy.
tensor = tensor.transpose(1, 0);
if (model_resource == AMSResourceType::HOST) {
for (long j = 0; j < numCols; j++) {
auto tmp = tensor[j].contiguous();
TypeInValue* ptr = tmp.data_ptr<TypeInValue>();
HtoHMemcpy(array[j], ptr, sizeof(TypeInValue) * numRows);
}
} else {
for (long j = 0; j < numCols; j++) {
auto tmp = tensor[j].contiguous();
TypeInValue* ptr = tmp.data_ptr<TypeInValue>();
DtoHMemcpy(array[j], ptr, sizeof(TypeInValue) * numRows);
}
}
}

// -------------------------------------------------------------------------
// loading a surrogate model!
// -------------------------------------------------------------------------
Expand Down Expand Up @@ -154,21 +178,37 @@ class SurrogateModel
size_t num_in,
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs)
TypeInValue** outputs,
TypeInValue** outputs_stdev)
{
//torch::NoGradGuard no_grad;
c10::InferenceMode guard(true);
auto input = arrayToTensor(num_elements, num_in, inputs);
input.set_requires_grad(false);
at::Tensor output = module.forward({input}).toTensor().detach();
if (_is_DeltaUQ) {
assert(outputs_stdev && "Expected non-null outputs_stdev");
// The deltauq surrogate returns a tuple of (outputs, outputs_stdev)
auto output_tuple = module.forward({input}).toTuple();
at::Tensor output_mean_tensor =
output_tuple->elements()[0].toTensor().detach();
at::Tensor output_stdev_tensor =
output_tuple->elements()[1].toTensor().detach();
tensorToArray(output_mean_tensor, num_elements, num_out, outputs);
tensorToHostArray(output_stdev_tensor,
num_elements,
num_out,
outputs_stdev);
} else {
at::Tensor output = module.forward({input}).toTensor().detach();
tensorToArray(output, num_elements, num_out, outputs);
}

DBG(Surrogate,
"Evaluate surrogate model (%ld, %ld) -> (%ld, %ld)",
num_elements,
num_in,
num_elements,
num_out);
tensorToArray(output, num_elements, num_out, outputs);
}

#else
Expand All @@ -184,16 +224,20 @@ class SurrogateModel
long num_in,
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs)
TypeInValue** outputs,
TypeInValue** outputs_stdev)
{
}

#endif

SurrogateModel(const char* model_path, AMSResourceType resource = AMSResourceType::HOST)
: model_path(model_path), model_resource(resource)
SurrogateModel(const char* model_path,
AMSResourceType resource = AMSResourceType::HOST,
bool is_DeltaUQ = false)
: model_path(model_path),
model_resource(resource),
_is_DeltaUQ(is_DeltaUQ)
{

if (resource != AMSResourceType::DEVICE)
_load<TypeInValue>(model_path, "cpu");
else
Expand Down Expand Up @@ -226,18 +270,22 @@ class SurrogateModel

static std::shared_ptr<SurrogateModel<TypeInValue>> getInstance(
const char* model_path,
AMSResourceType resource = AMSResourceType::HOST)
AMSResourceType resource = AMSResourceType::HOST,
bool is_DeltaUQ = false)
{
auto model =
SurrogateModel<TypeInValue>::instances.find(std::string(model_path));
if (model != instances.end()) {
// Model Found
auto torch_model = model->second;
if ( resource != torch_model->model_resource)
if (resource != torch_model->model_resource)
throw std::runtime_error(
"Currently we are not supporting loading the same model file on "
"different devices.");

if (is_DeltaUQ != torch_model->is_DeltaUQ())
THROW(std::runtime_error, "Loaded model instance is not DeltaUQ");

if (!same_type<TypeInValue>(torch_model->is_double()))
throw std::runtime_error(
"Requesting model loading of different data types.");
Expand All @@ -252,7 +300,7 @@ class SurrogateModel
DBG(Surrogate, "Generating new model under (%s)", model_path);
std::shared_ptr<SurrogateModel<TypeInValue>> torch_model =
std::shared_ptr<SurrogateModel<TypeInValue>>(
new SurrogateModel<TypeInValue>(model_path, resource));
new SurrogateModel<TypeInValue>(model_path, resource, is_DeltaUQ));
instances.insert(std::make_pair(std::string(model_path), torch_model));
return torch_model;
};
Expand All @@ -268,9 +316,24 @@ class SurrogateModel
size_t num_in,
size_t num_out,
const TypeInValue** inputs,
TypeInValue** outputs)
TypeInValue** outputs,
TypeInValue** outputs_stdev = nullptr)
{
_evaluate(num_elements, num_in, num_out, inputs, outputs);
_evaluate(num_elements, num_in, num_out, inputs, outputs, outputs_stdev);
}

PERFFASPECT()
inline void evaluate(long num_elements,
std::vector<const TypeInValue*> inputs,
std::vector<TypeInValue*> outputs,
std::vector<TypeInValue*> outputs_stdev)
{
_evaluate(num_elements,
inputs.size(),
outputs.size(),
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()),
static_cast<TypeInValue**>(outputs_stdev.data()));
}

PERFFASPECT()
Expand All @@ -282,7 +345,8 @@ class SurrogateModel
inputs.size(),
outputs.size(),
static_cast<const TypeInValue**>(inputs.data()),
static_cast<TypeInValue**>(outputs.data()));
static_cast<TypeInValue**>(outputs.data()),
nullptr);
}

#ifdef __ENABLE_TORCH__
Expand All @@ -295,6 +359,7 @@ class SurrogateModel
}
#endif

bool is_DeltaUQ() { return _is_DeltaUQ; }
};

template <typename T>
Expand Down
Loading