From 7bb6b1b05680a61749ee22d74f883145fab23223 Mon Sep 17 00:00:00 2001 From: Cole Brokamp Date: Thu, 7 Mar 2024 07:54:20 -0500 Subject: [PATCH] cv-model-performance vignette (#56) --- DESCRIPTION | 3 + inst/train_model.R | 7 +- justfile | 6 +- tests/testthat/test-merra-daily.R | 1 + vignettes/cv-model-performance.Rmd | 242 +++++++++++++++++++++++++++++ 5 files changed, 255 insertions(+), 4 deletions(-) create mode 100644 vignettes/cv-model-performance.Rmd diff --git a/DESCRIPTION b/DESCRIPTION index d2ca1d8..b93de72 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -33,8 +33,11 @@ Imports: withr RoxygenNote: 7.2.3 Suggests: + ggplot2, knitr, rmarkdown, + curl, + viridis, testthat (>= 3.0.0) Config/testthat/edition: 3 Config/testthat/parallel: true diff --git a/inst/train_model.R b/inst/train_model.R index b769acf..40cb6aa 100644 --- a/inst/train_model.R +++ b/inst/train_model.R @@ -45,6 +45,11 @@ d_train <- assemble_predictors(d$s2, d$dates, quiet = FALSE) d_train$conc <- unlist(d$conc) +message("saving training_data") +train_file_output_path <- fs::path(tools::R_user_dir("appc", "data"), "training_data.rds") +saveRDS(d_train, train_file_output_path) +message("saved training_data.rds (", fs::file_info(file_output_path)$size, ") to ", file_output_path) + pred_names <- c( "x", "y", @@ -85,7 +90,7 @@ grf <- ) message("saving GRF") -file_output_path <- fs::path_wd("rf_pm.rds") +file_output_path <- fs::path(tools::R_user_dir("appc", "data"), "rf_pm.rds") saveRDS(grf, file_output_path) message("saved rf_pm.rds (", fs::file_info(file_output_path)$size, ") to ", file_output_path) diff --git a/justfile b/justfile index 566d950..42e3181 100644 --- a/justfile +++ b/justfile @@ -41,10 +41,10 @@ docker_tool: train_model: Rscript --verbose inst/train_model.R -# upload grf model to current github release +# upload grf model and training data to current github release release_model: - cp rf_pm.rds "{{geomarker_folder}}"/rf_pm.rds - gh release upload v{{pkg_version}} "rf_pm.rds" + gh release upload v{{pkg_version}} "{{geomarker_folder}}"/rf_pm.rds + gh release upload v{{pkg_version}} "{{geomarker_folder}}"/training_data.rds # create CV accuracy report create_report: diff --git a/tests/testthat/test-merra-daily.R b/tests/testthat/test-merra-daily.R index 610eda9..67205d8 100644 --- a/tests/testthat/test-merra-daily.R +++ b/tests/testthat/test-merra-daily.R @@ -3,6 +3,7 @@ earthdata_secrets <- Sys.getenv(c("EARTHDATA_USERNAME", "EARTHDATA_PASSWORD"), u skip_if(any(is.na(earthdata_secrets)), message = "no earthdata credentials found") skip_if_offline() skip_if(Sys.getenv("CI") == "", "not on a CI platform") +skip_if(is.null(curl::nslookup("gesdisc.eosdis.nasa.gov", error = FALSE)), "NASA GES DISC not online") test_that("getting daily merra from GES DISC works", { # "normal" pattern diff --git a/vignettes/cv-model-performance.Rmd b/vignettes/cv-model-performance.Rmd new file mode 100644 index 0000000..85d9c6d --- /dev/null +++ b/vignettes/cv-model-performance.Rmd @@ -0,0 +1,242 @@ +--- +title: "cv-model-performance" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{cv-model-performance} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + +```{r setup, include=FALSE} +knitr::opts_chunk$set(echo = FALSE, message = FALSE, warning = FALSE) + +# load development version if developing (instead of currently installed version) +if (file.exists("./inst") || basename(getwd()) == "inst") { + devtools::load_all() +} else { + library(appc) +} +library(grf) +library(dplyr) +library(ggplot2) + +the_theme <- + ggplot2::theme_light(base_size = 11) + + ggplot2::theme( + panel.background = ggplot2::element_rect(fill = "white", colour = NA), + panel.border = ggplot2::element_rect(fill = NA, colour = "grey20"), + panel.grid.major = ggplot2::element_line(colour = "grey92"), + panel.grid.minor = ggplot2::element_blank(), + strip.background = ggplot2::element_rect(fill = "grey92", colour = "grey20"), + strip.text = ggplot2::element_text(color = "grey20"), legend.key = ggplot2::element_rect(fill = "white", colour = NA), + complete = TRUE + ) +``` + +```{r "load model and predictions"} +grf_file <- fs::path(tools::R_user_dir("appc", "data"), "rf_pm.rds") +if(!file.exists(grf_file)) appc:::install_released_data("rf_pm.rds") +grf <- readRDS(fs::path(tools::R_user_dir("appc", "data"), "rf_pm.rds")) + +train_file <-fs::path(tools::R_user_dir("appc", "data"), "training_data.rds") +if(!file.exists(train_file)) appc:::install_released_data("training_data.rds") +d_train <- readRDS(fs::path(tools::R_user_dir("appc", "data"), "training_data.rds")) +``` + +## Variable Importance + +```{r "variable importance"} +pred_names <- names(grf$X.orig) +tibble(importance = round(variable_importance(grf), 3), + variable = pred_names) |> + arrange(desc(importance)) |> + knitr::kable() +``` + +## LOLO Model Accuracy + +```{r "estimate variance of predictions"} +d <- + grf |> + predict(estimate.variance = TRUE) |> + tibble::as_tibble() |> + transmute( + pred = signif(predictions, 2), + se = signif(sqrt(variance.estimates), 2), + conc = grf$Y.orig) |> + bind_cols(select(d_train, s2, date)) +d <- d |> + mutate(lci = pred - se * qnorm(0.025, lower.tail = FALSE), + uci = pred + se * qnorm(0.025, lower.tail = FALSE), + ci_covered = conc < uci & conc > lci) +``` + +```{r "add temporal components"} +d$year <- as.numeric(format(d$date, "%Y")) +d$month <- as.numeric(format(d$date, "%m")) +d$week <- as.numeric(format(d$date, "%U")) +``` + +Leave-one-location-out (LOLO) accuracy is calculated by using out of bag predictions from the trained random forest with resample clustering by the location. + +Accuracy metrics are calculated for each left out location and then summarized using the median accuracy statistic across all locations. This most closely captures the performance in a real-world scenario where we are trying to predict air pollution between 2017 and 2023 in a place where it was not measured. + +Each left-out location, or AQS monitor, contains a variable number of days with air pollution measurements. This depends on the frequency of the daily measurements as well as when the monitoring station was initiated or deprecated. Some stations-time groupings only have a single measurement; exclude any station or station-time grouping that has 4 or less observations. + +### Daily + +```{r "daily accuracy"} +d |> + nest_by(s2) |> + mutate(n_obs = c(nrow(data))) |> + filter(n_obs > 4) |> + mutate(mae = c(median(abs(data$conc - data$pred))), + rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), + ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> + ungroup() |> + summarize(mae = median(mae), + rho = median(rho), + ci_coverage = scales::percent(median(ci_coverage)), + median_n_obs_per_grouping = median(n_obs)) |> + knitr::kable(digits = 2) +``` + +#### Actual PM2.5 Concentrations vs LOLO Daily Predictions + +```{r "daily pred vs actual plot"} +d |> + ggplot(aes(conc, pred)) + + stat_bin_hex(binwidth = c(0.05, 0.05)) + + viridis::scale_fill_viridis(option = "C", trans = "log10", name = "Number \nof points") + + geom_abline(slope = 1, intercept = 0, lty = 2, alpha = 0.8, color = "darkgrey") + + scale_x_log10(limits = c(1, 650)) + scale_y_log10(limits = c(1, 650)) + + xlab(expression(Observed ~ paste(PM[2.5], " (", mu, "g/", m^{3}, ") "))) + + ylab(expression(CV ~ Predicted ~ paste(PM[2.5], " (", mu, "g/", m^{3}, ") "))) + + the_theme + + theme(legend.position = c(0.85, 0.2)) + + coord_fixed() +``` + +#### Daily Prediction Accuracies per Calendar Year + +```{r "daily accuracies per year"} +d |> + nest_by(s2, year) |> + mutate(n_obs = c(nrow(data))) |> + filter(n_obs > 4) |> + mutate(mae = c(median(abs(data$conc - data$pred))), + rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), + ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> + group_by(year) |> + summarize(mae = median(mae), + rho = median(rho), + ci_coverage = scales::percent(median(ci_coverage)), + median_n_obs_per_grouping = median(n_obs)) |> + knitr::kable(digits = 2) +``` + +### Monthly + +Exclude stations with 4 or less total monthly observations. + +```{r "monthly accuracies"} +d |> + group_by(s2, year, month) |> + summarize(pred = mean(pred), + conc = mean(conc), + se = mean(sqrt(se^2))) |> + mutate(lci = pred - se * qnorm(0.025, lower.tail = FALSE), + uci = pred + se * qnorm(0.025, lower.tail = FALSE), + ci_covered = conc < uci & conc > lci) |> + ungroup() |> + nest_by(s2) |> + mutate(n_obs = c(nrow(data))) |> + filter(n_obs > 4) |> + mutate(mae = c(median(abs(data$conc - data$pred))), + rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), + ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> + ungroup() |> + summarize(mae = median(mae), + rho = median(rho), + ci_coverage = scales::percent(median(ci_coverage)), + median_n_obs_per_grouping = median(n_obs)) |> + knitr::kable(digits = 2) +``` + +### Annual + +Exclude stations with 4 or less total annual observations. + +```{r "yearly accuracies"} +d |> + group_by(s2, year) |> + summarize(pred = mean(pred), + conc = mean(conc), + se = mean(sqrt(se^2))) |> + mutate(lci = pred - se * qnorm(0.025, lower.tail = FALSE), + uci = pred + se * qnorm(0.025, lower.tail = FALSE), + ci_covered = conc < uci & conc > lci) |> + ungroup() |> + nest_by(s2) |> + mutate(n_obs = c(nrow(data))) |> + filter(n_obs > 4) |> + mutate(mae = c(median(abs(data$conc - data$pred))), + rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), + ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> + ungroup() |> + summarize(mae = median(mae), + rho = median(rho), + ci_coverage = scales::percent(median(ci_coverage)), + median_n_obs_per_grouping = median(n_obs)) |> + knitr::kable(digits = 2) +``` + +## Median LOLO Accuracy Per Spatial Aggregation Period + +```{r "spatial variation in accuracies"} +library(s2) + +the_map_theme <- + the_theme + + ggplot2::theme( + axis.ticks = ggplot2::element_blank(), + axis.text = ggplot2::element_blank(), + axis.title = ggplot2::element_blank(), + rect = ggplot2::element_blank(), + line = ggplot2::element_blank(), + panel.grid = ggplot2::element_blank(), + plot.margin = ggplot2::margin(1, 1, 1, 1, "cm"), + legend.key.height = ggplot2::unit(1, "cm"), + legend.key.width = ggplot2::unit(0.3, "cm") + ) + +d |> + mutate(s2_ = s2_cell_parent(s2, 5)) |> + nest_by(s2, s2_) |> + mutate(n_obs = c(nrow(data))) |> + filter(n_obs > 4) |> + mutate(mae = c(median(abs(data$conc - data$pred))), + rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), + ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> + group_by(s2_) |> + summarize(mae = median(mae), + s2_ = unique(s2_)) |> + mutate(geometry = s2_cell_polygon(s2_)) |> + sf::st_as_sf() |> + ggplot() + + geom_sf(aes(fill = mae), size = 0) + + coord_sf(crs = 5072) + + the_map_theme + + scale_fill_viridis_c() + + theme(legend.position = c(0.25, 0.1), + legend.direction = "horizontal", + legend.title = element_text(size = 11, family = "sans"), + legend.text = element_text(size = 11), + legend.box = "hortizontal", + legend.key.height = unit(4, "mm"), + legend.key.width = unit(9, "mm"), + strip.text.x = element_text(size = 11, face = "bold", vjust = 1), + strip.text.y = element_text(size = 11, face = "bold")) + + labs(fill = expression(paste(MAE~(ug/m^3)))) + + guides(fill = guide_colorbar(title.position = "top", title.hjust = 0.5)) +```