Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/yardstick #12

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ Suggests:
knitr,
rmarkdown,
testthat (>= 3.0.0),
covr
covr,
yardstick
VignetteBuilder: knitr
Config/testthat/edition: 3
Collate:
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ export(luz_metric_accuracy)
export(luz_metric_binary_accuracy)
export(luz_metric_binary_accuracy_with_logits)
export(luz_metric_mae)
export(luz_metric_yardstick)
export(luz_metric_mse)
export(luz_metric_rmse)
export(luz_save)
Expand Down
99 changes: 88 additions & 11 deletions R/metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ luz_metric_accuracy <- luz_metric(
#' @export
luz_metric_binary_accuracy <- luz_metric(
abbrev = "Acc",
inherit = luz_metric_accuracy,
initialize = function(threshold = 0.5) {
self$correct <- 0
self$total <- 0
Expand Down Expand Up @@ -174,20 +175,13 @@ luz_metric_binary_accuracy_with_logits <- luz_metric(
}
)

#' Internal metric that is used to track the loss
#' @noRd
luz_metric_loss_average <- luz_metric(
abbrev = "Loss",
luz_metric_average <- luz_metric(
name = "average",
initialize = function() {
self$values <- list()
},
update = function(preds, targets) {
if (length(ctx$loss) == 1)
loss <- ctx$loss[[1]]
else
loss <- ctx$loss

self$values[[length(self$values) + 1]] <- loss
update = function(values, ...) {
self$values[[length(self$values) + 1]] <- values
},
average_metric = function(x) {
if (is.numeric(x[[1]]) || inherits(x[[1]], "torch_tensor"))
Expand Down Expand Up @@ -218,6 +212,21 @@ luz_metric_loss_average <- luz_metric(
}
)

#' Internal metric that is used to track the loss
#' @noRd
luz_metric_loss_average <- luz_metric(
abbrev = "Loss",
inherit = luz_metric_average,
update = function(preds, targets) {
if (length(ctx$loss) == 1)
loss <- ctx$loss[[1]]
else
loss <- ctx$loss

super$update(loss)
}
)

#' Mean absolute error
#'
#' Computes the mean absolute error.
Expand Down Expand Up @@ -250,6 +259,73 @@ luz_metric_mae <- luz_metric(
}
)


#' Computes the average for any yardstick metric
#'
#' Allows using any yardstick metric with luz.
#'
#' @param metric_nm Name of the metric from yardstick (without the `_vec`).
#' For example `'accuracy'`, `'mae'`, etc.
#' @param transform A function of `preds` and `targets` that will be applied
#' to the values before computing the metric. This function is called after
#' moving `preds` and `targets` to R vectors.
#' @param ... Additional parameters forwarded to the metric implementation in
#' yardstick.
#'
#' @section Warning:
#' The only transformation we do on the predicted values and in the
#' moving to R with [torch::as_array()]. However, many metrics in yardstick
#' expect that values are factors, or in other formats. In that case you can use
#' the `transform` argument to specify a transformation.
#'
#' @examples
#' if (torch::torch_is_installed()) {
#' x <- torch::torch_randn(100)
#' y <- torch::torch_randn(100)
#'
#' m <- luz_metric_yardstick("mae")
#' m <- m$new()
#'
#' m$update(x, y)
#' o <- m$compute()
#' }
#' @returns
#' A luz metric object.
#'
#' @export
luz_metric_yardstick <- luz_metric(
name = "yardstick_metric",
inherit = luz_metric_average,
initialize = function(metric_nm, transform = NULL, ...) {
self$abbrev <- metric_nm
self$metric_fn <- getFromNamespace(paste0(metric_nm, "_vec"), "yardstick")
self$args <- rlang::list2(...)
self$transform <- transform
},
update = function(preds, targets) {
preds <- as.array(preds$cpu())
targets <- as.array(targets$cpu())

if (!is.null(self$transform))
transformed <- self$transform(preds, targets)
else
transformed <- list(preds, targets)

values <- do.call(
self$metric_fn,
append(
self$args,
list(
truth = transformed[[2]],
estimate = transformed[[1]]
)
)
)
super$update(values)
}
)


#' Mean squared error
#'
#' Computes the mean squared error
Expand All @@ -275,6 +351,7 @@ luz_metric_mse <- luz_metric(
}
)


#' Root mean squared error
#'
#' Computes the root mean squared error.
Expand Down
45 changes: 45 additions & 0 deletions man/luz_metric_yardstick.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions tests/testthat/test-metrics.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,25 @@ test_that("mae works", {
expect_equal(o, eo, tolerance = 1e-5)
})

test_that("yardstick metrics", {

x <- torch::torch_randn(100)
y <- torch::torch_randn(100)

m <- luz_metric_yardstick("mae")
m <- m$new()
m$update(x, y)
o <- m$compute()
eo <- mean(abs(as.array(x) - as.array(y)))

m <- luz_metric_yardstick("rmse")
m$update(x, y)
o <- m$compute()
eo <- sqrt(mean((as.array(x) - as.array(y))^2))

expect_equal(o, eo, tolerance = 1e-5)
}

test_that("mse works", {

x <- torch::torch_randn(100, 100)
Expand All @@ -56,6 +75,7 @@ test_that("mse works", {
eo <- mean((as.array(x) - as.array(y))^2)

expect_equal(o, eo, tolerance = 1e-5)

})

test_that("rmse works", {
Expand All @@ -64,13 +84,15 @@ test_that("rmse works", {
y <- torch::torch_randn(100, 100)

m <- luz_metric_rmse()

m <- m$new()

m$update(x, y)
o <- m$compute()
eo <- sqrt(mean((as.array(x) - as.array(y))^2))

expect_equal(o, eo, tolerance = 1e-5)

})

test_that("binary accuracy with logits", {
Expand All @@ -87,4 +109,5 @@ test_that("binary accuracy with logits", {
mean(as.array(x > 0) == as.array(y))
)


})