Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

268 add mpi support for benchmarking #270

Merged
merged 36 commits into from
Aug 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2d292cc
this is already looking good
pat-alt Aug 23, 2023
058a83c
not sure any of this makes sense
pat-alt Aug 24, 2023
67aa9b2
maybe more of a global wrapper?
pat-alt Aug 24, 2023
76fa6df
hmmm
pat-alt Aug 24, 2023
b1cf8bf
that's sort of working
pat-alt Aug 24, 2023
ae86659
a possible (?) solution using traits
pat-alt Aug 24, 2023
65acdff
somewhat incredible that this seems to alraedy be working
pat-alt Aug 24, 2023
c2fb6a6
trait for evaluate function needs to be fixed
pat-alt Aug 24, 2023
0cb5177
progress on docs
pat-alt Aug 25, 2023
3745e63
not as easy as I thought
pat-alt Aug 25, 2023
27ad0ef
looks like a macro is the way to go after all?
pat-alt Aug 25, 2023
87a1ba2
fuck this
pat-alt Aug 25, 2023
0cc0cb8
macro is functional
pat-alt Aug 25, 2023
2e22e58
small tings
pat-alt Aug 29, 2023
674cd28
progress
pat-alt Aug 29, 2023
cb645ca
macro also doesn't seem to resolve the issue that mpiexec just doesn'…
pat-alt Aug 29, 2023
2680825
finally
pat-alt Aug 29, 2023
99eaa64
seems to be working now
pat-alt Aug 29, 2023
0f02c69
problem is that now using parallelize returns nothing in some case
pat-alt Aug 29, 2023
e3543e3
finally working
pat-alt Aug 29, 2023
3d4cd6a
just needs polishing and proper testing
pat-alt Aug 29, 2023
002dad8
just needs polishing and proper testing
pat-alt Aug 29, 2023
0311295
done
pat-alt Aug 30, 2023
02d2630
good old formatting
pat-alt Aug 30, 2023
66eec6e
moved RCall into extension
pat-alt Aug 30, 2023
9010843
good old formatting
pat-alt Aug 30, 2023
faee9ed
moved MPI into an extension
pat-alt Aug 30, 2023
76b74eb
good old formatting
pat-alt Aug 30, 2023
c5e0e9c
also added PythonCall as extension
pat-alt Aug 30, 2023
914942b
good old formatting
pat-alt Aug 30, 2023
484cbc1
removed SliceMap
pat-alt Aug 30, 2023
626e0a9
sorted out errors
pat-alt Aug 30, 2023
7508653
good old formatting
pat-alt Aug 30, 2023
bbb448e
another attempt
pat-alt Aug 30, 2023
8362b62
now using empty functions instead
pat-alt Aug 30, 2023
139aae0
moved extensions into single file to avoid issues with file rewrite o…
pat-alt Aug 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 0 additions & 40 deletions .github/workflows/CI-development.yml

This file was deleted.

2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,5 @@ dev/artifacts/upload/
dev/resources/build

/.luarc.json

**/LocalPreferences.toml
24 changes: 17 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CounterfactualExplanations"
uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
authors = ["Patrick Altmeyer <[email protected]>"]
version = "0.1.14"
version = "0.1.15"

