Skip to content

Commit

Permalink
test: add an R test to run FIMS in parallel with {snowfall}
Browse files Browse the repository at this point in the history
Co-authored-by: Cole-Monnahan-NOAA <[email protected]>
  • Loading branch information
Bai-Li-NOAA and Cole-Monnahan-NOAA committed Jul 9, 2024
1 parent 0296e41 commit 4549c8a
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 1 deletion.
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,14 @@ Imports:
Suggests:
covr,
knitr,
parallel,
remotes,
rmarkdown,
snowfall,
testthat (>= 3.0.0),
tidyverse,
usethis
usethis,
withr
LinkingTo:
Rcpp,
RcppEigen,
Expand Down
Binary file not shown.
47 changes: 47 additions & 0 deletions tests/testthat/fixtures/simulate-integration-test-data.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Install the operating model repo from GitHub
remotes::install_github(
repo = "Bai-Li-NOAA/Age_Structured_Stock_Assessment_Model_Comparison"
)

working_dir <- getwd()

maindir <- tempdir()

# Save the initial OM input using ASSAMC package (sigmaR = 0.4)
model_input <- ASSAMC::save_initial_input()

# Configure the input parameters for the simulation
sim_num <- 100
FIMS_100iter <- ASSAMC::save_initial_input(
base_case = TRUE,
input_list = model_input,
maindir = maindir,
om_sim_num = sim_num,
keep_sim_num = sim_num,
figure_number = 1,
seed_num = 9924,
case_name = "FIMS_100iter"
)

# Run OM and generate om_input, om_output, and em_input
# using function from the model comparison project
ASSAMC::run_om(input_list = FIMS_100iter)

on.exit(unlink(maindir, recursive = TRUE), add = TRUE)

setwd(working_dir)
on.exit(setwd(working_dir), add = TRUE)

# Loop through each simulation to load the results from the corresponding
# .RData files and save them into one file
om_input_list <- om_output_list <- em_input_list <-
vector(mode = "list", length = sim_num)
for (i in 1:sim_num){
load(file.path(maindir, "FIMS_100iter", "output", "OM", paste0("OM", i, ".RData")))
om_input_list[[i]] <- om_input
om_output_list[[i]] <- om_output
em_input_list[[i]] <- em_input
}

