Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quasisep Benchmark #221

Open
ywx649999311 opened this issue Sep 18, 2024 · 6 comments
Open

Quasisep Benchmark #221

ywx649999311 opened this issue Sep 18, 2024 · 6 comments

Comments

@ywx649999311
Copy link
Contributor

ywx649999311 commented Sep 18, 2024

Hi @dfm,

I noticed that there is a factor of 10 increase in the runtime of the quasisep kernels in the newest version of tinygp. So, I rerun your benchmark notebook in the documentation to test out. And indeed, the quasisep kernel is about 10 times slower than celerite2. I am attaching the reproduced benchmark plot. Note I didn't rerun the benchmark for GPU. Any idea what might be causing it?

I run your notebook on two platforms (M2 Mac and Google Colab), the results are basically consistent. The plot shown was produced using the results obtained with:

  • M2 Pro
  • Python 3.11.9
  • tingpy v0.3

Thanks for your attention!

benchmark

@dfm
Copy link
Owner

dfm commented Sep 19, 2024

hmmm. Are you sure it's the tinygp version, and not the JAX version that makes this difference? Can you try to bisect the version combination where this goes wrong?

@ywx649999311
Copy link
Contributor Author

I think it starts to go wrong with:

  • tinygp = 0.2.3

I tried a collection of different JAX versions all the way back to 0.3.3 and it has no major impact on the benchmark. This makes me wonder if the problem could be caused by the jax.debug introduced in tinygp 0.2.3?

The JAX version has a tiny impact on the performance for about only a factor of <2 at N < 50. The impact starts at about JAX == 0.4.2.

@dfm
Copy link
Owner

dfm commented Sep 21, 2024

Interesting - thanks for tracking that down! Can try updating the benchmark you're running to use assume_sorted=True (it should be sufficient to pass that as a keyword argument when building the GaussianProcess object with the quasisep kernel). If that is set, the debug should never be executed. It seems a little surprising if that debug is having such an effect, but maybe it is. Let's get that figured out!

@ywx649999311
Copy link
Contributor Author

ywx649999311 commented Sep 23, 2024

I think it is more complicated than what I thought earlier. Setting assume_sorted=True helped, but it is not the cause for the overall factor of ten slowdown. Below are the benchmarks for two version combinations.

  1. jax=0.4.31 and tinygp=0.3.0:
    2f8b2c49-362b-404d-9a06-0ffb27698b50

  2. jax=0.4.33 and tinygp=0.3.0:
    52defd41-bd3f-41ac-aff8-cbec1fc803a0

You were right, the JAX version has a much bigger effect! Any thoughts?

P.S. Sorry for the misleading information from above. I tried to make comparisons with the earliest version of tinygp (0.2.2) which I tested in the past and worked my way up, but I had to stop at jax=0.4.28 because JAX dropped the support for the Python (3.9) version that I used to do the test.

@dfm
Copy link
Owner

dfm commented Sep 23, 2024

Thanks for tracking this down! My next guess is that this has to do with a change in the CPU behavior that was introduced in v0.4.32. Can you try setting the following environment variable and trying again:

XLA_FLAGS=--xla_cpu_use_thunk_runtime=false

If that is the source of the issue, I don't know exactly how to solve it in the long term, but at least then we can start digging in. Thanks again!

@ywx649999311
Copy link
Contributor Author

Setting the flag you recommended does not solve the issue by itself. I have to disable the intra_op_parallelism_threads flag to make benchmark close to what it was before.

Here are my flag settings:

os.environ["XLA_FLAGS"] = (
    os.environ.get("XLA_FLAGS", "")
    + " --xla_cpu_multi_thread_eigen=false"
    # + " --intra_op_parallelism_threads=1"
    + " --xla_cpu_use_thunk_runtime=false"
)

Here is the resulting benchmark:
d995e039-ade7-48b2-8145-aa074214ac8b

I think you must have set the intra_op_parallelism_threads for a reason and I am not sure it is reasonable to disable it. Perhaps this is not an apple to apple comparison.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants