From 7409f0997e86b794c690f52c2e46f7f36df16769 Mon Sep 17 00:00:00 2001 From: Christophe REGOUBY Date: Fri, 29 Dec 2023 12:44:19 +0100 Subject: [PATCH] fix #140 --- R/activation.R | 18 ++++++++++-------- R/model.R | 10 ++++++++-- man/nn_mb_wlu.Rd | 17 ++++++++++++++--- man/tabnet_config.Rd | 8 +++++++- tests/testthat/test-activation.R | 17 ++++++++++++++++- 5 files changed, 55 insertions(+), 15 deletions(-) diff --git a/R/activation.R b/R/activation.R index 65b32078..8e849a83 100644 --- a/R/activation.R +++ b/R/activation.R @@ -5,7 +5,7 @@ #' @param alpha (float) the weight of ELU activation component. #' @param beta (float) the weight of PReLU activation component. #' @param gamma (float) the weight of SiLU activation component. -#' @param init (float): the initial value of \eqn{a} of PReLU. Default: 0.25. +#' @param weight (torch_tensor): the initial value of \eqn{weight} of PReLU. Default: 0.25. #' #' @return an activation function computing #' \eqn{\mathbf{MBwLU(input) = \alpha \times ELU(input) + \beta \times PReLU(input) + \gamma \times SiLU(input)}} @@ -20,14 +20,15 @@ #' @export nn_mb_wlu <- torch::nn_module( "multibranch Weighted Linear Unit", - initialize = function(alpha = 0.6, beta = 0.2, gamma = 0.2, init = 0.25) { + initialize = function(alpha = 0.6, beta = 0.2, gamma = 0.2, weight = torch::torch_tensor(0.25)) { + stopifnot("weight must be a torch_tensor()" = inherits(weight, "torch_tensor")) self$alpha <- alpha self$beta <- beta self$gamma <- gamma - self$init <- init + self$weight <- weight }, forward = function(input) { - nnf_mb_wlu(input, self$alpha, self$beta, self$gamma, self$init) + nnf_mb_wlu(input, self$alpha, self$beta, self$gamma, self$weight) } ) @@ -36,9 +37,10 @@ nn_mb_wlu <- torch::nn_module( #' @seealso [nn_mb_wlu()]. #' @export #' @rdname nn_mb_wlu -nnf_mb_wlu <- function(input, alpha = 0.6, beta = 0.2, gamma = 0.2, init = 0.25) { - alpha * torch::nnf_elu(input) + - beta * torch::nnf_prelu(input, init) + - gamma * torch::nnf_silu(input) +nnf_mb_wlu <- function(input, alpha = 0.6, beta = 0.2, gamma = 0.2, weight = torch::torch_tensor(0.25)) { + stopifnot("weight and input must reside on the same device" = weight$device == input$device) + alpha * torch::nnf_elu(input) + + beta * torch::nnf_prelu(input, weight) + + gamma * torch::nnf_silu(input) } diff --git a/R/model.R b/R/model.R index fb1ad1e4..9c5c581e 100644 --- a/R/model.R +++ b/R/model.R @@ -256,12 +256,18 @@ interpretabnet_config <- function(mask_type = "entmax", mlp_hidden_multiplier = c(4,2), mlp_activation = NULL, encoder_activation = nn_mb_wlu(), ...) { - tabnet_config(mask_type = mask_type, + interpretabnet_conf <- tabnet_config(mask_type = mask_type, mlp_hidden_multiplier = mlp_hidden_multiplier, mlp_activation = mlp_activation, encoder_activation = encoder_activation, ...) - + # align nn_mb_wlu weight device with the config device + device <- get_device_from_config(interpretabnet_conf) + if (!grepl(device,interpretabnet_conf$encoder_activation$weight$device )) { + # move the weight to the config device + interpretabnet_conf$encoder_activation$weight <- interpretabnet_conf$encoder_activation$weight$to(device = device) + } + interpretabnet_conf } get_constr_output <- function(x, R) { diff --git a/man/nn_mb_wlu.Rd b/man/nn_mb_wlu.Rd index 8bb989c7..cc8ebc84 100644 --- a/man/nn_mb_wlu.Rd +++ b/man/nn_mb_wlu.Rd @@ -5,9 +5,20 @@ \alias{nnf_mb_wlu} \title{Multi-branch Weighted Linear Unit (MB-wLU) nn module.} \usage{ -nn_mb_wlu(alpha = 0.6, beta = 0.2, gamma = 0.2, init = 0.25) +nn_mb_wlu( + alpha = 0.6, + beta = 0.2, + gamma = 0.2, + weight = torch::torch_tensor(0.25) +) -nnf_mb_wlu(input, alpha = 0.6, beta = 0.2, gamma = 0.2, init = 0.25) +nnf_mb_wlu( + input, + alpha = 0.6, + beta = 0.2, + gamma = 0.2, + weight = torch::torch_tensor(0.25) +) } \arguments{ \item{alpha}{(float) the weight of ELU activation component.} @@ -16,7 +27,7 @@ nnf_mb_wlu(input, alpha = 0.6, beta = 0.2, gamma = 0.2, init = 0.25) \item{gamma}{(float) the weight of SiLU activation component.} -\item{init}{(float): the initial value of \eqn{a} of PReLU. Default: 0.25.} +\item{weight}{(torch_tensor): the initial value of \eqn{weight} of PReLU. Default: 0.25.} \item{input}{(N,*) tensor, where * means, any number of additional dimensions} diff --git a/man/tabnet_config.Rd b/man/tabnet_config.Rd index a167c215..25802272 100644 --- a/man/tabnet_config.Rd +++ b/man/tabnet_config.Rd @@ -45,7 +45,13 @@ tabnet_config( skip_importance = FALSE ) -interpretabnet_config(...) +interpretabnet_config( + mask_type = "entmax", + mlp_hidden_multiplier = c(4, 2), + mlp_activation = NULL, + encoder_activation = nn_mb_wlu(), + ... +) } \arguments{ \item{cat_emb_dim}{Size of the embedding of categorical features. If a single interger, all categorical diff --git a/tests/testthat/test-activation.R b/tests/testthat/test-activation.R index dd87c35f..66411058 100644 --- a/tests/testthat/test-activation.R +++ b/tests/testthat/test-activation.R @@ -1,4 +1,4 @@ -test_that("multibranch_weighted_linear_unit nn_module works", { +test_that("multibranch_weighted_linear_unit activation works", { mb_wlu <- nn_mb_wlu() input <- torch::torch_tensor(c(-1.0, 0.0, 1.0)) expected_output <- torch::torch_tensor(c(-0.48306063, 0.0, 0.94621176)) @@ -7,3 +7,18 @@ test_that("multibranch_weighted_linear_unit nn_module works", { }) +test_that("multibranch_weighted_linear_unit correctly prevent weight not being a tensor", { + expect_error(mb_wlu <- nn_mb_wlu( weight = 0.25), + regexp = "must be a torch_tensor") +}) + +test_that("multibranch_weighted_linear_unit correctly prevent weight not being on the same device", { + skip_if_not(torch::backends_openmp_is_available()) + weight <- torch::torch_tensor(0.25)$to(device = "cpu") + z <- torch::torch_randr(c(2,2))$to(device = "openmp") + + expect_no_error(mb_wlu <- nn_mb_wlu( weight = weight)) + expect_error(mb_wlu(z), + regexp = "reside on the same device") +}) +