-
Notifications
You must be signed in to change notification settings - Fork 2
Open
Description
Minimal Reproducible example
print(jax.__version__) # 0.5.0
from simplexity.generative_processes.builder import build_generalized_hidden_markov_model
tomq = build_generalized_hidden_markov_model("tom_quantum", alpha=1.07, beta=7.1)
Traceback (most recent call last):
File "<stdin>", line 7, in <module>
File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/simplexity/generative_processes/builder.py", line 57, in build_generalized_hidden_markov_model
return GeneralizedHiddenMarkovModel(transition_matrices)
File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/equinox/_module.py", line 186, in __call__
self = super(_ModuleMeta, initable_cls).__call__(*args, **kwargs)
File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/equinox/_better_abstract.py", line 280, in __call__
self = super().__call__(*args, **kwargs)
File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/simplexity/generative_processes/generalized_hidden_markov_model.py", line 30, in __init__
eigenvalues, right_eigenvectors = jnp.linalg.eig(state_transition_matrix)
File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/jax/_src/numpy/linalg.py", line 746, in eig
w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Nonsymmetric eigendecomposition is only implemented on the CPU backend. If your matrix is symmetric or Hermitian, you should use eigh instead.Analysis
The constructors for HiddenMarkovModel and GeneralizedHiddenMarkovModel call jnp.linalg.eig. I've adapted Adam's model training code which relies on TransformerLens, which requires numpy < 2.0. This means we have to use jax/jaxlib < 0.5. The eig function isn't implemented on GPU prior to version 0.5. Could we do something like below and replace jnp.linalg.eig with eig_safe?
Further experiments (works on CPU)
import jax
import jax.numpy as jnp
# Quick version check
JAX_SUPPORTS_GPU_EIG = tuple(map(int, jax.__version__.split('.')[:2])) >= (0, 5)
def eig_safe(matrix):
if JAX_SUPPORTS_GPU_EIG:
return jnp.linalg.eig(matrix)
else:
# Fallback to CPU for JAX < 0.5.0
def _eig_cpu(matrix):
with jax.default_device(jax.devices("cpu")[0]):
return jax.jit(jnp.linalg.eig)(matrix)
eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
return jax.pure_callback(
_eig_cpu,
(eigenvalues_shape, eigenvectors_shape),
matrix.astype(complex)
)Metadata
Metadata
Assignees
Labels
No labels