Skip to content

Wrong axis in jax spe summation #13

@tomweingarten

Description

@tomweingarten

For the jax implementation, on line 210 of spe.py, should the axis summed over be -1 instead of -2? When using -2, the size of the last output dimension is num_realizations, rather than the query/key dimension:

return (spe[:, :keys.shape[1]] * keys[..., None]).sum(axis=-1)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions