diff --git a/.Rbuildignore b/.Rbuildignore index 999be025..aa881e6b 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -8,3 +8,5 @@ ^pkgdown$ ^doc$ ^Meta$ +^cran-comments\.md$ +^CRAN-RELEASE$ diff --git a/CRAN-RELEASE b/CRAN-RELEASE new file mode 100644 index 00000000..04cc61c9 --- /dev/null +++ b/CRAN-RELEASE @@ -0,0 +1,2 @@ +This package was submitted to CRAN on 2021-06-16. +Once it is accepted, delete this file and tag the release (commit 1d00ae2). diff --git a/DESCRIPTION b/DESCRIPTION index 2dc1f294..9206f6d2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: luz Title: Higher Level 'API' for 'torch' -Version: 0.0.0.9000 +Version: 0.1.0 Authors@R: c( person("Daniel", "Falbel", email = "daniel@rstudio.com", role = c("aut", "cre", "cph")), person(family = "RStudio", role = c("cph")) @@ -12,9 +12,8 @@ Description: A high level interface for 'torch' providing utilities to reduce th Howard et al. (2020) , 'Keras' by Chollet et al. (2015) and 'Pytorch Lightning' by Falcon et al. (2019) . License: MIT + file LICENSE -URL: https://mlverse.github.io/luz, https://github.com/mlverse/luz +URL: https://mlverse.github.io/luz/, https://github.com/mlverse/luz Encoding: UTF-8 -LazyData: true Roxygen: list(markdown = TRUE) RoxygenNote: 7.1.1 Imports: diff --git a/R/callbacks-interrupt.R b/R/callbacks-interrupt.R index 6270fa58..fe9d3e1b 100644 --- a/R/callbacks-interrupt.R +++ b/R/callbacks-interrupt.R @@ -10,6 +10,12 @@ NULL #' @note In general you don't need to use these callback by yourself because it's always #' included by default in [fit.luz_module_generator()]. #' +#' @examples +#' interrupt_callback <- luz_callback_interrupt() +#' +#' @returns +#' A `luz_callback` +#' #' @family luz_callbacks #' @export luz_callback_interrupt <- luz_callback( diff --git a/R/callbacks-profile.R b/R/callbacks-profile.R index 95782834..f8cd44e8 100644 --- a/R/callbacks-profile.R +++ b/R/callbacks-profile.R @@ -16,6 +16,12 @@ #' @note In general you don't need to use these callback by yourself because it's always #' included by default in [fit.luz_module_generator()]. #' +#' @examples +#' profile_callback <- luz_callback_profile() +#' +#' @returns +#' A `luz_callback` +#' #' @family luz_callbacks #' @export luz_callback_profile <- luz_callback( diff --git a/R/callbacks.R b/R/callbacks.R index a2cf5ef5..ca5f70a9 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -86,6 +86,10 @@ luz_callback <- function(name = NULL, ..., private = NULL, active = NULL, parent #' @note Printing can be disabled by passing `verbose=FALSE` to [fit.luz_module_generator()]. #' #' @family luz_callbacks +#' +#' @returns +#' A `luz_callback` +#' #' @export luz_callback_progress <- luz_callback( "progress_callback", @@ -203,6 +207,10 @@ luz_callback_progress <- luz_callback( #' used by default in [fit.luz_module_generator()]. #' #' @family luz_callbacks +#' +#' @returns +#' A `luz_callback` +#' #' @export luz_callback_metrics <- luz_callback( "metrics_callback", @@ -272,6 +280,9 @@ luz_callback_metrics <- luz_callback( #' @note In general you won't need to explicitly use the metrics callback as it's #' used by default in [fit.luz_module_generator()]. #' +#' @returns +#' A `luz_callback` +#' #' @family luz_callbacks #' @export luz_callback_train_valid <- luz_callback( diff --git a/R/metrics.R b/R/metrics.R index e17156e0..1216edb8 100644 --- a/R/metrics.R +++ b/R/metrics.R @@ -96,6 +96,11 @@ luz_metric <- function(name = NULL, ..., private = NULL, active = NULL, #' metric$compute() #' } #' @export +#' +#' +#' @returns +#' Returns new Luz metric. +#' #' @family luz_metrics luz_metric_accuracy <- luz_metric( abbrev = "Acc", @@ -131,6 +136,10 @@ luz_metric_accuracy <- luz_metric( #' metric$update(torch_rand(100), torch::torch_randint(0, 1, size = 100)) #' metric$compute() #' } +#' +#' @returns +#' Returns new Luz metric. +#' #' @family luz_metrics #' @export luz_metric_binary_accuracy <- luz_metric( @@ -172,6 +181,8 @@ luz_metric_binary_accuracy <- luz_metric( #' metric$update(torch_randn(100), torch::torch_randint(0, 1, size = 100)) #' metric$compute() #' } +#' @returns +#' Returns new Luz metric. #' #' @family luz_metrics #' @export @@ -240,6 +251,8 @@ luz_metric_loss_average <- luz_metric( #' metric$update(torch_randn(100), torch_randn(100)) #' metric$compute() #' } +#' @returns +#' Returns new Luz metric. #' #' @family luz_metrics #' @export @@ -289,6 +302,10 @@ luz_metric_mse <- luz_metric( #' Computes the root mean squared error. #' #' @family luz_metrics +#' +#' @returns +#' Returns new Luz metric. +#' #' @export luz_metric_rmse <- luz_metric( inherit = luz_metric_mse, diff --git a/R/module-plot.R b/R/module-plot.R index 83293a14..00635944 100644 --- a/R/module-plot.R +++ b/R/module-plot.R @@ -5,3 +5,5 @@ plot.luz_module_fitted <- function(x, ...) { p <- p + ggplot2::geom_point() + ggplot2::geom_line() p + ggplot2::facet_grid(metric ~ set, scales = "free_y") } + +globalVariables(c("epoch", "value")) diff --git a/R/module-print.R b/R/module-print.R index d4151fe1..2d9f73d9 100644 --- a/R/module-print.R +++ b/R/module-print.R @@ -43,3 +43,5 @@ print.luz_module_fitted <- function(x, ...) { print(x$model) } + + diff --git a/R/module.R b/R/module.R index fa9379b9..dc56aba5 100644 --- a/R/module.R +++ b/R/module.R @@ -16,6 +16,9 @@ #' @param metrics (`list`, optional) A list of metrics to be tracked during #' the training procedure. #' +#' @returns +#' A luz module that can be trained with [fit()]. +#' #' @family training #' #' @export @@ -68,6 +71,9 @@ setup <- function(module, loss = NULL, optimizer = NULL, metrics = NULL) { #' #' @family set_hparam #' +#' @returns +#' The same luz module +#' #' @export set_hparams <- function(module, ...) { hparams <- rlang::list2(...) @@ -87,6 +93,9 @@ set_hparams <- function(module, ...) { #' `optim_adam` function is called with `optim_adam(parameters, lr=0.1)` when fitting #' the model. #' +#' @returns +#' The same luz module +#' #' @family set_hparam #' @export set_opt_hparams <- function(module, ...) { @@ -132,9 +141,12 @@ get_opt_hparams <- function(module) { #' #' @param ... Currently unused, #' +#' @returns +#' A fitted object that can be saved with [luz_save()] and can be printed with +#' [print()] and plotted with [plot()]. +#' #' @importFrom generics fit #' @export -#' fit.luz_module_generator <- function(object, data, epochs = 10, callbacks = NULL, valid_data = NULL, accelerator = NULL, verbose = NULL, ...) { diff --git a/README.md b/README.md index 7b13bf6e..9c80eb5d 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![R-CMD-check](https://github.com/mlverse/luz/workflows/R-CMD-check/badge.svg)](https://github.com/mlverse/luz/actions) [![Codecov test coverage](https://codecov.io/gh/mlverse/luz/branch/master/graph/badge.svg)](https://codecov.io/gh/mlverse/luz?branch=master) -[![Discord](https://img.shields.io/discord/837019024499277855?logo=discord)](https://discord.gg/s3D5cKhBkx) +[![Discord](https://img.shields.io/discord/837019024499277855?logo=discord)](https://discord.com/invite/s3D5cKhBkx) luz is a higher level API for torch providing abstractions to allow for much less verbose training loops. diff --git a/cran-comments.md b/cran-comments.md new file mode 100644 index 00000000..0b524c9f --- /dev/null +++ b/cran-comments.md @@ -0,0 +1 @@ +First release. diff --git a/man/fit.luz_module_generator.Rd b/man/fit.luz_module_generator.Rd index 6b7056a8..0ac48981 100644 --- a/man/fit.luz_module_generator.Rd +++ b/man/fit.luz_module_generator.Rd @@ -43,6 +43,10 @@ to the console.} \item{...}{Currently unused,} } +\value{ +A fitted object that can be saved with \code{\link[=luz_save]{luz_save()}} and can be printed with +\code{\link[=print]{print()}} and plotted with \code{\link[=plot]{plot()}}. +} \description{ Fit a \code{nn_module} } diff --git a/man/luz_callback_interrupt.Rd b/man/luz_callback_interrupt.Rd index 77c9553d..5d7c5369 100644 --- a/man/luz_callback_interrupt.Rd +++ b/man/luz_callback_interrupt.Rd @@ -6,6 +6,9 @@ \usage{ luz_callback_interrupt() } +\value{ +A \code{luz_callback} +} \description{ Adds a handler that allows interrupting the training loop using \code{ctrl + C}. Also registers a \code{on_interrupt} breakpoint so users can register callbacks to @@ -14,6 +17,10 @@ be run on training loop interruption. \note{ In general you don't need to use these callback by yourself because it's always included by default in \code{\link[=fit.luz_module_generator]{fit.luz_module_generator()}}. +} +\examples{ +interrupt_callback <- luz_callback_interrupt() + } \seealso{ Other luz_callbacks: diff --git a/man/luz_callback_metrics.Rd b/man/luz_callback_metrics.Rd index 5f03ff8f..2a70cb77 100644 --- a/man/luz_callback_metrics.Rd +++ b/man/luz_callback_metrics.Rd @@ -6,6 +6,9 @@ \usage{ luz_callback_metrics() } +\value{ +A \code{luz_callback} +} \description{ Tracks metrics passed to \code{\link[=setup]{setup()}} during training and validation. } diff --git a/man/luz_callback_profile.Rd b/man/luz_callback_profile.Rd index a3689251..04cf8119 100644 --- a/man/luz_callback_profile.Rd +++ b/man/luz_callback_profile.Rd @@ -6,6 +6,9 @@ \usage{ luz_callback_profile() } +\value{ +A \code{luz_callback} +} \description{ Computes the times for high-level operations in the training loops. } @@ -24,6 +27,10 @@ the model step. (not including data acquisition and preprocessing) \note{ In general you don't need to use these callback by yourself because it's always included by default in \code{\link[=fit.luz_module_generator]{fit.luz_module_generator()}}. +} +\examples{ +profile_callback <- luz_callback_profile() + } \seealso{ Other luz_callbacks: diff --git a/man/luz_callback_progress.Rd b/man/luz_callback_progress.Rd index 7e266e74..b0845df7 100644 --- a/man/luz_callback_progress.Rd +++ b/man/luz_callback_progress.Rd @@ -6,6 +6,9 @@ \usage{ luz_callback_progress() } +\value{ +A \code{luz_callback} +} \description{ Responsible for printing progress during training. } diff --git a/man/luz_callback_train_valid.Rd b/man/luz_callback_train_valid.Rd index 10b9fc70..f2b09439 100644 --- a/man/luz_callback_train_valid.Rd +++ b/man/luz_callback_train_valid.Rd @@ -6,6 +6,9 @@ \usage{ luz_callback_train_valid() } +\value{ +A \code{luz_callback} +} \description{ Switches important flags for training and evaluation modes. } diff --git a/man/luz_metric_accuracy.Rd b/man/luz_metric_accuracy.Rd index 60cf76f5..6c828736 100644 --- a/man/luz_metric_accuracy.Rd +++ b/man/luz_metric_accuracy.Rd @@ -6,6 +6,9 @@ \usage{ luz_metric_accuracy() } +\value{ +Returns new Luz metric. +} \description{ Computes accuracy for multi-class classification problems. } diff --git a/man/luz_metric_binary_accuracy.Rd b/man/luz_metric_binary_accuracy.Rd index ce37c50e..c3b3adfc 100644 --- a/man/luz_metric_binary_accuracy.Rd +++ b/man/luz_metric_binary_accuracy.Rd @@ -9,6 +9,9 @@ luz_metric_binary_accuracy(threshold = 0.5) \arguments{ \item{threshold}{value used to classifiy observations between 0 and 1.} } +\value{ +Returns new Luz metric. +} \description{ Computes the accuracy for binary classification problems where the model returns probabilities. Commonly used when the loss is \code{\link[torch:nn_bce_loss]{torch::nn_bce_loss()}}. @@ -21,6 +24,7 @@ metric <- metric$new() metric$update(torch_rand(100), torch::torch_randint(0, 1, size = 100)) metric$compute() } + } \seealso{ Other luz_metrics: diff --git a/man/luz_metric_binary_accuracy_with_logits.Rd b/man/luz_metric_binary_accuracy_with_logits.Rd index 83c15b77..240b7d35 100644 --- a/man/luz_metric_binary_accuracy_with_logits.Rd +++ b/man/luz_metric_binary_accuracy_with_logits.Rd @@ -9,6 +9,9 @@ luz_metric_binary_accuracy_with_logits(threshold = 0.5) \arguments{ \item{threshold}{value used to classifiy observations between 0 and 1.} } +\value{ +Returns new Luz metric. +} \description{ Computes accuracy for binary classification problems where the model return logits. Commonly used together with \code{\link[torch:nn_bce_with_logits_loss]{torch::nn_bce_with_logits_loss()}}. @@ -25,7 +28,6 @@ metric <- metric$new() metric$update(torch_randn(100), torch::torch_randint(0, 1, size = 100)) metric$compute() } - } \seealso{ Other luz_metrics: diff --git a/man/luz_metric_mae.Rd b/man/luz_metric_mae.Rd index ed34603b..36b74fb2 100644 --- a/man/luz_metric_mae.Rd +++ b/man/luz_metric_mae.Rd @@ -6,6 +6,9 @@ \usage{ luz_metric_mae() } +\value{ +Returns new Luz metric. +} \description{ Computes the mean absolute error. } @@ -17,7 +20,6 @@ metric <- metric$new() metric$update(torch_randn(100), torch_randn(100)) metric$compute() } - } \seealso{ Other luz_metrics: diff --git a/man/luz_metric_rmse.Rd b/man/luz_metric_rmse.Rd index 97d3d1cd..ea28544a 100644 --- a/man/luz_metric_rmse.Rd +++ b/man/luz_metric_rmse.Rd @@ -6,6 +6,9 @@ \usage{ luz_metric_rmse() } +\value{ +Returns new Luz metric. +} \description{ Computes the root mean squared error. } diff --git a/man/set_hparams.Rd b/man/set_hparams.Rd index fe3ccd8a..028dc6f0 100644 --- a/man/set_hparams.Rd +++ b/man/set_hparams.Rd @@ -12,6 +12,9 @@ set_hparams(module, ...) \item{...}{The parameters set here will be used to initialize the \code{nn_module}, ie they are passed unchanged to the \code{initialize} method of the base \code{nn_module}.} } +\value{ +The same luz module +} \description{ This function is used to define hyper-parameters before calling \code{fit} for \code{luz_modules}. diff --git a/man/set_opt_hparams.Rd b/man/set_opt_hparams.Rd index 2e8c36c1..9ac0fc83 100644 --- a/man/set_opt_hparams.Rd +++ b/man/set_opt_hparams.Rd @@ -14,6 +14,9 @@ For example, if your optimizer is \code{optim_adam} and you pass \code{lr=0.1}, \code{optim_adam} function is called with \code{optim_adam(parameters, lr=0.1)} when fitting the model.} } +\value{ +The same luz module +} \description{ This function is used to define hyper-parameters for the optimizer initialization method. diff --git a/man/setup.Rd b/man/setup.Rd index c2b9d01d..22d6da9f 100644 --- a/man/setup.Rd +++ b/man/setup.Rd @@ -20,6 +20,9 @@ the model parameters.} \item{metrics}{(\code{list}, optional) A list of metrics to be tracked during the training procedure.} } +\value{ +A luz module that can be trained with \code{\link[=fit]{fit()}}. +} \description{ The setup function is used to set important attributes and method for \code{nn_modules} to be used with Luz. diff --git a/tests/testthat.R b/tests/testthat.R index c1f0dd74..6bb2a3ff 100644 --- a/tests/testthat.R +++ b/tests/testthat.R @@ -1,4 +1,6 @@ library(testthat) library(luz) -test_check("luz") +if (Sys.getenv("TORCH_TEST", unset = 0) == 1) + test_check("luz") +