Skip to content

Commit

Permalink
MacOs fixes (#141)
Browse files Browse the repository at this point in the history
* Trigger CI

* update gpu ci

* update container

* some debug info

* upodate image

* try that

* ask for CRAN to0rch

* fix mixed precision

* fix desc

* update snapshot

* Trigger ci

* Re-trigger CI
  • Loading branch information
dfalbel committed Feb 29, 2024
1 parent eed4be8 commit fdc660b
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 16 deletions.
12 changes: 7 additions & 5 deletions .github/workflows/R-CMD-check.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,20 @@ jobs:
extra-packages: any::rcmdcheck
needs: check

- run: |
print(torch::torch_is_installed())
print(torch::backends_mps_is_available())
shell: Rscript {0}
- uses: r-lib/actions/check-r-package@v2
with:
error-on: '"error"'
args: 'c("--no-multiarch", "--no-manual", "--as-cran")'

GPU:
runs-on: ['self-hosted', 'gce', 'gpu']
runs-on: ['self-hosted', 'gpu-local']
name: 'gpu'

container:
image: nvidia/cuda:11.7.1-cudnn8-devel-ubuntu18.04
options: --runtime=nvidia --gpus all
container: {image: 'nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04', options: '--gpus all --runtime=nvidia'}

env:
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test-coverage-pak.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ name: test-coverage
jobs:
test-coverage:

runs-on: ['self-hosted', 'gce', 'gpu']
runs-on: ['self-hosted', 'gpu-local']

container:
image: nvidia/cuda:11.7.1-cudnn8-devel-ubuntu18.04
image: nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
options: --gpus all

env:
Expand Down
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: luz
Title: Higher Level 'API' for 'torch'
Version: 0.4.0.9000
Version: 0.4.0.9002
Authors@R: c(
person("Daniel", "Falbel", email = "[email protected]", role = c("aut", "cre", "cph")),
person(family = "RStudio", role = c("cph"))
Expand All @@ -17,7 +17,7 @@ Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
Imports:
torch (>= 0.9.0),
torch (>= 0.11.9000),
magrittr,
zeallot,
rlang (>= 1.0.0),
Expand Down Expand Up @@ -69,4 +69,4 @@ Collate:
'reexports.R'
'serialization.R'
Remotes:
mlverse/torch
mlverse/torch
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Fixed a bug when trying to resume models trained with learning rate schedulers. (#137)
* Added support for learning rate schedulers that take the current loss as arguments. (#140)


# luz 0.4.0

## Breaking changes
Expand Down
2 changes: 1 addition & 1 deletion R/accelerator.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ LuzAcceleratorState <- R6::R6Class(

if (torch::cuda_is_available())
paste0("cuda:", index)
else if (torch::backends_mps_is_available())
else if (can_use_mps())
"mps"
else
"cpu"
Expand Down
7 changes: 4 additions & 3 deletions R/callbacks-amp.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ NULL
luz_callback_mixed_precision <- luz_callback(
"mixed_precision_callback",
initialize = function(...) {
self$autocast_env <- rlang::new_environment()
self$autocast_context <- NULL
self$scaler <- torch::cuda_amp_grad_scaler(...)
},
on_fit_begin = function() {
Expand All @@ -30,10 +30,11 @@ luz_callback_mixed_precision <- luz_callback(
},
on_train_batch_begin = function() {
device_type <- if (grepl("cuda", ctx$device)) "cuda" else ctx$device
torch::local_autocast(device_type = device_type, .env = self$autocast_env)
self$autocast_context <- torch::set_autocast(device_type = device_type)
},
on_train_batch_after_loss = function() {
withr::deferred_run(self$autocast_env)
torch::unset_autocast(self$autocast_context)
self$autocast_context <- NULL
},
on_train_batch_before_backward = function() {
torch::with_enable_grad({
Expand Down
7 changes: 6 additions & 1 deletion R/module.R
Original file line number Diff line number Diff line change
Expand Up @@ -603,8 +603,13 @@ get_metrics.luz_module_evaluation <- function(object, ...) {
res[, c("metric", "value")]
}

can_use_mps <- function() {
arch <- Sys.info()["machine"]
"arm64" %in% arch && torch::backends_mps_is_available()
}

enable_mps_fallback <- function() {
if (!torch::backends_mps_is_available())
if (!can_use_mps())
return(invisible(NULL))

fallback <- Sys.getenv("PYTORCH_ENABLE_MPS_FALLBACK", unset = "")
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/_snaps/module-plot/ggplot2-histogram.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit fdc660b

Please sign in to comment.