save(om_input_list, om_output_list, em_input_list,
file = test_path("fixtures", "integration_test_data.RData"))
186 changes: 186 additions & 0 deletions tests/testthat/helper-integration-tests-setup.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Set-up Rcpp modules and fix parameters to "true" values from the OM
setup_and_run_FIMS <- function(iter_id,
om_input_list,
om_output_list,
em_input_list,
estimation_mode = TRUE) {
# set.seed(seed = 123)

# Load operating model data
om_input <- om_input_list[[iter_id]]
om_output <- om_output_list[[iter_id]]
em_input <- em_input_list[[iter_id]]

clear()
# Recruitment
# create new module in the recruitment class (specifically Beverton-Holt,
# when there are other options, this would be where the option would be chosen)
recruitment <- new(BevertonHoltRecruitment)

# NOTE: in first set of parameters below (for recruitment),
# $is_random_effect (default is FALSE) and $estimated (default is FALSE)
# are defined even if they match the defaults in order to provide an example
# of how that is done. Other sections of the code below leave defaults in
# place as appropriate.

# set up logR_sd
# logR_sd is NOT logged. It needs to enter the model logged b/c the exp() is
# taken before the likelihood calculation
recruitment$log_sigma_recruit$value <- log(om_input$logR_sd)
recruitment$log_sigma_recruit$is_random_effect <- FALSE
recruitment$log_sigma_recruit$estimated <- FALSE
# set up log_rzero (equilibrium recruitment)
recruitment$log_rzero$value <- log(om_input$R0)
recruitment$log_rzero$is_random_effect <- FALSE
recruitment$log_rzero$estimated <- TRUE
# set up logit_steep
recruitment$logit_steep$value <- -log(1.0 - om_input$h) + log(om_input$h - 0.2)
recruitment$logit_steep$is_random_effect <- FALSE
recruitment$logit_steep$estimated <- FALSE
# turn on estimation of deviations
recruitment$estimate_log_devs <- TRUE
# recruit deviations should enter the model in normal space.
# The log is taken in the likelihood calculations
# alternative setting: recruitment$log_devs <- rep(0, length(om_input$logR.resid))
recruitment$log_devs <- om_input$logR.resid[-1]


# Data
catch <- em_input$L.obs$fleet1
# set fishing fleet catch data, need to set dimensions of data index
# currently FIMS only has a fleet module that takes index for both survey index and fishery catch
fishing_fleet_index <- new(Index, length(catch))
fishing_fleet_index$index_data <- catch
# set fishing fleet age comp data, need to set dimensions of age comps
fishing_fleet_age_comp <- new(AgeComp, length(catch), om_input$nages)
fishing_fleet_age_comp$age_comp_data <- c(t(em_input$L.age.obs$fleet1)) * em_input$n.L$fleet1

# repeat for surveys
survey_index <- em_input$surveyB.obs$survey1
survey_fleet_index <- new(Index, length(survey_index))
survey_fleet_index$index_data <- survey_index
survey_fleet_age_comp <- new(AgeComp, length(survey_index), om_input$nages)
survey_fleet_age_comp$age_comp_data <- c(t(em_input$survey.age.obs$survey1)) * em_input$n.survey$survey1

# Growth
ewaa_growth <- new(EWAAgrowth)
ewaa_growth$ages <- om_input$ages
ewaa_growth$weights <- om_input$W.mt

# Maturity
maturity <- new(LogisticMaturity)
maturity$inflection_point$value <- om_input$A50.mat
maturity$inflection_point$is_random_effect <- FALSE
maturity$inflection_point$estimated <- FALSE
maturity$slope$value <- om_input$slope
maturity$slope$is_random_effect <- FALSE
maturity$slope$estimated <- FALSE

# Fleet
# Create the fishing fleet
fishing_fleet_selectivity <- new(LogisticSelectivity)
fishing_fleet_selectivity$inflection_point$value <- om_input$sel_fleet$fleet1$A50.sel1
fishing_fleet_selectivity$inflection_point$is_random_effect <- FALSE
# turn on estimation of inflection_point
fishing_fleet_selectivity$inflection_point$estimated <- TRUE
fishing_fleet_selectivity$slope$value <- om_input$sel_fleet$fleet1$slope.sel1
# turn on estimation of slope
fishing_fleet_selectivity$slope$is_random_effect <- FALSE
fishing_fleet_selectivity$slope$estimated <- TRUE

fishing_fleet <- new(Fleet)
fishing_fleet$nages <- om_input$nages
fishing_fleet$nyears <- om_input$nyr
fishing_fleet$log_Fmort <- log(om_output$f)
fishing_fleet$estimate_F <- TRUE
fishing_fleet$random_F <- FALSE
fishing_fleet$log_q <- log(1.0)
fishing_fleet$estimate_q <- FALSE
fishing_fleet$random_q <- FALSE
fishing_fleet$log_obs_error <- rep(log(sqrt(log(em_input$cv.L$fleet1^2 + 1))), om_input$nyr)
fishing_fleet$estimate_obs_error <- FALSE
# Modules are linked together using module IDs
# Each module has a get_id() function that returns the unique ID for that module
# Each fleet uses the module IDs to link up the correct module to the correct fleet
# Note: Likelihoods not yet set up as a stand-alone modules, so no get_id()
fishing_fleet$SetAgeCompLikelihood(1)
fishing_fleet$SetIndexLikelihood(1)
fishing_fleet$SetSelectivity(fishing_fleet_selectivity$get_id())
fishing_fleet$SetObservedIndexData(fishing_fleet_index$get_id())
fishing_fleet$SetObservedAgeCompData(fishing_fleet_age_comp$get_id())

# Create the survey fleet
survey_fleet_selectivity <- new(LogisticSelectivity)
survey_fleet_selectivity$inflection_point$value <- om_input$sel_survey$survey1$A50.sel1
survey_fleet_selectivity$inflection_point$is_random_effect <- FALSE
# turn on estimation of inflection_point
survey_fleet_selectivity$inflection_point$estimated <- TRUE
survey_fleet_selectivity$slope$value <- om_input$sel_survey$survey1$slope.sel1
survey_fleet_selectivity$slope$is_random_effect <- FALSE
# turn on estimation of slope
survey_fleet_selectivity$slope$estimated <- TRUE

survey_fleet <- new(Fleet)
survey_fleet$is_survey <- TRUE
survey_fleet$nages <- om_input$nages
survey_fleet$nyears <- om_input$nyr
survey_fleet$estimate_F <- FALSE
survey_fleet$random_F <- FALSE
survey_fleet$log_q <- log(om_output$survey_q$survey1)
survey_fleet$estimate_q <- TRUE
survey_fleet$random_q <- FALSE
survey_fleet$log_obs_error <- rep(log(sqrt(log(em_input$cv.survey$survey1^2 + 1))), om_input$nyr)
survey_fleet$estimate_obs_error <- FALSE
survey_fleet$SetAgeCompLikelihood(1)
survey_fleet$SetIndexLikelihood(1)
survey_fleet$SetSelectivity(survey_fleet_selectivity$get_id())
survey_fleet$SetObservedIndexData(survey_fleet_index$get_id())
survey_fleet$SetObservedAgeCompData(survey_fleet_age_comp$get_id())

# Population
population <- new(Population)
population$log_M <- rep(log(om_input$M.age[1]), om_input$nyr * om_input$nages)
population$estimate_M <- FALSE
population$log_init_naa <- log(om_output$N.age[1, ])
population$estimate_init_naa <- TRUE
population$nages <- om_input$nages
population$ages <- om_input$ages
population$nfleets <- sum(om_input$fleet_num, om_input$survey_num)
population$nseasons <- 1
population$nyears <- om_input$nyr
population$SetMaturity(maturity$get_id())
population$SetGrowth(ewaa_growth$get_id())
population$SetRecruitment(recruitment$get_id())

# Set-up TMB
CreateTMBModel()
# Create parameter list from Rcpp modules
parameters <- list(p = get_fixed())
obj <- TMB::MakeADFun(data = list(), parameters, DLL = "FIMS", silent=TRUE)

if (estimation_mode == TRUE){

opt <- with(obj, optim(par, fn, gr,
method = "BFGS",
control = list(maxit = 1000000, reltol = 1e-15)
))
}
# Call report using MLE parameter values, or
# the initial values if optimization is skipped
report <- obj$report(obj$env$last.par.best)

sdr <- TMB::sdreport(obj)
sdr_report <- summary(sdr, "report")
sdr_fixed <- summary(sdr, "fixed")

clear()

# end of setup_fims function, returning test_env
return(list(
parameters = parameters,
obj = obj,
report = report,
sdr_report = sdr_report,
sdr_fixed = sdr_fixed
))
}
69 changes: 69 additions & 0 deletions tests/testthat/test-parallel-with-snowfall.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Ensure the latest precompiled version of FIMS is installed in R before
# running devtools. To do this, either run:
# - devtools::install() followed by devtools::test(), or
# - devtools::check()

