Skip to content

HiddenMarkovModel and GeneralizedHiddenMarkovModel constructors crash on jax/jaxlib < 0.5.0 #89

@good-epic

Description

@good-epic

Minimal Reproducible example

print(jax.__version__) # 0.5.0
from simplexity.generative_processes.builder import build_generalized_hidden_markov_model
tomq = build_generalized_hidden_markov_model("tom_quantum", alpha=1.07, beta=7.1)
Traceback (most recent call last):
  File "<stdin>", line 7, in <module>
  File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/simplexity/generative_processes/builder.py", line 57, in build_generalized_hidden_markov_model
    return GeneralizedHiddenMarkovModel(transition_matrices)
  File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/equinox/_module.py", line 186, in __call__
    self = super(_ModuleMeta, initable_cls).__call__(*args, **kwargs)
  File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/equinox/_better_abstract.py", line 280, in __call__
    self = super().__call__(*args, **kwargs)
  File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/simplexity/generative_processes/generalized_hidden_markov_model.py", line 30, in __init__
    eigenvalues, right_eigenvectors = jnp.linalg.eig(state_transition_matrix)
  File "/home/mattylev/.pyenv/versions/simplex/lib/python3.12/site-packages/jax/_src/numpy/linalg.py", line 746, in eig
    w, v = lax_linalg.eig(a, compute_left_eigenvectors=False)
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Nonsymmetric eigendecomposition is only implemented on the CPU backend. If your matrix is symmetric or Hermitian, you should use eigh instead.

Analysis

The constructors for HiddenMarkovModel and GeneralizedHiddenMarkovModel call jnp.linalg.eig. I've adapted Adam's model training code which relies on TransformerLens, which requires numpy < 2.0. This means we have to use jax/jaxlib < 0.5. The eig function isn't implemented on GPU prior to version 0.5. Could we do something like below and replace jnp.linalg.eig with eig_safe?

Further experiments (works on CPU)

import jax
import jax.numpy as jnp

# Quick version check
JAX_SUPPORTS_GPU_EIG = tuple(map(int, jax.__version__.split('.')[:2])) >= (0, 5)

def eig_safe(matrix):
    if JAX_SUPPORTS_GPU_EIG:
        return jnp.linalg.eig(matrix)
    else:
        # Fallback to CPU for JAX < 0.5.0
        def _eig_cpu(matrix):
            with jax.default_device(jax.devices("cpu")[0]):
                return jax.jit(jnp.linalg.eig)(matrix)
        
        eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
        eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)
        
        return jax.pure_callback(
            _eig_cpu,
            (eigenvalues_shape, eigenvectors_shape),
            matrix.astype(complex)
        )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions