Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding parallel implementations of (some?) quasisep algorithms #210

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

dfm
Copy link
Owner

@dfm dfm commented Apr 3, 2024

The quasisep solver is fast on CPU, but the performance is very bad on GPU (and probably TPU) because of the extensive use of lax.scan. It's possible to rewrite at least some of these operations using lax.associative_scan which (at least in principle) are more accelerator friendly. This approach is similar is spirit to the algorithms derived in https://arxiv.org/abs/1905.13002

This PR is a WIP to add some of these operations. So far, I've just implemented a parallel matrix multiplication. There are still some precision issues to work out, but the initial performance looks good:

Screenshot 2024-04-03 at 6 16 26 PM

On CPU, the scan and associative_scan matmuls take 1.65 ms and 3.59 ms respectively, for a J = 3 lower triangular matrix with N = 50,000 data points. On the GPU, these computations cost 685 ms and 1.32 ms respectively. Therefore, the scan version is ~600x slower on GPU, whereas the associative_scan version isn't. These GPU results are not impressive, but it might be worth investigating further in case someone wants to use this solver as part of a larger model that benefits from hardware acceleration.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant