Skip to content

Commit

Permalink
corrected rll implementation (#340)
Browse files Browse the repository at this point in the history
* corrected rll implementation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
M-R-Schaefer and pre-commit-ci[bot] committed Sep 19, 2024
1 parent 6128749 commit 6da554a
Showing 1 changed file with 22 additions and 7 deletions.
29 changes: 22 additions & 7 deletions ipsuite/analysis/model/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,28 @@ def nlls(pred, std, true):
return nll


def comptue_rll(pred, std, true):
errors = np.abs(pred - true)
rmse = compute_rmse(errors)
numerator = np.sum(nlls(errors, std, true) - nlls(errors, rmse, true))
denominator = np.sum(nlls(errors, errors, true) - nlls(errors, rmse, true))
rll = numerator / denominator * 100
return rll
def nll(pred, std, true):
tmp = nlls(pred, std, true)
return np.mean(tmp)


def comptue_rll(inputs, std, target):
"""Compute relative log likelihood
Adapted from https://github.com/bananenpampe/DPOSE
"""

mse = np.mean((inputs - target) ** 2)
uncertainty_estimate = (inputs - target) ** 2

ll_best = nll(inputs, np.sqrt(uncertainty_estimate), target)

ll_worst_case_best_RMSE = nll(inputs, np.sqrt(np.ones_like(std) * mse), target)

ll_actual = nll(inputs, std, target)

rll = (ll_actual - ll_worst_case_best_RMSE) / (ll_best - ll_worst_case_best_RMSE)

return rll * 100


def compute_uncertainty_metrics(pred, std, true):
Expand Down

0 comments on commit 6da554a

Please sign in to comment.