Skip to content
This repository has been archived by the owner on Jul 24, 2024. It is now read-only.

Groundtruth refactor #435

Open
wants to merge 21 commits into
base: inference
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions R/pkgs/covidcommon/NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ export(create_prefix)
export(get_CSSE_US_data)
export(get_CSSE_US_matchGlobal_data)
export(get_CSSE_global_data)
export(get_LA_health_dpt_county_hosp_data)
export(get_USAFacts_data)
export(get_groundtruth_from_single_source)
export(get_groundtruth_from_source)
export(get_hhsCMU_allHosp_st_data)
export(get_hhsCMU_hospCurr_st_data)
Expand Down
83 changes: 71 additions & 12 deletions R/pkgs/covidcommon/R/DataUtils.R
Original file line number Diff line number Diff line change
Expand Up @@ -707,8 +707,51 @@ get_reichlab_cty_data <- function(cum_case_filename = "data/case_data/rlab_cum_c
##'
##' @export
##'
get_groundtruth_from_source <- function(source = "csse", scale = "US county", variables = c("Confirmed", "Deaths", "incidI", "incidDeath"),
incl_unass = FALSE, get_hosp = FALSE){
get_groundtruth_from_source <- function(source = c("csse", "csse", "csse", "csse"), scale = "US county", variables = c("Confirmed", "Deaths", "incidI", "incidDeath"), incl_unass = FALSE, adjust_for_variant = FALSE, variant_props_file = "data/variant/variant_props_long.csv", misc_data_filename = NULL) {
df <- data.frame(
data_source = source,
variables = variables
)

df %>%
dplyr::group_by(data_source) %>%
dplyr::group_modify(function(.x,.y){
return(get_groundtruth_from_single_source(
source = .y$data_source,
scale = scale,
variables = .x$variables,
incl_unass = incl_unass,
adjust_for_variant = adjust_for_variant,
variant_props_file = variant_props_file,
misc_data_filename = misc_data_filename
))
}) %>%
return()
}

##'
##' Wrapper function to pull data from different sources
##'
##' Pulls a groundtruth dataset with the variables specified
##'
##' @param source name of data source: reichlab, usafacts, csse
##' @param scale geographic scale: US county, US state, country (csse only), complete (csse only)
##' @param variables vector that may include one or more of the following variable names: Confirmed, Deaths, incidI, incidDeath, (hhsCMU source only: incidH_confirmed, incidH_all, hospCurr_confirmed, hospCurr_all)
##' @return data frame
##'
##' @importFrom magrittr %>%
##'
##' @export
##'
get_groundtruth_from_single_source <- function(source = "csse", scale = "US county", variables = c("Confirmed", "Deaths", "incidI", "incidDeath"), incl_unass = FALSE, adjust_for_variant = FALSE, variant_props_file = "data/variant/variant_props_long.csv", misc_data_filename = NULL) {

if(length(source) > 1) {
stop(paste(
"get_groundtruth_from_single_source only allows a single source, but",
paste(source, collapse = ", "),
"was provided"
))
}

if(source == "reichlab" & scale == "US county"){

Expand Down Expand Up @@ -773,25 +816,41 @@ get_groundtruth_from_source <- function(source = "csse", scale = "US county", va

} else if(source == "hhsCMU" & scale == "US state"){

rc <- get_hhsCMU_cleanHosp_st_data()
rc <- get_hhsCMU_incidH_st_data()
rc <- dplyr::mutate(rc, FIPS = paste0(FIPS, "000"))
rc <- dplyr::select(rc, Update, FIPS, source, !!variables)
rc <- tidyr::drop_na(rc, tidyselect::everything())

} else{
warning(print(paste("The combination of ", source, "and", scale, "is not valid. Returning NULL object.")))
rc <- NULL
}

if(get_hosp & scale == "US state") {
hosp <- get_hhsCMU_incidH_st_data()
hosp <- hosp %>% dplyr::select(-FIPS)
rc <- left_join(rc, hosp)
} else if ((source == "LA health dpt") && (scale == "US county")) {

rc <- get_LA_health_dpt_county_hosp_data(misc_data_filename)

} else {
warning(print(paste("The combination of ", source, "and", scale, "is not valid. Returning empty tibble.")))
rc <- dplyr::as_tibble(NULL)
}


return(rc)

}

##'
##' Pull LA data
##'
##' @export
get_LA_health_dpt_county_hosp_data <- function(hosp_file_name = "data/LACDPH/hospitalizations/20210813.xlsx"){
dat <- readxl::read_xlsx(hosp_file_name) %>%
dplyr::rename(incidH = incidH_covid) %>%
dplyr::mutate(incidH = dplyr::na_if(incidH, "n/a")) %>%
dplyr::mutate(date = as.Date(date), incidH = as.numeric(incidH),
FIPS = "06037") %>%
dplyr::select(Update=date, FIPS, incidH) %>%
dplyr::ungroup()

return(dat)
}

##'
##' Pull CSSE US data in format similar to that of global data
##'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ test_that("get_groundtruth_from_source works", {
csse_st_ctyonly <- get_groundtruth_from_source(source = "csse", scale = "US state", incl_unass = FALSE)
fake <- get_groundtruth_from_source(source = "fakesource")

expect_null(fake)
expect_equal(nrow(fake),
0)
usaf_cty_processed <- usaf_cty %>%
dplyr::mutate(FIPS = stringr::str_sub(FIPS, 1, 2)) %>%
dplyr::group_by(Update, FIPS, source) %>%
Expand Down
5 changes: 3 additions & 2 deletions R/pkgs/inference/R/filter_MC_runner_funcs.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ aggregate_and_calc_loc_likelihoods <- function(
this_location_log_likelihood <- 0
for (var in names(ground_truth_data[[location]])) {

observed_indices <- !is.na(ground_truth_data[[location]][[var]]$data_var)

this_location_log_likelihood <- this_location_log_likelihood +
## Actually compute likelihood for this location and statistic here:
sum(inference::logLikStat(
obs = ground_truth_data[[location]][[var]]$data_var,
sim = this_location_modeled_outcome[[var]]$sim_var,
obs = ground_truth_data[[location]][[var]]$data_var[observed_indices],
sim = this_location_modeled_outcome[[var]]$sim_var[observed_indices],
dist = config$filtering$statistics[[var]]$likelihood$dist,
param = config$filtering$statistics[[var]]$likelihood$param,
add_one = config$filtering$statistics[[var]]$add_one
Expand Down
160 changes: 124 additions & 36 deletions R/pkgs/inference/R/groundtruth.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,72 @@
#' @param data_path Path where to write the data
#' @param cache logical indicating whether to cache the data (default = TRUE)
#' @param gt_source string indicating source of ground truth data. options include "csse" or "usafacts" (default csse)
#' @param gt_scale string indicating whether "US county" or "US state"-level data
#' @param gt_scale string indicating whether "US county" or "US state"-level data
#'
#' @return NULL
#'
#' @export
get_ground_truth_file <- function(data_path, cache = TRUE, gt_source = "csse", gt_scale = "US county") {
get_ground_truth_file <- function(data_path, cache = TRUE, gt_source = "csse", gt_scale = "US county", gt_vars = c("Confirmed", "Deaths", "incidI", "incidDeath"), new_vars = gt_vars, fips_column_name = "geoid", date_column_name = "date", misc_data_filename = NULL) {
data_dir <- dirname(data_path)

if(!dir.exists(data_dir)){
suppressWarnings(dir.create(data_dir,recursive=TRUE))
}

if (length(gt_vars) != length(new_vars)) {
stop(paste(
"groundtruth variables and new variables should have the same number of elements, got:",
"(", paste(gt_vars, collapse = ", "), ")",
"and",
"(", paste(new_vars, collapse = ", "), ")"
))
}

if(!isTRUE(all.equal(gt_vars, new_vars))){
warning("new_vars is deprecated, please adjust data_var to match gt_column_name")
}

if(!(file.exists(data_path) & cache)){
message(paste("*** Loading Data from", gt_source, "\n"))
cases_deaths <- suppressMessages(covidcommon::get_groundtruth_from_source(source = gt_source, scale = gt_scale, variables = c("Confirmed", "Deaths", "incidI", "incidDeath"), incl_unass = ifelse(gt_scale == "US state", TRUE, FALSE)))
cases_deaths <- dplyr::arrange(
dplyr::rename(
dplyr::mutate(
cases_deaths,
Update = lubridate::ymd(Update)
),
date = Update,
cumConfirmed = Confirmed,
cumDeaths = Deaths,
confirmed_incid = incidI,
death_incid = incidDeath
cases_deaths <- suppressMessages(covidcommon::get_groundtruth_from_source(
source = gt_source,
scale = gt_scale,
variables = gt_vars,
incl_unass = ifelse(gt_scale == "US state", TRUE, FALSE),
misc_data_filename = misc_data_filename
))
cases_deaths <- dplyr::arrange(
dplyr::mutate(
cases_deaths,
Update = lubridate::ymd(Update)
),
date
Update
)
if(any(is.na(cases_deaths$cumConfirmed))){
cases_deaths$cumConfirmed[is.na(cases_deaths$cumConfirmed)] <- 0

gt_vars <- c("Update", "FIPS", gt_vars)
new_vars <- c(date_column_name, fips_column_name, new_vars)

if(!all(gt_vars %in% names(cases_deaths))) {
stop(paste(
"Could not find all expected names. Looking for",
"(", paste(gt_vars, collapse = ", "), ")",
"found",
"(", paste(names(cases_deaths), collapse = ", "), ")"
))
}
if(any(is.na(cases_deaths$cumDeaths))){
cases_deaths$cumDeaths[is.na(cases_deaths$cumDeaths)] <- 0
if(!all(names(cases_deaths) %in% gt_vars)) {
warning(paste(
"Found more than the expected names. Looking for",
"(", paste(gt_vars, collapse = ", "), ")",
"found",
"(", paste(names(cases_deaths), collapse = ", "), ")",
"extra",
"(", paste(names(cases_deaths)[!(names(cases_deaths) %in% gt_vars)], collapse = ", "), ")"
))
}
names(cases_deaths)[names(cases_deaths) %in% gt_vars] <-
setNames(new_vars, gt_vars)[names(cases_deaths)[names(cases_deaths) %in% gt_vars]]

readr::write_csv(cases_deaths, data_path)
rm(cases_deaths)
message("*** DONE Loading Data \n")
Expand All @@ -49,24 +82,79 @@ get_ground_truth_file <- function(data_path, cache = TRUE, gt_source = "csse", g
#' @param data_path Path where to write the data
#'
#' @export
get_ground_truth <- function(data_path, fips_codes, fips_column_name, start_date, end_date, cache = TRUE, gt_source = "csse", gt_scale = "US county"){
get_ground_truth_file(data_path = data_path, cache = cache, gt_source = gt_source, gt_scale = gt_scale)
get_ground_truth <- function(
data_path,
fips_codes = NULL,
fips_column_name = "geoid",
date_column_name = "date",
start_date = NULL,
end_date = NULL,
cache = TRUE,
gt_source = "csse",
gt_scale = "US county",
gt_vars = c("Confirmed", "Deaths", "incidI", "incidDeath"),
new_vars = gt_vars,
misc_data_filename = NULL
) {

rc <- suppressMessages(readr::read_csv(data_path,col_types = list(FIPS = readr::col_character())))
rc <- dplyr::filter(
rc,
FIPS %in% fips_codes,
date >= start_date,
date <= end_date
)
rc <- dplyr::right_join(
rc,
tidyr::expand_grid(
FIPS = unique(rc$FIPS),
date = unique(rc$date)
)
get_ground_truth_file(
data_path = data_path,
cache = cache,
gt_source = gt_source,
gt_scale = gt_scale,
gt_vars = gt_vars,
new_vars = new_vars,
fips_column_name = fips_column_name,
date_column_name = date_column_name,
misc_data_filename = misc_data_filename
)
rc <- dplyr::mutate_if(rc,is.numeric,dplyr::coalesce,0)
names(rc)[names(rc) == "FIPS"] <- fips_column_name


rc <- suppressMessages(readr::read_csv(data_path,col_types = list(FIPS = readr::col_character())))

if(is.null(start_date)) {
start_date <- min(rc$date)
}

if(is.null(end_date)) {
end_date <- max(rc$date)
}

if (is.null(fips_codes)) {
fips_codes <- unique(rc$fips_codes)
}

if(length(start_date)!=length(gt_vars) & length(start_date)==1){
start_date <- rep(start_date, length(gt_vars))
} else if(length(start_date)!=length(gt_vars)){
warning("No start date specified for at least one of the variables; the variable will be removed from the groundtruth")
}

if(length(end_date)!=length(gt_vars) & length(end_date)==1){
end_date <- rep(end_date, length(gt_vars))
} else if(length(end_date)!=length(gt_vars)){
warning("No end date specified for at least one of the variables; the variable will be removed from the groundtruth")
}

rc <- rc %>%
dplyr::filter(!!rlang::sym(fips_column_name) %in% fips_codes) %>%
tidyr::pivot_longer(tidyselect::all_of(new_vars)) %>%
dplyr::mutate(
start_date = lubridate::ymd(start_date[match(name,new_vars)]),
end_date = lubridate::ymd(end_date[match(name,new_vars)])
) %>%
dplyr::filter(
start_date <= !!rlang::sym(date_column_name),
!!rlang::sym(date_column_name) <= end_date
) %>%
tidyr::pivot_wider(names_from = name, values_from = value) %>%
dplyr::right_join(
tidyr::expand_grid(
!!rlang::sym(fips_column_name) := fips_codes
)
) %>%
dplyr::filter(!is.na(!!rlang::sym(date_column_name))) %>%
dplyr::mutate(geoid = !!rlang::sym(fips_column_name))

return(rc)
}
22 changes: 10 additions & 12 deletions R/pkgs/inference/tests/testthat/test-get_ground_truth_file.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ test_that("get_ground_truth_file creates a file",{
warning("Testing should not be using a file that already exists")
file.remove(data_path)
}
expect_error(get_ground_truth_file(data_path,FALSE),NA)
expect_error(get_ground_truth_file(data_path,cache=FALSE),NA)
expect_equal(file.exists(data_path),TRUE)

expect_error(get_ground_truth_file(data_path,TRUE),NA)
expect_error(get_ground_truth_file(data_path,cache=TRUE),NA)
expect_equal(file.exists(data_path),TRUE)
})

Expand All @@ -18,16 +18,14 @@ test_that("get_ground_truth returns an appropriate data frame",{
fips_column_name <- "test_fips_column"
start_date <- lubridate::ymd("2020-04-15")
end_date <- lubridate::ymd("2020-04-30")
expect_error({get_ground_truth(data_path,fips_codes,fips_column_name,start_date,end_date,FALSE)},NA)
expect_error({get_ground_truth(data_path,fips_codes,fips_column_name,start_date,end_date,TRUE)},NA)
new_vars <- c("cumConfirmed", "cumDeaths", "confirmed_incid","death_incid")
expect_error({get_ground_truth(data_path = data_path,fips_codes = fips_codes,fips_column_name = fips_column_name,start_date = start_date,end_date = end_date, new_vars = new_vars, cache=FALSE)},NA)
expect_error({get_ground_truth(data_path = data_path,fips_codes = fips_codes, fips_column_name = fips_column_name,start_date = start_date,end_date = end_date, new_vars = new_vars, cache=TRUE)},NA)
expect_equal({
all(c(fips_column_name,"date","confirmed_incid","death_incid", "cumConfirmed", "cumDeaths") %in% names(get_ground_truth(data_path,fips_codes,fips_column_name,start_date,end_date,TRUE)))
all(c(fips_column_name,"date","cumConfirmed", "cumDeaths", "confirmed_incid","death_incid") %in% names(get_ground_truth(data_path = data_path,fips_codes = fips_codes, fips_column_name = fips_column_name,start_date = start_date,end_date = end_date, new_vars = new_vars, cache=TRUE)))
},TRUE)
expect_equal({
all(c(fips_column_name,"date","confirmed_incid","death_incid", "cumConfirmed", "cumDeaths") %in% names(get_ground_truth(data_path,fips_codes,fips_column_name,start_date,end_date,TRUE)))
},TRUE)
expect_gt(nrow(get_ground_truth(data_path,fips_codes,fips_column_name,start_date,end_date,TRUE)),0)
expect_equal(all(get_ground_truth(data_path,fips_codes,fips_column_name,start_date,end_date,TRUE)[[fips_column_name]] %in% fips_codes),TRUE)
expect_equal(all(get_ground_truth(data_path,fips_codes,fips_column_name,start_date,end_date,TRUE)$date >= start_date),TRUE)
expect_equal(all(get_ground_truth(data_path,fips_codes,fips_column_name,start_date,end_date,TRUE)$date <= end_date),TRUE)
expect_gt(nrow(get_ground_truth(data_path = data_path,fips_codes = fips_codes, fips_column_name = fips_column_name,start_date = start_date,end_date = end_date, new_vars = new_vars,cache=TRUE)),0)
expect_equal(all(get_ground_truth(data_path = data_path,fips_codes = fips_codes, fips_column_name = fips_column_name,start_date = start_date,end_date = end_date, new_vars = new_vars,cache=TRUE)[[fips_column_name]] %in% fips_codes),TRUE)
expect_equal(all(get_ground_truth(data_path = data_path,fips_codes = fips_codes, fips_column_name = fips_column_name,start_date = start_date,end_date = end_date, new_vars = new_vars,cache=TRUE)$date >= start_date),TRUE)
expect_equal(all(get_ground_truth(data_path = data_path,fips_codes = fips_codes, fips_column_name = fips_column_name,start_date = start_date,end_date = end_date, new_vars = new_vars,cache=TRUE)$date <= end_date),TRUE)
})
Loading