diff --git a/articles/examples/chargpt.html b/articles/examples/chargpt.html new file mode 100644 index 00000000..98f09978 --- /dev/null +++ b/articles/examples/chargpt.html @@ -0,0 +1,234 @@ + + + + + + + + +CharGPT • luz + + + + + + + + + + Skip to contents + + +
+ + + +
+
+ + + +

This example is inspired by the chargpt project by Andrey Karpathy. We are going to train character-level language model on Shakespeare texts.

+

We first load the libraries that we plan to use:

+ +

Next we define the torch dataset that will pre-process data for the model. It splits the text into a character vector, each element containing exactly one character.

+

Then lists all unique characters into the vocab attribute. The order of the characters in the vocabulary is used to encode each character to an integer value, that will be used in the embedding layer.

+

The .getitem() method, can take chunks of block_size characters and encode them into their integer representation.

+
+url <- "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
+
+char_dataset <- torch::dataset(
+    initialize = function(data, block_size = 128) {
+        self$block_size <- block_size
+        self$data <- stringr::str_split_1(data, "")
+
+        self$data_size <- length(self$data)
+        self$vocab <- unique(self$data)
+        self$vocab_size <- length(self$vocab)
+    },
+    .getitem = function(i) {
+        chunk <- self$data[i + seq_len(self$block_size + 1)]
+        idx <- match(chunk, self$vocab)
+        list(
+            x = head(idx, self$block_size),
+            y = tail(idx, self$block_size)
+        )
+    },
+    .length = function() {
+        self$data_size - self$block_size - 1L # this is to account the last value
+    }
+)
+
+dataset <- char_dataset(readr::read_file(url))
+dataset[1] # this allows us to see an element of the dataset
+

We then define the neural net we are going to train. Defining a GPT-2 model is quite verbose, so we are going to use the minhub implementation directly. You can find the full model definition here, and this code is entirely self-contained, so you don’t need to install minhub, if you don’t want to.

+

We also implemented the generate method for the model, that allows one to generate completions using the model. It applies the model in a loop, at each iteration prediction what’s the next character.

+
+model <- torch::nn_module(
+    initialize = function(vocab_size) {
+        # remotes::install_github("mlverse/minhub")
+        self$gpt <- minhub::gpt2(
+            vocab_size = vocab_size,
+            n_layer = 6,
+            n_head = 6,
+            n_embd = 192
+        )
+    },
+    forward = function(x) {
+        # we have to transpose to make the vocabulary the last dimension
+        self$gpt(x)$transpose(2,3)
+    },
+    generate = function(x, temperature = 1, iter = 50, top_k = 10) {
+        # samples from the model givn a context vector.
+        for (i in seq_len(iter)) {
+            logits <- self$forward(x)[,,-1]
+            logits <- logits/temperature
+            c(prob, ind) %<-% logits$topk(top_k)
+            logits <- torch_full_like(logits, -Inf)$scatter_(-1, ind, prob)
+            logits <- nnf_softmax(logits, dim = -1)
+            id_next <- torch_multinomial(logits, num_samples = 1)
+            x <- torch_cat(list(x, id_next), dim = 2)
+        }
+        x
+    }
+)
+

Next, we implemented a callback that is used for nicely displaying generated samples during the model training:

+
+# samples from the model using the context.
+generate <- function(model, vocab, context, ...) {
+  local_no_grad() # disables gradient for sampling
+  x <- match(stringr::str_split_1(context, ""), vocab)
+  x <- torch_tensor(x)[NULL,]$to(device = model$device)
+  content <- as.integer(model$generate(x, ...)$cpu())
+  paste0(vocab[content], collapse = "")
+}
+
+display_cb <- luz_callback(
+  initialize = function(iter = 500) {
+    self$iter <- iter # print every 500 iterations
+  },
+  on_train_batch_end = function() {
+    if (!(ctx$iter %% self$iter == 0))
+      return()
+
+    ctx$model$eval()
+    with_no_grad({
+      # sample from the model...
+      context <- "O God, O God!"
+      text <- generate(ctx$model, dataset$vocab, context, iter = 100)
+      cli::cli_h3(paste0("Iter ", ctx$iter))
+      cli::cli_text(text)
+    })
+
+  }
+)
+

Finally, you can train the model using fit:

+
+fitted <- model |>
+    setup(
+        loss = nn_cross_entropy_loss(),
+        optimizer = optim_adam
+    ) |>
+    set_opt_hparams(lr = 5e-4) |>
+    set_hparams(vocab_size = dataset$vocab_size) |>
+    fit(
+      dataset,
+      dataloader_options = list(batch_size = 128, shuffle = TRUE),
+      epochs = 1,
+      callbacks = list(
+        display_cb(iter = 500),
+        luz_callback_gradient_clip(max_norm = 1)
+      )
+    )
+