# Run FIMS in serial and parallel
# This test demonstrates how to run the FIMS model in both serial and parallel
# modes. The test compares the execution time and results of running the model
# in serial versus parallel. The parallel execution uses the {snowfall} package
# to parallelize the tasks across multiple CPU cores

# Load the model comparison operating model data from the fixtures folder
load(test_path("fixtures", "integration_test_data.RData"))

# Run the FIMS model in serial and record the execution time
estimation_results_serial <- vector(mode = "list", length = length(om_input_list))

start_time_serial <- Sys.time()
for (i in 1:length(om_input_list)) {
estimation_results_serial[[i]] <- setup_and_run_FIMS(
iter_id = i,
om_input_list = om_input_list,
om_output_list = om_output_list,
em_input_list = em_input_list,
estimation_mode = TRUE)
}
end_time_serial <- Sys.time()
estimation_time_serial <- end_time_serial - start_time_serial

test_that("Run FIMS in parallel using {snowfall}", {

core_num <- parallel::detectCores() - 1
snowfall::sfInit(parallel = TRUE, cpus = core_num)
start_time_parallel <- Sys.time()

results_parallel <- snowfall::sfLapply(
1:length(om_input_list),
setup_and_run_FIMS,
om_input_list,
om_output_list,
em_input_list,
TRUE)

end_time_parallel <- Sys.time()

time_parallel <- end_time_parallel - start_time_parallel

snowfall::sfStop()

# Compare execution times: verify that the execution time of the parallel run
# is less than the serial run.
expect_lt(object = time_parallel, expected = estimation_time_serial)

# Compare parameters in results:
# Verify that the results from both runs are equivalent.
expect_setequal(unname(unlist(lapply(results_parallel, `[[`, "parameters"))),
unname(unlist(lapply(estimation_results_serial, `[[`, "parameters"))))

# Compare sdr_fixed values in results:
# Verify that the results from both runs are equivalent.
expect_setequal(unlist(lapply(results_parallel, `[[`, "sdr_fixed")),
unlist(lapply(estimation_results_serial, `[[`, "sdr_fixed")))

# Compare sdr_report values in results:
# Verify that the results from both runs are equivalent.
expect_setequal(unlist(lapply(results_parallel, `[[`, "sdr_report")),
unlist(lapply(estimation_results_serial, `[[`, "sdr_report")))

})

0 comments on commit 4549c8a

Please sign in to comment.