For the jax implementation, on line 210 of spe.py, should the axis summed over be -1 instead of -2? When using -2, the size of the last output dimension is num_realizations, rather than the query/key dimension:
return (spe[:, :keys.shape[1]] * keys[..., None]).sum(axis=-1)