Skip to content

Commit

Permalink
cv-model-performance vignette (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
cole-brokamp committed Mar 7, 2024
1 parent 598b870 commit 7bb6b1b
Show file tree
Hide file tree
Showing 5 changed files with 255 additions and 4 deletions.
3 changes: 3 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ Imports:
withr
RoxygenNote: 7.2.3
Suggests:
ggplot2,
knitr,
rmarkdown,
curl,
viridis,
testthat (>= 3.0.0)
Config/testthat/edition: 3
Config/testthat/parallel: true
Expand Down
7 changes: 6 additions & 1 deletion inst/train_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ d_train <- assemble_predictors(d$s2, d$dates, quiet = FALSE)

d_train$conc <- unlist(d$conc)

message("saving training_data")
train_file_output_path <- fs::path(tools::R_user_dir("appc", "data"), "training_data.rds")
saveRDS(d_train, train_file_output_path)
message("saved training_data.rds (", fs::file_info(file_output_path)$size, ") to ", file_output_path)

pred_names <-
c(
"x", "y",
Expand Down Expand Up @@ -85,7 +90,7 @@ grf <-
)

message("saving GRF")
file_output_path <- fs::path_wd("rf_pm.rds")
file_output_path <- fs::path(tools::R_user_dir("appc", "data"), "rf_pm.rds")
saveRDS(grf, file_output_path)
message("saved rf_pm.rds (", fs::file_info(file_output_path)$size, ") to ", file_output_path)

Expand Down
6 changes: 3 additions & 3 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ docker_tool:
train_model:
Rscript --verbose inst/train_model.R

# upload grf model to current github release
# upload grf model and training data to current github release
release_model:
cp rf_pm.rds "{{geomarker_folder}}"/rf_pm.rds
gh release upload v{{pkg_version}} "rf_pm.rds"
gh release upload v{{pkg_version}} "{{geomarker_folder}}"/rf_pm.rds
gh release upload v{{pkg_version}} "{{geomarker_folder}}"/training_data.rds

# create CV accuracy report
create_report:
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/test-merra-daily.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ earthdata_secrets <- Sys.getenv(c("EARTHDATA_USERNAME", "EARTHDATA_PASSWORD"), u
skip_if(any(is.na(earthdata_secrets)), message = "no earthdata credentials found")
skip_if_offline()
skip_if(Sys.getenv("CI") == "", "not on a CI platform")
skip_if(is.null(curl::nslookup("gesdisc.eosdis.nasa.gov", error = FALSE)), "NASA GES DISC not online")

test_that("getting daily merra from GES DISC works", {
# "normal" pattern
Expand Down
242 changes: 242 additions & 0 deletions vignettes/cv-model-performance.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
---
title: "cv-model-performance"
output: rmarkdown::html_vignette
vignette: >
%\VignetteIndexEntry{cv-model-performance}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---

```{r setup, include=FALSE}
knitr::opts_chunk$set(echo = FALSE, message = FALSE, warning = FALSE)
# load development version if developing (instead of currently installed version)
if (file.exists("./inst") || basename(getwd()) == "inst") {
devtools::load_all()
} else {
library(appc)
}
library(grf)
library(dplyr)
library(ggplot2)
the_theme <-
ggplot2::theme_light(base_size = 11) +
ggplot2::theme(
panel.background = ggplot2::element_rect(fill = "white", colour = NA),
panel.border = ggplot2::element_rect(fill = NA, colour = "grey20"),
panel.grid.major = ggplot2::element_line(colour = "grey92"),
panel.grid.minor = ggplot2::element_blank(),
strip.background = ggplot2::element_rect(fill = "grey92", colour = "grey20"),
strip.text = ggplot2::element_text(color = "grey20"), legend.key = ggplot2::element_rect(fill = "white", colour = NA),
complete = TRUE
)
```

```{r "load model and predictions"}
grf_file <- fs::path(tools::R_user_dir("appc", "data"), "rf_pm.rds")
if(!file.exists(grf_file)) appc:::install_released_data("rf_pm.rds")
grf <- readRDS(fs::path(tools::R_user_dir("appc", "data"), "rf_pm.rds"))
train_file <-fs::path(tools::R_user_dir("appc", "data"), "training_data.rds")
if(!file.exists(train_file)) appc:::install_released_data("training_data.rds")
d_train <- readRDS(fs::path(tools::R_user_dir("appc", "data"), "training_data.rds"))
```

## Variable Importance

```{r "variable importance"}
pred_names <- names(grf$X.orig)
tibble(importance = round(variable_importance(grf), 3),
variable = pred_names) |>
arrange(desc(importance)) |>
knitr::kable()
```

## LOLO Model Accuracy

```{r "estimate variance of predictions"}
d <-
grf |>
predict(estimate.variance = TRUE) |>
tibble::as_tibble() |>
transmute(
pred = signif(predictions, 2),
se = signif(sqrt(variance.estimates), 2),
conc = grf$Y.orig) |>
bind_cols(select(d_train, s2, date))
d <- d |>
mutate(lci = pred - se * qnorm(0.025, lower.tail = FALSE),
uci = pred + se * qnorm(0.025, lower.tail = FALSE),
ci_covered = conc < uci & conc > lci)
```

```{r "add temporal components"}
d$year <- as.numeric(format(d$date, "%Y"))
d$month <- as.numeric(format(d$date, "%m"))
d$week <- as.numeric(format(d$date, "%U"))
```

Leave-one-location-out (LOLO) accuracy is calculated by using out of bag predictions from the trained random forest with resample clustering by the location.

Accuracy metrics are calculated for each left out location and then summarized using the median accuracy statistic across all locations. This most closely captures the performance in a real-world scenario where we are trying to predict air pollution between 2017 and 2023 in a place where it was not measured.

Each left-out location, or AQS monitor, contains a variable number of days with air pollution measurements. This depends on the frequency of the daily measurements as well as when the monitoring station was initiated or deprecated. Some stations-time groupings only have a single measurement; exclude any station or station-time grouping that has 4 or less observations.

### Daily

```{r "daily accuracy"}
d |>
nest_by(s2) |>
mutate(n_obs = c(nrow(data))) |>
filter(n_obs > 4) |>
mutate(mae = c(median(abs(data$conc - data$pred))),
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate),
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |>
ungroup() |>
summarize(mae = median(mae),
rho = median(rho),
ci_coverage = scales::percent(median(ci_coverage)),
median_n_obs_per_grouping = median(n_obs)) |>
knitr::kable(digits = 2)
```

#### Actual PM2.5 Concentrations vs LOLO Daily Predictions

```{r "daily pred vs actual plot"}
d |>
ggplot(aes(conc, pred)) +
stat_bin_hex(binwidth = c(0.05, 0.05)) +
viridis::scale_fill_viridis(option = "C", trans = "log10", name = "Number \nof points") +
geom_abline(slope = 1, intercept = 0, lty = 2, alpha = 0.8, color = "darkgrey") +
scale_x_log10(limits = c(1, 650)) + scale_y_log10(limits = c(1, 650)) +
xlab(expression(Observed ~ paste(PM[2.5], " (", mu, "g/", m^{3}, ") "))) +
ylab(expression(CV ~ Predicted ~ paste(PM[2.5], " (", mu, "g/", m^{3}, ") "))) +
the_theme +
theme(legend.position = c(0.85, 0.2)) +
coord_fixed()
```

#### Daily Prediction Accuracies per Calendar Year

```{r "daily accuracies per year"}
d |>
nest_by(s2, year) |>
mutate(n_obs = c(nrow(data))) |>
filter(n_obs > 4) |>
mutate(mae = c(median(abs(data$conc - data$pred))),
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate),
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |>
group_by(year) |>
summarize(mae = median(mae),
rho = median(rho),
ci_coverage = scales::percent(median(ci_coverage)),
median_n_obs_per_grouping = median(n_obs)) |>
knitr::kable(digits = 2)
```

### Monthly

Exclude stations with 4 or less total monthly observations.

```{r "monthly accuracies"}
d |>
group_by(s2, year, month) |>
summarize(pred = mean(pred),
conc = mean(conc),
se = mean(sqrt(se^2))) |>
mutate(lci = pred - se * qnorm(0.025, lower.tail = FALSE),
uci = pred + se * qnorm(0.025, lower.tail = FALSE),
ci_covered = conc < uci & conc > lci) |>
ungroup() |>
nest_by(s2) |>
mutate(n_obs = c(nrow(data))) |>
filter(n_obs > 4) |>
mutate(mae = c(median(abs(data$conc - data$pred))),
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate),
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |>
ungroup() |>
summarize(mae = median(mae),
rho = median(rho),
ci_coverage = scales::percent(median(ci_coverage)),
median_n_obs_per_grouping = median(n_obs)) |>
knitr::kable(digits = 2)
```

### Annual

Exclude stations with 4 or less total annual observations.

```{r "yearly accuracies"}
d |>
group_by(s2, year) |>
summarize(pred = mean(pred),
conc = mean(conc),
se = mean(sqrt(se^2))) |>
mutate(lci = pred - se * qnorm(0.025, lower.tail = FALSE),
uci = pred + se * qnorm(0.025, lower.tail = FALSE),
ci_covered = conc < uci & conc > lci) |>
ungroup() |>
nest_by(s2) |>
mutate(n_obs = c(nrow(data))) |>
filter(n_obs > 4) |>
mutate(mae = c(median(abs(data$conc - data$pred))),
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate),
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |>
ungroup() |>
summarize(mae = median(mae),
rho = median(rho),
ci_coverage = scales::percent(median(ci_coverage)),
median_n_obs_per_grouping = median(n_obs)) |>
knitr::kable(digits = 2)
```

## Median LOLO Accuracy Per Spatial Aggregation Period

```{r "spatial variation in accuracies"}
library(s2)
the_map_theme <-
the_theme +
ggplot2::theme(
axis.ticks = ggplot2::element_blank(),
axis.text = ggplot2::element_blank(),
axis.title = ggplot2::element_blank(),
rect = ggplot2::element_blank(),
line = ggplot2::element_blank(),
panel.grid = ggplot2::element_blank(),
plot.margin = ggplot2::margin(1, 1, 1, 1, "cm"),
legend.key.height = ggplot2::unit(1, "cm"),
legend.key.width = ggplot2::unit(0.3, "cm")
)
d |>
mutate(s2_ = s2_cell_parent(s2, 5)) |>
nest_by(s2, s2_) |>
mutate(n_obs = c(nrow(data))) |>
filter(n_obs > 4) |>
mutate(mae = c(median(abs(data$conc - data$pred))),
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate),
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |>
group_by(s2_) |>
summarize(mae = median(mae),
s2_ = unique(s2_)) |>
mutate(geometry = s2_cell_polygon(s2_)) |>
sf::st_as_sf() |>
ggplot() +
geom_sf(aes(fill = mae), size = 0) +
coord_sf(crs = 5072) +
the_map_theme +
scale_fill_viridis_c() +
theme(legend.position = c(0.25, 0.1),
legend.direction = "horizontal",
legend.title = element_text(size = 11, family = "sans"),
legend.text = element_text(size = 11),
legend.box = "hortizontal",
legend.key.height = unit(4, "mm"),
legend.key.width = unit(9, "mm"),
strip.text.x = element_text(size = 11, face = "bold", vjust = 1),
strip.text.y = element_text(size = 11, face = "bold")) +
labs(fill = expression(paste(MAE~(ug/m^3)))) +
guides(fill = guide_colorbar(title.position = "top", title.hjust = 0.5))
```

0 comments on commit 7bb6b1b

Please sign in to comment.