[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
Expand All @@ -14,7 +14,6 @@ DecisionTree = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EvoTrees = "f6006082-12f8-11e9-0c9c-0d5d367ab1e5"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -25,20 +24,28 @@ MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411"
NearestNeighborModels = "636a865e-7cf4-491e-846c-de09b730eb36"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SliceMap = "82cb661a-3f19-5665-9e27-df437c7e54c8"
SnoopPrecompile = "66db9d55-30c0-4569-8b51-7e840670fc0c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
UMAP = "c4f8c510-2410-5be4-91d7-4fbaeb39457e"

[weakdeps]
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"

[extensions]
MPIExt = "MPI"
PythonCallExt = "PythonCall"
RCallExt = "RCall"

[compat]
CSV = "0.10"
CUDA = "3, 4"
Expand All @@ -50,28 +57,31 @@ DecisionTree = "0.12.3"
Distributions = "0.25.97"
EvoTrees = "0.14, 0.15, 0.16"
Flux = "0.12, 0.13, 0.14"
JuliaFormatter = "1.0.34"
LaplaceRedux = "0.1"
MLDatasets = "0.7"
MLJBase = "0.21"
MLJDecisionTreeInterface = "0.4.0"
MLJModels = "0.16"
MLUtils = "0.2, 0.3, 0.4"
MPI = "0.20"
MultivariateStats = "0.9, 0.10"
NearestNeighborModels = "0.2"
Parameters = "0.12"
Plots = "1.38.2"
ProgressMeter = "1"
PythonCall = "0.9.13"
RCall = "0.13.15"
SliceMap = "0.2"
SnoopPrecompile = "1.0"
StatsBase = "0.33, 0.34"
Tables = "1"
UMAP = "0.1"
julia = "1.6, 1.7, 1.8, 1.9"

[extras]
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
4 changes: 2 additions & 2 deletions _freeze/docs/src/tutorials/generators/execute-results/md.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"hash": "0d1259894dfddbd812a763bef8a29e18",
"hash": "82cbe7f5a19244545a3377d525b02874",
"result": {
"markdown": "---\ntitle: Handling Generators\n---\n\n\n```@meta\nCurrentModule = CounterfactualExplanations \n```\n\n\n\n\nGenerating Counterfactual Explanations can be seen as a generative modelling task because it involves generating samples in the input space: $x \\sim \\mathcal{X}$. In this tutorial, we will introduce how Counterfactual `GradientBasedGenerator`s are used. They are discussed in more detail in the explanatory section of the documentation.\n\n## Composable Generators \n\n| !!! warning \\\"Breaking Changes Expected\\\"\n| Work on this feature is still in its very early stages and breaking changes should be expected. \n\nOne of the key objectives for this package is **Composability**. It turns out that many of the various counterfactual generators that have been proposed in the literature, essentially do the same thing: they optimize an objective function. Formally we have,\n\n$$\n\\begin{aligned}\n\\mathbf{s}^\\prime &= \\arg \\min_{\\mathbf{s}^\\prime \\in \\mathcal{S}} \\left\\{ {\\text{yloss}(M(f(\\mathbf{s}^\\prime)),y^*)}+ \\lambda {\\text{cost}(f(\\mathbf{s}^\\prime)) } \\right\\} \n\\end{aligned} \n$$ {#eq-general}\n\nwhere $\\text{yloss}$ denotes the main loss function and $\\text{cost}$ is a penalty term [@altmeyer2023endogenous]. \n\nWithout going into further detail here, the important thing to mention is that @eq-general very closely describes how counterfactual search is actually implemented in the package. In other words, all off-the-shelf generators currently implemented work with that same objective. They just vary in the way that penalties are defined, for example. This gives rise to an interesting idea: \n\n> Why not compose generators that combine ideas from different off-the-shelf generators?\n\nThe [`GradientBasedGenerator`](@ref) class provides a straightforward way to do this, without requiring users to build custom `GradientBasedGenerator`s from scratch. It can be instantiated as follows:\n\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\ngenerator = GradientBasedGenerator()\n```\n:::\n\n\nBy default, this creates a `generator` that simply performs gradient descent without any penalties. To modify the behaviour of the `generator`, you can define the counterfactual search objective function using the [`@objective`](@ref) macro:\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n@objective(generator, logitbinarycrossentropy + 0.1distance_l2 + 1.0ddp_diversity)\n```\n:::\n\n\nHere we have essentially created a version of the [`DiCEGenerator`](@ref):\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nce = generate_counterfactual(x, target, counterfactual_data, M, generator; num_counterfactuals=5)\nplot(ce)\n```\n\n::: {.cell-output .cell-output-display execution_count=5}\n![](generators_files/figure-commonmark/cell-5-output-1.svg){}\n:::\n:::\n\n\nMultiple macros can be chained using `Chains.jl` making it easy to create entirely new flavours of counterfactual generators. The following generator, for example, combines ideas from DiCE [@mothilal2020explaining] and REVISE [@joshi2019realistic]:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\n@chain generator begin\n @objective logitcrossentropy + 1.0ddp_diversity # DiCE (Mothilal et al. 2020)\n @with_optimiser Flux.Adam(0.1) \n @search_latent_space # REVISE (Joshi et al. 2019)\nend\n```\n:::\n\n\n\n\n\n\n::: {.cell execution_count=8}\n\n::: {.cell-output .cell-output-display execution_count=9}\nLet's take this generator to our MNIST dataset and generate a counterfactual explanation for turning a 0 into a 8. \n\n:::\n:::\n\n\n::: {.cell execution_count=9}\n\n::: {.cell-output .cell-output-display}\n![](generators_files/figure-commonmark/cell-10-output-1.svg){}\n:::\n:::\n\n\n## Off-the-Shelf Generators \n\nOff-the-shelf generators are just default recipes for counterfactual generators. Currently, the following off-the-shelf counterfactual generators are implemented in the package:\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\ngenerator_catalogue\n```\n\n::: {.cell-output .cell-output-display execution_count=11}\n```\nDict{Symbol, Any} with 11 entries:\n :gravitational => GravitationalGenerator\n :growing_spheres => GrowingSpheresGenerator\n :revise => REVISEGenerator\n :clue => CLUEGenerator\n :probe => ProbeGenerator\n :dice => DiCEGenerator\n :feature_tweak => FeatureTweakGenerator\n :claproar => ClaPROARGenerator\n :wachter => WachterGenerator\n :generic => GenericGenerator\n :greedy => GreedyGenerator\n```\n:::\n:::\n\n\n\n\nTo specify the type of generator you want to use, you can simply instantiate it:\n\n::: {.cell execution_count=12}\n``` {.julia .cell-code}\n# Search:\ngenerator = GenericGenerator()\nce = generate_counterfactual(x, target, counterfactual_data, M, generator)\nplot(ce)\n```\n\n::: {.cell-output .cell-output-display execution_count=13}\n![](generators_files/figure-commonmark/cell-13-output-1.svg){}\n:::\n:::\n\n\nWe generally make an effort to follow the literature as closely as possible when implementing off-the-shelf generators. \n\n## References\n\n",
"markdown": "---\ntitle: Handling Generators\n---\n\n\n```@meta\nCurrentModule = CounterfactualExplanations \n```\n\n\n\n\nGenerating Counterfactual Explanations can be seen as a generative modelling task because it involves generating samples in the input space: $x \\sim \\mathcal{X}$. In this tutorial, we will introduce how Counterfactual `GradientBasedGenerator`s are used. They are discussed in more detail in the explanatory section of the documentation.\n\n## Composable Generators \n\n\n```{=commonmark}\n!!! warning \"Breaking Changes Expected\"\n Work on this feature is still in its very early stages and breaking changes should be expected. \n```\n\n\nOne of the key objectives for this package is **Composability**. It turns out that many of the various counterfactual generators that have been proposed in the literature, essentially do the same thing: they optimize an objective function. Formally we have,\n\n$$\n\\begin{aligned}\n\\mathbf{s}^\\prime &= \\arg \\min_{\\mathbf{s}^\\prime \\in \\mathcal{S}} \\left\\{ {\\text{yloss}(M(f(\\mathbf{s}^\\prime)),y^*)}+ \\lambda {\\text{cost}(f(\\mathbf{s}^\\prime)) } \\right\\} \n\\end{aligned} \n$$ {#eq-general}\n\nwhere $\\text{yloss}$ denotes the main loss function and $\\text{cost}$ is a penalty term [@altmeyer2023endogenous]. \n\nWithout going into further detail here, the important thing to mention is that @eq-general very closely describes how counterfactual search is actually implemented in the package. In other words, all off-the-shelf generators currently implemented work with that same objective. They just vary in the way that penalties are defined, for example. This gives rise to an interesting idea: \n\n> Why not compose generators that combine ideas from different off-the-shelf generators?\n\nThe [`GradientBasedGenerator`](@ref) class provides a straightforward way to do this, without requiring users to build custom `GradientBasedGenerator`s from scratch. It can be instantiated as follows:\n\n\n::: {.cell execution_count=2}\n``` {.julia .cell-code}\ngenerator = GradientBasedGenerator()\n```\n:::\n\n\nBy default, this creates a `generator` that simply performs gradient descent without any penalties. To modify the behaviour of the `generator`, you can define the counterfactual search objective function using the [`@objective`](@ref) macro:\n\n::: {.cell execution_count=3}\n``` {.julia .cell-code}\n@objective(generator, logitbinarycrossentropy + 0.1distance_l2 + 1.0ddp_diversity)\n```\n:::\n\n\nHere we have essentially created a version of the [`DiCEGenerator`](@ref):\n\n::: {.cell execution_count=4}\n``` {.julia .cell-code}\nce = generate_counterfactual(x, target, counterfactual_data, M, generator; num_counterfactuals=5)\nplot(ce)\n```\n\n::: {.cell-output .cell-output-display execution_count=5}\n![](generators_files/figure-commonmark/cell-5-output-1.svg){}\n:::\n:::\n\n\nMultiple macros can be chained using `Chains.jl` making it easy to create entirely new flavours of counterfactual generators. The following generator, for example, combines ideas from DiCE [@mothilal2020explaining] and REVISE [@joshi2019realistic]:\n\n::: {.cell execution_count=5}\n``` {.julia .cell-code}\n@chain generator begin\n @objective logitcrossentropy + 1.0ddp_diversity # DiCE (Mothilal et al. 2020)\n @with_optimiser Flux.Adam(0.1) \n @search_latent_space # REVISE (Joshi et al. 2019)\nend\n```\n:::\n\n\n\n\n\n\n::: {.cell execution_count=8}\n\n::: {.cell-output .cell-output-display execution_count=9}\nLet's take this generator to our MNIST dataset and generate a counterfactual explanation for turning a 0 into a 8. \n\n:::\n:::\n\n\n::: {.cell execution_count=9}\n\n::: {.cell-output .cell-output-display}\n![](generators_files/figure-commonmark/cell-10-output-1.svg){}\n:::\n:::\n\n\n## Off-the-Shelf Generators \n\nOff-the-shelf generators are just default recipes for counterfactual generators. Currently, the following off-the-shelf counterfactual generators are implemented in the package:\n\n::: {.cell execution_count=10}\n``` {.julia .cell-code}\ngenerator_catalogue\n```\n\n::: {.cell-output .cell-output-display execution_count=11}\n```\nDict{Symbol, Any} with 11 entries:\n :gravitational => GravitationalGenerator\n :growing_spheres => GrowingSpheresGenerator\n :revise => REVISEGenerator\n :clue => CLUEGenerator\n :probe => ProbeGenerator\n :dice => DiCEGenerator\n :feature_tweak => FeatureTweakGenerator\n :claproar => ClaPROARGenerator\n :wachter => WachterGenerator\n :generic => GenericGenerator\n :greedy => GreedyGenerator\n```\n:::\n:::\n\n\n\n\nTo specify the type of generator you want to use, you can simply instantiate it:\n\n::: {.cell execution_count=12}\n``` {.julia .cell-code}\n# Search:\ngenerator = GenericGenerator()\nce = generate_counterfactual(x, target, counterfactual_data, M, generator)\nplot(ce)\n```\n\n::: {.cell-output .cell-output-display execution_count=13}\n![](generators_files/figure-commonmark/cell-13-output-1.svg){}\n:::\n:::\n\n\nWe generally make an effort to follow the literature as closely as possible when implementing off-the-shelf generators. \n\n## References\n\n",
"supporting": [
"generators_files"
],
Expand Down
Loading
Loading