One epoch, is reasonable for this dataset and takes ~1h on the M1 MBP. You can generate new samples with:

+
+context <- "O God, O God!"
+text <- generate(fitted$model, dataset$vocab, context, iter = 100)
+cat(text)
+
+
+ + + + +
+ + + + + + + diff --git a/articles/examples/chargpt_files/accessible-code-block-0.0.1/empty-anchor.js b/articles/examples/chargpt_files/accessible-code-block-0.0.1/empty-anchor.js new file mode 100644 index 00000000..ca349fd6 --- /dev/null +++ b/articles/examples/chargpt_files/accessible-code-block-0.0.1/empty-anchor.js @@ -0,0 +1,15 @@ +// Hide empty tag within highlighted CodeBlock for screen reader accessibility (see https://github.com/jgm/pandoc/issues/6352#issuecomment-626106786) --> +// v0.0.1 +// Written by JooYoung Seo (jooyoung@psu.edu) and Atsushi Yasumoto on June 1st, 2020. + +document.addEventListener('DOMContentLoaded', function() { + const codeList = document.getElementsByClassName("sourceCode"); + for (var i = 0; i < codeList.length; i++) { + var linkList = codeList[i].getElementsByTagName('a'); + for (var j = 0; j < linkList.length; j++) { + if (linkList[j].innerHTML === "") { + linkList[j].setAttribute('aria-hidden', 'true'); + } + } + } +}); diff --git a/articles/examples/index.html b/articles/examples/index.html index 280550f6..4714a06a 100644 --- a/articles/examples/index.html +++ b/articles/examples/index.html @@ -95,6 +95,20 @@
+CharGPT +
+intermediate +

+Train a character-level GPT-2 on Shakespeare texts. +

+
See code +
+
+ +
+
+
+
Binary classification
basic @@ -105,6 +119,8 @@
+ +
@@ -119,8 +135,6 @@
-
-
@@ -135,6 +149,8 @@
+
+
@@ -149,8 +165,6 @@
-
-
@@ -165,6 +179,8 @@
+
+
@@ -179,8 +195,6 @@
-
-
@@ -195,6 +209,8 @@
+
+
@@ -209,8 +225,6 @@
-
-
diff --git a/articles/index.html b/articles/index.html index 423799e6..2bd5852c 100644 --- a/articles/index.html +++ b/articles/index.html @@ -63,6 +63,8 @@

All vignettes

Accelerator API
+
+
CharGPT
Checkpointing your models
diff --git a/pkgdown.yml b/pkgdown.yml index 4e3a701e..3b8ad870 100644 --- a/pkgdown.yml +++ b/pkgdown.yml @@ -3,6 +3,7 @@ pkgdown: 2.0.7.9000 pkgdown_sha: c9206802f2888992de92aa41f517ba7812f05331 articles: accelerator: accelerator.html + chargpt: examples/chargpt.html checkpoints: checkpoints.html custom-loop: custom-loop.html dogs-vs-cats-binary-classification: examples/dogs-vs-cats-binary-classification.html @@ -17,5 +18,5 @@ articles: mnist-triplet: examples/mnist-triplet.html pets-unet: examples/pets-unet.html text-classification: examples/text-classification.html -last_built: 2023-09-12T13:53Z +last_built: 2023-09-15T17:29Z diff --git a/reference/lr_finder-1.png b/reference/lr_finder-1.png index 0b597a58..6801c27f 100644 Binary files a/reference/lr_finder-1.png and b/reference/lr_finder-1.png differ diff --git a/reference/luz_callback_auto_resume.html b/reference/luz_callback_auto_resume.html index 5387fd5f..2013b39d 100644 --- a/reference/luz_callback_auto_resume.html +++ b/reference/luz_callback_auto_resume.html @@ -177,16 +177,16 @@

Examples#> Caused by error in `self[[callback_nm]]()`: #> ! Error on epoch 5 #> set metric epoch value -#> 1 train loss 1 1.188841 -#> 2 train loss 2 1.083670 -#> 3 train loss 3 1.050834 -#> 4 train loss 4 1.042924 -#> 5 train loss 5 1.041700 -#> 6 train loss 6 1.034182 -#> 7 train loss 7 1.034796 -#> 8 train loss 8 1.032502 -#> 9 train loss 9 1.039884 -#> 10 train loss 10 1.036183 +#> 1 train loss 1 1.302326 +#> 2 train loss 2 1.141849 +#> 3 train loss 3 1.094023 +#> 4 train loss 4 1.082328 +#> 5 train loss 5 1.083923 +#> 6 train loss 6 1.072870 +#> 7 train loss 7 1.083111 +#> 8 train loss 8 1.079866 +#> 9 train loss 9 1.074621 +#> 10 train loss 10 1.075743