Skip to content

Commit

Permalink
Merge pull request #70 from JuliaTrustworthyAI/cp-llm-new-attempt
Browse files Browse the repository at this point in the history
Cp llm new attempt
  • Loading branch information
pat-alt authored Jul 5, 2023
2 parents e21970b + 423858f commit 8fdfc48
Show file tree
Hide file tree
Showing 37 changed files with 6,819 additions and 3,380 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
version:
- '1.7'
- '1.8'
- '~1.9.0-0'
- '1.9'
- 'nightly'
os:
- ubuntu-latest
Expand Down
39 changes: 19 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# 🏃 Quick Tour

![](dev/logo/wide_logo.png)

[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/stable/) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://juliatrustworthyai.github.io/ConformalPrediction.jl/dev/) [![Build Status](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/juliatrustworthyai/ConformalPrediction.jl/actions/workflows/CI.yml?query=branch%3Amain) [![Coverage](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/juliatrustworthyai/ConformalPrediction.jl) [![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) [![ColPrac: Contributor’s Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet.png)](https://github.com/SciML/ColPrac) [![Twitter Badge](https://img.shields.io/twitter/url/https/twitter.com/paltmey.svg?style=social&label=Follow%20%40paltmey)](https://twitter.com/paltmey)

`ConformalPrediction.jl` is a package for Predictive Uncertainty Quantification (UQ) through Conformal Prediction (CP) in Julia. It is designed to work with supervised models trained in [MLJ](https://alan-turing-institute.github.io/MLJ.jl/dev/) (Blaom et al. 2020). Conformal Prediction is easy-to-understand, easy-to-use and model-agnostic and it works under minimal distributional assumptions.

## 🏃 Quick Tour

> First time here? Take a quick interactive [tour](https://binder.plutojl.org/v0.19.12/open?url=https%253A%252F%252Fraw.githubusercontent.com%252Fpat-alt%252FConformalPrediction.jl%252Fmain%252Fdocs%252Fpluto%252Fintro.jl) to see what this package can do: [![Binder](https://mybinder.org/badge_logo.svg)](https://binder.plutojl.org/v0.19.12/open?url=https%253A%252F%252Fraw.githubusercontent.com%252Fpat-alt%252FConformalPrediction.jl%252Fmain%252Fdocs%252Fpluto%252Fintro.jl)
The button takes you to a [`Pluto.jl`](https://github.com/fonsp/Pluto.jl) 🎈 notebook hosted on [binder](https://mybinder.org/). In my own experience, this may take some time to load, certainly long enough to get yourself a hot beverage ☕. Alternatively, you can run the notebook locally or skip the tour for now and read on below.
Expand Down Expand Up @@ -106,11 +105,11 @@ ŷ[1:show_first]
```

5-element Vector{Tuple{Float64, Float64}}:
(0.3514065102722679, 2.4948272235282696)
(-0.36580206168104035, 1.7780775120607)
(0.13671800582612756, 2.2792132778975933)
(0.15237308545277795, 2.2801138611534326)
(0.19080981472120032, 2.3863592104933966)
(0.3633641966158244, 2.4931870917039434)
(-0.3996500917580523, 1.7928089786632433)
(0.09653821719666224, 2.284119083077198)
(0.13354256573784634, 2.260005698592606)
(0.21655224395842643, 2.434258746076169)

For simple models like this one, we can call a custom `Plots` recipe on our instance, fit result and data to generate the chart below:

Expand Down Expand Up @@ -138,16 +137,16 @@ println("SSC: $(round(_eval.measurement[2], digits=3))")
per_observation, fitted_params_per_fold,
report_per_fold, train_test_rows
Extract:
┌───────────────────────────────────────────────────────────┬───────────┬──────
│ measure │ operationmeas
├───────────────────────────────────────────────────────────┼───────────┼──────
│ emp_coverage (generic function with 1 method) │ predict │ 0.95
│ size_stratified_coverage (generic function with 1 method) │ predict │ 0.84
└───────────────────────────────────────────────────────────┴───────────┴──────
3 columns omitted
┌─────────────────────────────────────────────────────────┬─────────────┬──────
│ measure │ operation │ measurement1.9
├─────────────────────────────────────────────────────────┼─────────────┼──────
ConformalPrediction.emp_coverage │ predict │ 0.95 │ 0.0
ConformalPrediction.size_stratified_coverage │ predict │ 0.903 │ 0.0
└─────────────────────────────────────────────────────────┴─────────────┴──────
2 columns omitted

Empirical coverage: 0.95
SSC: 0.841
SSC: 0.903

## 📚 Read on

Expand Down Expand Up @@ -196,10 +195,11 @@ The package has been tested for the following supervised models offered by [MLJ]
keys(tested_atomic_models[:regression])
```

KeySet for a Dict{Symbol, Expr} with 4 entries. Keys:
:nearest_neighbor
KeySet for a Dict{Symbol, Expr} with 5 entries. Keys:
:ridge
:lasso
:evo_tree
:light_gbm
:nearest_neighbor
:linear

**Classification**:
Expand All @@ -208,10 +208,9 @@ keys(tested_atomic_models[:regression])
keys(tested_atomic_models[:classification])
```

KeySet for a Dict{Symbol, Expr} with 4 entries. Keys:
KeySet for a Dict{Symbol, Expr} with 3 entries. Keys:
:nearest_neighbor
:evo_tree
:light_gbm
:logistic

### Implemented Evaluation Metrics
Expand Down
2 changes: 1 addition & 1 deletion README.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ crossref:
fig-prefix: Figure
tbl-prefix: Table
bibliography: https://raw.githubusercontent.com/pat-alt/bib/main/bib.bib
jupyter: julia-1.8
jupyter: julia-1.9
---

![](dev/logo/wide_logo.png)
Expand Down
84 changes: 44 additions & 40 deletions README_files/figure-commonmark/cell-11-output-1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
560 changes: 280 additions & 280 deletions README_files/figure-commonmark/cell-7-output-1.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions _freeze/docs/src/how_to_guides/llm/execute-results/md.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"hash": "9ace0a1ec38b37c490957e4b679ebce9",
"result": {
"markdown": "---\ntitle: How to Build a Conformal Chatbot\n---\n\n\n``` @meta\nCurrentModule = ConformalPrediction\n```\n\n\n\n\nLarge Language Models are all the buzz right now. They are used for a variety of tasks, including text classification, question answering, and text generation. In this tutorial, we will show how to conformalize a transformer language model for text classification. We will use the [Banking77](https://arxiv.org/abs/2003.04807) dataset [@casanueva2020efficient], which consists of 13,083 queries from 77 intents. On the model side, we will use the [DistilRoBERTa](https://huggingface.co/mrm8488/distilroberta-finetuned-banking77) model, which is a distilled version of [RoBERTa](https://arxiv.org/abs/1907.11692) [@liu2019roberta] finetuned on the Banking77 dataset.\n\n## Data\n\nThe data was downloaded from [HuggingFace](https://huggingface.co/datasets/PolyAI/banking77) 🤗 (HF) and split into a proper training, calibration, and test set. All that's left to do is to load the data and preprocess it. We add 1 to the labels to make them 1-indexed (sorry Pythonistas 😜)\n\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\n# Get labels:\ndf_labels = CSV.read(\"dev/artifacts/data/banking77/labels.csv\", DataFrame, drop=[1])\nlabels = df_labels[:,1]\n\n# Get data:\ndf_train = CSV.read(\"dev/artifacts/data/banking77/train.csv\", DataFrame, drop=[1])\ndf_cal = CSV.read(\"dev/artifacts/data/banking77/calibration.csv\", DataFrame, drop=[1])\ndf_full_train = vcat(df_train, df_cal)\ntrain_ratio = round(nrow(df_train)/nrow(df_full_train), digits=2)\ndf_test = CSV.read(\"dev/artifacts/data/banking77/test.csv\", DataFrame, drop=[1])\n\n# Preprocess data:\nqueries_train, y_train = collect(df_train.text), categorical(df_train.labels .+ 1)\nqueries_cal, y_cal = collect(df_cal.text), categorical(df_cal.labels .+ 1)\nqueries, y = collect(df_full_train.text), categorical(df_full_train.labels .+ 1)\nqueries_test, y_test = collect(df_test.text), categorical(df_test.labels .+ 1)\n```\n:::\n\n\n## HuggingFace Model\n\nThe model can be loaded from HF straight into our running Julia session using the [`Transformers.jl`](https://github.com/chengchingwen/Transformers.jl/tree/master) package. Below we load the tokenizer `tkr` and the model `mod`. The tokenizer is used to convert the text into a sequence of integers, which is then fed into the model. The model outputs a hidden state, which is then fed into a classifier to get the logits for each class. Finally, the logits are then passed through a softmax function to get the corresponding predicted probabilities. Below we run a few queries through the model to see how it performs.\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n# Load model from HF 🤗:\ntkr = hgf\"mrm8488/distilroberta-finetuned-banking77:tokenizer\"\nmod = hgf\"mrm8488/distilroberta-finetuned-banking77:ForSequenceClassification\"\n\n# Test model:\nquery = [\n \"What is the base of the exchange rates?\",\n \"Why is my card not working?\",\n \"My Apple Pay is not working, what should I do?\",\n]\na = encode(tkr, query)\nb = mod.model(a)\nc = mod.cls(b.hidden_state)\nd = softmax(c.logit)\n[labels[i] for i in Flux.onecold(d)]\n```\n\n::: {.cell-output .cell-output-display execution_count=4}\n```\n3-element Vector{String}:\n \"exchange_rate\"\n \"card_not_working\"\n \"apple_pay_or_google_pay\"\n```\n:::\n:::\n\n\n## `MLJ` Interface\n\nSince our package is interfaced to [`MLJ.jl`](https://alan-turing-institute.github.io/MLJ.jl/dev/), we need to define a wrapper model that conforms to the `MLJ` interface. In order to add the model for general use, we would probably go through [`MLJFlux.jl`](https://github.com/FluxML/MLJFlux.jl), but for this tutorial, we will make our life easy and simply overload the `MLJBase.fit` and `MLJBase.predict` methods. Since the model from HF is already pre-trained and we are not interested in further fine-tuning, we will simply return the model object in the `MLJBase.fit` method. The `MLJBase.predict` method will then take the model object and the query and return the predicted probabilities. We also need to define the `MLJBase.target_scitype` and `MLJBase.predict_mode` methods. The former tells `MLJ` what the output type of the model is, and the latter can be used to retrieve the label with the highest predicted probability.\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nstruct IntentClassifier <: MLJBase.Probabilistic\n tkr::TextEncoders.AbstractTransformerTextEncoder\n mod::HuggingFace.HGFRobertaForSequenceClassification\nend\n\nfunction IntentClassifier(;\n tokenizer::TextEncoders.AbstractTransformerTextEncoder, \n model::HuggingFace.HGFRobertaForSequenceClassification,\n)\n IntentClassifier(tkr, mod)\nend\n\nfunction get_hidden_state(clf::IntentClassifier, query::Union{AbstractString, Vector{<:AbstractString}})\n token = encode(clf.tkr, query)\n hidden_state = clf.mod.model(token).hidden_state\n return hidden_state\nend\n\n# This doesn't actually retrain the model, but it retrieves the classifier object\nfunction MLJBase.fit(clf::IntentClassifier, verbosity, X, y)\n cache=nothing\n report=nothing\n fitresult = (clf = clf.mod.cls, labels = levels(y))\n return fitresult, cache, report\nend\n\nfunction MLJBase.predict(clf::IntentClassifier, fitresult, Xnew)\n output = fitresult.clf(get_hidden_state(clf, Xnew))\n p̂ = UnivariateFinite(fitresult.labels,softmax(output.logit)',pool=missing)\n return p̂\nend\n\nMLJBase.target_scitype(clf::IntentClassifier) = AbstractVector{<:Finite}\n\nMLJBase.predict_mode(clf::IntentClassifier, fitresult, Xnew) = mode.(MLJBase.predict(clf, fitresult, Xnew))\n```\n:::\n\n\nTo test that everything is working as expected, we fit the model and generated predictions for a subset of the test data:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\nclf = IntentClassifier(tkr, mod)\ntop_n = 10\nfitresult, _, _ = MLJBase.fit(clf, 1, nothing, y_test[1:top_n])\n@time ŷ = MLJBase.predict(clf, fitresult, queries_test[1:top_n]);\n```\n:::\n\n\n## Conformal Chatbot\n\nTo turn the wrapped, pre-trained model into a conformal intent classifier, we can now rely on standard API calls. We first wrap our atomic model where we also specify the desired coverage rate and method. Since even simple forward passes are computationally expensive for our (small) LLM, we rely on Simple Inductive Conformal Classification.\n\n::: {.cell execution_count=6}\n``` {.julia .cell-code}\nconf_model = conformal_model(clf; coverage=0.95, method=:simple_inductive, train_ratio=train_ratio)\nmach = machine(conf_model, queries, y)\n@time fit!(mach)\nSerialization.serialize(\"dev/artifacts/models/banking77/simple_inductive.jls\", mach)\n```\n:::\n\n\nFinally, we use our conformal LLM to build a simple and yet powerful chatbot that runs directly in the Julia REPL. Without dwelling on the details too much, the `conformal_chatbot` works as follows:\n\n1. Prompt user to explain their intent.\n2. Feed user input through conformal LLM and present the output to the user.\n3. If the conformal prediction sets includes more than one label, prompt the user to either refine their input or choose one of the options included in the set.\n\n::: {.cell execution_count=7}\n``` {.julia .cell-code}\nmach = Serialization.deserialize(\"dev/artifacts/models/banking77/simple_inductive.jls\")\n\nfunction prediction_set(mach, query::String)\n p̂ = MLJBase.predict(mach, query)[1]\n probs = pdf.(p̂, collect(1:77))\n in_set = findall(probs .!= 0)\n labels_in_set = labels[in_set]\n probs_in_set = probs[in_set]\n _order = sortperm(-probs_in_set)\n plt = UnicodePlots.barplot(labels_in_set[_order], probs_in_set[_order], title=\"Possible Intents\")\n return labels_in_set, plt\nend\n\nfunction conformal_chatbot()\n println(\"👋 Hi, I'm a Julia, your conformal chatbot. I'm here to help you with your banking query. Ask me anything or type 'exit' to exit ...\\n\")\n completed = false\n queries = \"\"\n while !completed\n query = readline()\n queries = queries * \",\" * query\n labels, plt = prediction_set(mach, queries)\n if length(labels) > 1\n println(\"🤔 Hmmm ... I can think of several options here. If any of these applies, simply type the corresponding number (e.g. '1' for the first option). Otherwise, can you refine your question, please?\\n\")\n println(plt)\n else\n println(\"🥳 I think you mean $(labels[1]). Correct?\")\n end\n\n # Exit:\n if query == \"exit\"\n println(\"👋 Bye!\")\n break\n end\n if query ∈ string.(collect(1:77))\n println(\"👍 Great! You've chosen '$(labels[parse(Int64, query)])'. I'm glad I could help you. Have a nice day!\")\n completed = true\n end\n end\nend\n```\n:::\n\n\nBelow we show the output for two example queries. The first one is very ambiguous. As expected, the size of the prediction set is therefore large. \n\n::: {.cell execution_count=8}\n``` {.julia .cell-code}\nambiguous_query = \"transfer mondey?\"\nprediction_set(mach, ambiguous_query)[2]\n```\n:::\n\n\nThe more refined version of the prompt yields a smaller prediction set: less ambiguous prompts result in lower predictive uncertainty. \n\n::: {.cell execution_count=9}\n``` {.julia .cell-code}\nrefined_query = \"I tried to transfer money to my friend, but it failed.\"\nprediction_set(mach, refined_query)[2]\n```\n:::\n\n\nBelow we include a short demo video that shows the REPL-based chatbot in action.\n\n![](/docs/src/www/demo_llm.gif)\n\n## Final Remarks\n\nThis work was done in collaboration with colleagues at ING as part of the ING Analytics 2023 Experiment Week. Our team demonstrated that Conformal Prediction provides a powerful and principled alternative to top-*K* intent classification. We won the first prize by popular vote.\n\n",
"supporting": [
"llm_files"
],
"filters": []
}
}
Loading

0 comments on commit 8fdfc48

Please sign in to comment.