-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
suggested replacement code from AI
references pointing at https://jaxns.readthedocs.io/en/latest/examples/mvn_data_mvn_prior.html
from jax import vmap
def mvn_log_prob_fast(X, mu, sigma):
L = jnp.linalg.cholesky(sigma) # (K, D, D)
D = mu.shape[-1]
# Vectorize over components (K)
def single_component(L_k, mu_k):
diff = X - mu_k[None, :] # (N, D)
y = solve_triangular(L_k, diff.T, lower=True).T # (N, D)
maha = -0.5 * jnp.sum(y**2, axis=-1) # (N,)
log_det = jnp.sum(jnp.log(jnp.diag(L_k)))
return maha - 0.5 * D * jnp.log(2 * jnp.pi) - log_det
return vmap(single_component, in_axes=(0, 0), out_axes=1)(L, mu) # (N, K)
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels