Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update jax sparse operations when possible #52

Open
DanPuzzuoli opened this issue Nov 4, 2021 · 2 comments
Open

Update jax sparse operations when possible #52

DanPuzzuoli opened this issue Nov 4, 2021 · 2 comments

Comments

@DanPuzzuoli
Copy link
Collaborator

DanPuzzuoli commented Nov 4, 2021

PR #51 introduces JAX sparse evaluation for models. Due to some limited functionality for the jax BCOO sparse array type, some work arounds were required to implement a few things. Subsequent releases of JAX are expected to eliminate the need for these workarounds, and this issue is a reminder of what these are.

  • A function jsparse_linear_combo is defined at the beginning of operator_collections.py that takes a linear combination of sparse arrays (specified as a 3d BCOO array), with the coefficients given in a dense 1d array. This cannot be achieved using a sparisfied version of jnp.tensordot as the design convention jax is using is that such operations will always output dense arrays if at least one input is dense. Hence, jsparse_linear_combo multiplies the coefficients against the sparse array directly via broadcasting. However, sparse-dense element-wise multiplication, at the time of writing, is limited to arrays of the same shape, and therefore it is necessary to explicitly blow up the coefficient array to a dense array with the same shape as the sparse array (which is huge). I'm not sure if this is done via views so it's okay, but this should be changed when possible regardless.
  • Setting up of operators in the jax sparse Lindblad operator collection is done using dense arrays, as sparse-sparse matrix multiplication is not yet implemented. This is relatively minor but it would be nice to do it with sparse when possible.
  • "Vectorized" products A @ B, where A and B are 3d arrays, with one being sparse and the other dense, are not reverse-mode differentiable. This results in LindbladModel in jax-sparse not being reverse-mode differentiable. It is however, forward mode differentiable. At some point this will change, and we will need to remove the caveat in LindbladModel.evaluation_mode that sparse mode with jax is not reverse-mode differentiable.
@DanPuzzuoli
Copy link
Collaborator Author

Code samples for the above points:

setup:

import jax.numpy as jnp
from jax import jit, grad
from jax.experimental import sparse as jsparse

# sparse versions of jax.numpy operations
jsparse_sum = jsparse.sparsify(jnp.sum)
jsparse_matmul = jsparse.sparsify(jnp.matmul)
jsparse_add = jsparse.sparsify(jnp.add)
jsparse_subtract = jsparse.sparsify(jnp.subtract)

coeffs = jnp.array([1., 2., 3.])
dense_array = jnp.array([[[0., 1.], [1., 0.]], [[0., 1.], [1j, 0.]], [[0., 1.], [0., 1.]]])
sparse_array = jsparse.BCOO.fromdense(dense_array, n_batch=1)

Test code for linear combo:

def jsparse_linear_combo(coeffs, mats):
    return jsparse_sum(coeffs[:, None, None] * mats, axis=0)

jsparse_linear_combo(coeffs, sparse_array)

Triple product reverse-mode differentiation test:

jsparse_triple_product = jsparse.sparsify(lambda A, X, B: A @ X @ B)

def f(X):
    return jsparse_triple_product(sparse_array, X, sparse_array).real.sum()

jit_grad_f = jit(grad(f))
jit_grad_f(jnp.eye(2, dtype=float))

@DanPuzzuoli
Copy link
Collaborator Author

Update:

As of jax 0.2.26 and jaxlib 0.1.75 the above code snippets work. PR #69 now removes the caveat that LindbladModel.evaluate_rhs cannot be reverse-mode autodiffed when in sparse mode, and changes the autodiff test case to revert to testing reverse-mode autodiff.

Updating jsparse_linear_combo in operator_collections.py still needs to be done: while the above snippet works, simply updating jsparse_linear_combo results in several test failures, and why these are happening needs to be figured out. It's possible they're all just numpy.array v.s. jax.numpy.array type errors in the test case setups.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant