Skip to content

Commit

Permalink
fix #140
Browse files Browse the repository at this point in the history
  • Loading branch information
Christophe-Regouby committed Dec 29, 2023
1 parent fb431c5 commit 7409f09
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 15 deletions.
18 changes: 10 additions & 8 deletions R/activation.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#' @param alpha (float) the weight of ELU activation component.
#' @param beta (float) the weight of PReLU activation component.
#' @param gamma (float) the weight of SiLU activation component.
#' @param init (float): the initial value of \eqn{a} of PReLU. Default: 0.25.
#' @param weight (torch_tensor): the initial value of \eqn{weight} of PReLU. Default: 0.25.
#'
#' @return an activation function computing
#' \eqn{\mathbf{MBwLU(input) = \alpha \times ELU(input) + \beta \times PReLU(input) + \gamma \times SiLU(input)}}
Expand All @@ -20,14 +20,15 @@
#' @export
nn_mb_wlu <- torch::nn_module(
"multibranch Weighted Linear Unit",
initialize = function(alpha = 0.6, beta = 0.2, gamma = 0.2, init = 0.25) {
initialize = function(alpha = 0.6, beta = 0.2, gamma = 0.2, weight = torch::torch_tensor(0.25)) {
stopifnot("weight must be a torch_tensor()" = inherits(weight, "torch_tensor"))
self$alpha <- alpha
self$beta <- beta
self$gamma <- gamma
self$init <- init
self$weight <- weight
},
forward = function(input) {
nnf_mb_wlu(input, self$alpha, self$beta, self$gamma, self$init)
nnf_mb_wlu(input, self$alpha, self$beta, self$gamma, self$weight)
}
)

Expand All @@ -36,9 +37,10 @@ nn_mb_wlu <- torch::nn_module(
#' @seealso [nn_mb_wlu()].
#' @export
#' @rdname nn_mb_wlu
nnf_mb_wlu <- function(input, alpha = 0.6, beta = 0.2, gamma = 0.2, init = 0.25) {
alpha * torch::nnf_elu(input) +
beta * torch::nnf_prelu(input, init) +
gamma * torch::nnf_silu(input)
nnf_mb_wlu <- function(input, alpha = 0.6, beta = 0.2, gamma = 0.2, weight = torch::torch_tensor(0.25)) {
stopifnot("weight and input must reside on the same device" = weight$device == input$device)
alpha * torch::nnf_elu(input) +
beta * torch::nnf_prelu(input, weight) +
gamma * torch::nnf_silu(input)

}
10 changes: 8 additions & 2 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,18 @@ interpretabnet_config <- function(mask_type = "entmax",
mlp_hidden_multiplier = c(4,2),
mlp_activation = NULL,
encoder_activation = nn_mb_wlu(), ...) {
tabnet_config(mask_type = mask_type,
interpretabnet_conf <- tabnet_config(mask_type = mask_type,
mlp_hidden_multiplier = mlp_hidden_multiplier,
mlp_activation = mlp_activation,
encoder_activation = encoder_activation,
...)

# align nn_mb_wlu weight device with the config device
device <- get_device_from_config(interpretabnet_conf)
if (!grepl(device,interpretabnet_conf$encoder_activation$weight$device )) {
# move the weight to the config device
interpretabnet_conf$encoder_activation$weight <- interpretabnet_conf$encoder_activation$weight$to(device = device)
}
interpretabnet_conf
}

get_constr_output <- function(x, R) {
Expand Down
17 changes: 14 additions & 3 deletions man/nn_mb_wlu.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/tabnet_config.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 16 additions & 1 deletion tests/testthat/test-activation.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
test_that("multibranch_weighted_linear_unit nn_module works", {
test_that("multibranch_weighted_linear_unit activation works", {
mb_wlu <- nn_mb_wlu()
input <- torch::torch_tensor(c(-1.0, 0.0, 1.0))
expected_output <- torch::torch_tensor(c(-0.48306063, 0.0, 0.94621176))
Expand All @@ -7,3 +7,18 @@ test_that("multibranch_weighted_linear_unit nn_module works", {

})

test_that("multibranch_weighted_linear_unit correctly prevent weight not being a tensor", {
expect_error(mb_wlu <- nn_mb_wlu( weight = 0.25),
regexp = "must be a torch_tensor")
})

test_that("multibranch_weighted_linear_unit correctly prevent weight not being on the same device", {
skip_if_not(torch::backends_openmp_is_available())
weight <- torch::torch_tensor(0.25)$to(device = "cpu")
z <- torch::torch_randr(c(2,2))$to(device = "openmp")

expect_no_error(mb_wlu <- nn_mb_wlu( weight = weight))
expect_error(mb_wlu(z),
regexp = "reside on the same device")
})

0 comments on commit 7409f09

Please sign in to comment.