-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
598b870
commit 7bb6b1b
Showing
5 changed files
with
255 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
--- | ||
title: "cv-model-performance" | ||
output: rmarkdown::html_vignette | ||
vignette: > | ||
%\VignetteIndexEntry{cv-model-performance} | ||
%\VignetteEngine{knitr::rmarkdown} | ||
%\VignetteEncoding{UTF-8} | ||
--- | ||
|
||
```{r setup, include=FALSE} | ||
knitr::opts_chunk$set(echo = FALSE, message = FALSE, warning = FALSE) | ||
# load development version if developing (instead of currently installed version) | ||
if (file.exists("./inst") || basename(getwd()) == "inst") { | ||
devtools::load_all() | ||
} else { | ||
library(appc) | ||
} | ||
library(grf) | ||
library(dplyr) | ||
library(ggplot2) | ||
the_theme <- | ||
ggplot2::theme_light(base_size = 11) + | ||
ggplot2::theme( | ||
panel.background = ggplot2::element_rect(fill = "white", colour = NA), | ||
panel.border = ggplot2::element_rect(fill = NA, colour = "grey20"), | ||
panel.grid.major = ggplot2::element_line(colour = "grey92"), | ||
panel.grid.minor = ggplot2::element_blank(), | ||
strip.background = ggplot2::element_rect(fill = "grey92", colour = "grey20"), | ||
strip.text = ggplot2::element_text(color = "grey20"), legend.key = ggplot2::element_rect(fill = "white", colour = NA), | ||
complete = TRUE | ||
) | ||
``` | ||
|
||
```{r "load model and predictions"} | ||
grf_file <- fs::path(tools::R_user_dir("appc", "data"), "rf_pm.rds") | ||
if(!file.exists(grf_file)) appc:::install_released_data("rf_pm.rds") | ||
grf <- readRDS(fs::path(tools::R_user_dir("appc", "data"), "rf_pm.rds")) | ||
train_file <-fs::path(tools::R_user_dir("appc", "data"), "training_data.rds") | ||
if(!file.exists(train_file)) appc:::install_released_data("training_data.rds") | ||
d_train <- readRDS(fs::path(tools::R_user_dir("appc", "data"), "training_data.rds")) | ||
``` | ||
|
||
## Variable Importance | ||
|
||
```{r "variable importance"} | ||
pred_names <- names(grf$X.orig) | ||
tibble(importance = round(variable_importance(grf), 3), | ||
variable = pred_names) |> | ||
arrange(desc(importance)) |> | ||
knitr::kable() | ||
``` | ||
|
||
## LOLO Model Accuracy | ||
|
||
```{r "estimate variance of predictions"} | ||
d <- | ||
grf |> | ||
predict(estimate.variance = TRUE) |> | ||
tibble::as_tibble() |> | ||
transmute( | ||
pred = signif(predictions, 2), | ||
se = signif(sqrt(variance.estimates), 2), | ||
conc = grf$Y.orig) |> | ||
bind_cols(select(d_train, s2, date)) | ||
d <- d |> | ||
mutate(lci = pred - se * qnorm(0.025, lower.tail = FALSE), | ||
uci = pred + se * qnorm(0.025, lower.tail = FALSE), | ||
ci_covered = conc < uci & conc > lci) | ||
``` | ||
|
||
```{r "add temporal components"} | ||
d$year <- as.numeric(format(d$date, "%Y")) | ||
d$month <- as.numeric(format(d$date, "%m")) | ||
d$week <- as.numeric(format(d$date, "%U")) | ||
``` | ||
|
||
Leave-one-location-out (LOLO) accuracy is calculated by using out of bag predictions from the trained random forest with resample clustering by the location. | ||
|
||
Accuracy metrics are calculated for each left out location and then summarized using the median accuracy statistic across all locations. This most closely captures the performance in a real-world scenario where we are trying to predict air pollution between 2017 and 2023 in a place where it was not measured. | ||
|
||
Each left-out location, or AQS monitor, contains a variable number of days with air pollution measurements. This depends on the frequency of the daily measurements as well as when the monitoring station was initiated or deprecated. Some stations-time groupings only have a single measurement; exclude any station or station-time grouping that has 4 or less observations. | ||
|
||
### Daily | ||
|
||
```{r "daily accuracy"} | ||
d |> | ||
nest_by(s2) |> | ||
mutate(n_obs = c(nrow(data))) |> | ||
filter(n_obs > 4) |> | ||
mutate(mae = c(median(abs(data$conc - data$pred))), | ||
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), | ||
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> | ||
ungroup() |> | ||
summarize(mae = median(mae), | ||
rho = median(rho), | ||
ci_coverage = scales::percent(median(ci_coverage)), | ||
median_n_obs_per_grouping = median(n_obs)) |> | ||
knitr::kable(digits = 2) | ||
``` | ||
|
||
#### Actual PM2.5 Concentrations vs LOLO Daily Predictions | ||
|
||
```{r "daily pred vs actual plot"} | ||
d |> | ||
ggplot(aes(conc, pred)) + | ||
stat_bin_hex(binwidth = c(0.05, 0.05)) + | ||
viridis::scale_fill_viridis(option = "C", trans = "log10", name = "Number \nof points") + | ||
geom_abline(slope = 1, intercept = 0, lty = 2, alpha = 0.8, color = "darkgrey") + | ||
scale_x_log10(limits = c(1, 650)) + scale_y_log10(limits = c(1, 650)) + | ||
xlab(expression(Observed ~ paste(PM[2.5], " (", mu, "g/", m^{3}, ") "))) + | ||
ylab(expression(CV ~ Predicted ~ paste(PM[2.5], " (", mu, "g/", m^{3}, ") "))) + | ||
the_theme + | ||
theme(legend.position = c(0.85, 0.2)) + | ||
coord_fixed() | ||
``` | ||
|
||
#### Daily Prediction Accuracies per Calendar Year | ||
|
||
```{r "daily accuracies per year"} | ||
d |> | ||
nest_by(s2, year) |> | ||
mutate(n_obs = c(nrow(data))) |> | ||
filter(n_obs > 4) |> | ||
mutate(mae = c(median(abs(data$conc - data$pred))), | ||
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), | ||
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> | ||
group_by(year) |> | ||
summarize(mae = median(mae), | ||
rho = median(rho), | ||
ci_coverage = scales::percent(median(ci_coverage)), | ||
median_n_obs_per_grouping = median(n_obs)) |> | ||
knitr::kable(digits = 2) | ||
``` | ||
|
||
### Monthly | ||
|
||
Exclude stations with 4 or less total monthly observations. | ||
|
||
```{r "monthly accuracies"} | ||
d |> | ||
group_by(s2, year, month) |> | ||
summarize(pred = mean(pred), | ||
conc = mean(conc), | ||
se = mean(sqrt(se^2))) |> | ||
mutate(lci = pred - se * qnorm(0.025, lower.tail = FALSE), | ||
uci = pred + se * qnorm(0.025, lower.tail = FALSE), | ||
ci_covered = conc < uci & conc > lci) |> | ||
ungroup() |> | ||
nest_by(s2) |> | ||
mutate(n_obs = c(nrow(data))) |> | ||
filter(n_obs > 4) |> | ||
mutate(mae = c(median(abs(data$conc - data$pred))), | ||
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), | ||
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> | ||
ungroup() |> | ||
summarize(mae = median(mae), | ||
rho = median(rho), | ||
ci_coverage = scales::percent(median(ci_coverage)), | ||
median_n_obs_per_grouping = median(n_obs)) |> | ||
knitr::kable(digits = 2) | ||
``` | ||
|
||
### Annual | ||
|
||
Exclude stations with 4 or less total annual observations. | ||
|
||
```{r "yearly accuracies"} | ||
d |> | ||
group_by(s2, year) |> | ||
summarize(pred = mean(pred), | ||
conc = mean(conc), | ||
se = mean(sqrt(se^2))) |> | ||
mutate(lci = pred - se * qnorm(0.025, lower.tail = FALSE), | ||
uci = pred + se * qnorm(0.025, lower.tail = FALSE), | ||
ci_covered = conc < uci & conc > lci) |> | ||
ungroup() |> | ||
nest_by(s2) |> | ||
mutate(n_obs = c(nrow(data))) |> | ||
filter(n_obs > 4) |> | ||
mutate(mae = c(median(abs(data$conc - data$pred))), | ||
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), | ||
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> | ||
ungroup() |> | ||
summarize(mae = median(mae), | ||
rho = median(rho), | ||
ci_coverage = scales::percent(median(ci_coverage)), | ||
median_n_obs_per_grouping = median(n_obs)) |> | ||
knitr::kable(digits = 2) | ||
``` | ||
|
||
## Median LOLO Accuracy Per Spatial Aggregation Period | ||
|
||
```{r "spatial variation in accuracies"} | ||
library(s2) | ||
the_map_theme <- | ||
the_theme + | ||
ggplot2::theme( | ||
axis.ticks = ggplot2::element_blank(), | ||
axis.text = ggplot2::element_blank(), | ||
axis.title = ggplot2::element_blank(), | ||
rect = ggplot2::element_blank(), | ||
line = ggplot2::element_blank(), | ||
panel.grid = ggplot2::element_blank(), | ||
plot.margin = ggplot2::margin(1, 1, 1, 1, "cm"), | ||
legend.key.height = ggplot2::unit(1, "cm"), | ||
legend.key.width = ggplot2::unit(0.3, "cm") | ||
) | ||
d |> | ||
mutate(s2_ = s2_cell_parent(s2, 5)) |> | ||
nest_by(s2, s2_) |> | ||
mutate(n_obs = c(nrow(data))) |> | ||
filter(n_obs > 4) |> | ||
mutate(mae = c(median(abs(data$conc - data$pred))), | ||
rho = c(cor.test(data$conc, data$pred, method = "spearman", exact = FALSE)$estimate), | ||
ci_coverage = c(sum(data$ci_covered) / length(data$ci_covered))) |> | ||
group_by(s2_) |> | ||
summarize(mae = median(mae), | ||
s2_ = unique(s2_)) |> | ||
mutate(geometry = s2_cell_polygon(s2_)) |> | ||
sf::st_as_sf() |> | ||
ggplot() + | ||
geom_sf(aes(fill = mae), size = 0) + | ||
coord_sf(crs = 5072) + | ||
the_map_theme + | ||
scale_fill_viridis_c() + | ||
theme(legend.position = c(0.25, 0.1), | ||
legend.direction = "horizontal", | ||
legend.title = element_text(size = 11, family = "sans"), | ||
legend.text = element_text(size = 11), | ||
legend.box = "hortizontal", | ||
legend.key.height = unit(4, "mm"), | ||
legend.key.width = unit(9, "mm"), | ||
strip.text.x = element_text(size = 11, face = "bold", vjust = 1), | ||
strip.text.y = element_text(size = 11, face = "bold")) + | ||
labs(fill = expression(paste(MAE~(ug/m^3)))) + | ||
guides(fill = guide_colorbar(title.position = "top", title.hjust = 0.5)) | ||
``` |