Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise DynamicPPL Slightly and Better Zero Adjoint Functionality #242

Merged
merged 22 commits into from
Sep 9, 2024

Conversation

willtebbutt
Copy link
Member

@willtebbutt willtebbutt commented Sep 5, 2024

This PR has now bloated to also tackle #241 , and #243 .
Ideally I would have tackled this in three separate PRs, but I got carried away fixing things.

So, it does several things:

  1. add a function called simple_zero_adjoint, which is useful to create rrules for things which don't need differentiating through
  2. use simple_zero_adjoint everywhere that we can in the code base -- there are quite a number of instances in which it works
  3. uses simple_zero_adjoint to add a DynamicPPL-specific rule, in an extension, to avoid some annoying computation that was adding overhead to Tapir.jl in that context
  4. a function remove_dead_blocks(::BBCode), which will remove any basic blocks which cannot be reached. This was the solution to making LKJCholesky work properly. I don't fully understand what was going on, but basically I think that the compiler is making some assumptions about what IRCode it can see in practice, and I wasn't producing code which conformed to them.
  5. added a coupe of simple_zero_adjoint rules for string and symbol related functionality that Tapir wasn't entirely happy with due to some ccalls.
  6. add default values for all kwargs for Tapir.TestUtils.test_rrule and make use of them throughout the tests

ToDo:

  • add unit testing for remove_dead_blocks

I'm planning to finish this up on Monday.

edit: increased code churn is due to slow down in CI that we were observing due to not fully caching an interpreter for the current world age. This now does that, and some code has changed as a result. Additionally, the implementation of _remove_unreachable_blocks (renamed from _remove_dead_blocks) has been heavily simplified, and the docstring associated to it improved substantially.

Copy link

codecov bot commented Sep 5, 2024

Codecov Report

Attention: Patch coverage is 94.59459% with 4 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/rrules/builtins.jl 83.33% 2 Missing ⚠️
src/interpreter/s2s_reverse_mode_ad.jl 93.33% 1 Missing ⚠️
src/rrules/misc.jl 50.00% 1 Missing ⚠️
Files with missing lines Coverage Δ
ext/TapirDynamicPPLExt.jl 100.00% <100.00%> (ø)
src/codual.jl 91.66% <100.00%> (+0.49%) ⬆️
src/interpreter/abstract_interpretation.jl 80.55% <100.00%> (+2.43%) ⬆️
src/interpreter/bbcode.jl 95.13% <100.00%> (+0.15%) ⬆️
src/rrules/avoiding_non_differentiable_code.jl 100.00% <100.00%> (ø)
src/rrules/foreigncall.jl 94.00% <100.00%> (-0.14%) ⬇️
src/rrules/tasks.jl 72.72% <100.00%> (-0.61%) ⬇️
src/test_utils.jl 91.12% <100.00%> (-0.03%) ⬇️
src/interpreter/s2s_reverse_mode_ad.jl 93.04% <93.33%> (-0.15%) ⬇️
src/rrules/misc.jl 97.46% <50.00%> (ø)
... and 1 more

Copy link
Contributor

github-actions bot commented Sep 5, 2024

Performance Ratio:
Ratio of time to compute gradient and time to compute function.
Warning: results are very approximate! See here for more context.

┌────────────────────────────┬────────┬─────────┬─────────────┬─────────┐
│                      Label │  Tapir │  Zygote │ ReverseDiff │  Enzyme │
│                     String │ String │  String │      String │  String │
├────────────────────────────┼────────┼─────────┼─────────────┼─────────┤
│                   sum_1000 │  116.0 │   0.786 │        4.81 │    1.71 │
│                  _sum_1000 │   7.85 │  1360.0 │        46.9 │  0.0841 │
│               sum_sin_1000 │   2.93 │    1.61 │        10.9 │    1.01 │
│              _sum_sin_1000 │   3.39 │   319.0 │        16.6 │    1.49 │
│                   kron_sum │   76.4 │    3.49 │       211.0 │    8.24 │
│              kron_view_sum │   85.3 │    10.8 │       231.0 │    8.05 │
│      naive_map_sin_cos_exp │   4.27 │ missing │        8.85 │    2.79 │
│            map_sin_cos_exp │   4.66 │    1.72 │        7.61 │    3.42 │
│      broadcast_sin_cos_exp │   4.68 │    2.64 │        1.66 │    2.85 │
│                 simple_mlp │   8.85 │    3.13 │        13.7 │     3.1 │
│                     gp_lml │   15.8 │    4.38 │     missing │ missing │
│ turing_broadcast_benchmark │   8.41 │ missing │        26.9 │ missing │
└────────────────────────────┴────────┴─────────┴─────────────┴─────────┘

ext/TapirDynamicPPLExt.jl Outdated Show resolved Hide resolved
@willtebbutt willtebbutt changed the title Optimise DynamicPPL Slightly Optimise DynamicPPL Slightly and Better Zero Adjoint Functionality Sep 6, 2024
@willtebbutt
Copy link
Member Author

CI was passing before I pushed a docstring tweak, so I should be fine to merge this in an hour or so.

@yebai
Copy link
Contributor

yebai commented Sep 9, 2024

As a rule of thumb, we would like Tapir to be thoroughly tested against Julia's standard library, SciML, Distributions, Lux and Turing. So, we should probably add LKJCholesky to Distributions.jl integration test in this PR or a separate PR.

@willtebbutt
Copy link
Member Author

willtebbutt commented Sep 9, 2024

It's already in there -- see the diff associated to this PR :)

@willtebbutt willtebbutt merged commit 3048683 into main Sep 9, 2024
19 checks passed
@willtebbutt willtebbutt deleted the wct/dynamicppl-optimisation branch September 9, 2024 17:05
@willtebbutt willtebbutt mentioned this pull request Sep 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants