Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions simplexity/generative_processes/generalized_hidden_markov_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,31 @@ def initial_state(self) -> State:
"""The initial state of the model."""
return cast(State, self._initial_state)

@eqx.filter_vmap(in_axes=(None, 0, 0, None))
def generate_with_obs_dist(
self, state: State, key: chex.PRNGKey, sequence_len: int
) -> tuple[State, chex.Array, chex.Array]:
"""Generate a batch of sequences of observations from the generative process.

Inputs:
state: (batch_size, num_states)
key: (batch_size, 2)
Returns: tuple of (belief states, observations, observation probabilities) where:
Comment on lines +75 to +80
Copy link

Copilot AI May 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The docstring describes batch input shapes, but this method is vmapped over a per-sample state and key. Clarify that inputs are single-example (no batch dim) and outputs are vectorized across the batch.

Suggested change
"""Generate a batch of sequences of observations from the generative process.
Inputs:
state: (batch_size, num_states)
key: (batch_size, 2)
Returns: tuple of (belief states, observations, observation probabilities) where:
"""Generate sequences of observations from the generative process.
Inputs (per-sample, no batch dimension):
state: (num_states,)
key: (2,)
Returns (vectorized across the batch):

Copilot uses AI. Check for mistakes.
states: (batch_size, sequence_len, num_states)
obs: (batch_size, sequence_len)
obs_probs: (batch_size, sequence_len, vocab_size)
"""
keys = jax.random.split(key, sequence_len)

def gen_sequences(state: State, key: chex.PRNGKey) -> tuple[State, tuple[State, chex.Array, chex.Array]]:
obs_probs = self.observation_probability_distribution(state)
obs = jax.random.choice(key, self.vocab_size, p=obs_probs)
new_state = self.transition_states(state, obs)
return new_state, (state, obs, obs_probs)

_, (states, obs, obs_probs) = jax.lax.scan(gen_sequences, state, keys)
return states, obs, obs_probs

@eqx.filter_jit
def emit_observation(self, state: State, key: chex.PRNGKey) -> jax.Array:
"""Emit an observation based on the state of the generative process."""
Expand Down
20 changes: 20 additions & 0 deletions tests/generative_processes/test_generalized_hidden_markov_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,26 @@ def test_generate_with_intermediate_states(model_name: str, request: pytest.Fixt
assert observations.shape == (batch_size, sequence_len)


@pytest.mark.parametrize("model_name", ["z1r", "fanizza_model"])
def test_generate_with_obs_dist(model_name: str, request: pytest.FixtureRequest):
model: GeneralizedHiddenMarkovModel = request.getfixturevalue(model_name)
batch_size = 4
sequence_len = 10

initial_states = jnp.repeat(model.initial_state[None, :], batch_size, axis=0)
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
intermediate_states, observations, obs_probs = model.generate_with_obs_dist(initial_states, keys, sequence_len)
assert intermediate_states.shape == (batch_size, sequence_len, model.num_states)
assert observations.shape == (batch_size, sequence_len)
assert obs_probs.shape == (batch_size, sequence_len, model.vocab_size)
last_intermediate_states = intermediate_states[:, -1, :]

final_states, observations, obs_probs = model.generate_with_obs_dist(last_intermediate_states, keys, sequence_len)
assert final_states.shape == (batch_size, sequence_len, model.num_states)
assert observations.shape == (batch_size, sequence_len)
assert obs_probs.shape == (batch_size, sequence_len, model.vocab_size)


def test_hmm_observation_probability_distribution(z1r: GeneralizedHiddenMarkovModel):
state = jnp.array([0.3, 0.1, 0.6])
obs_probs = z1r.observation_probability_distribution(state)
Expand Down
38 changes: 38 additions & 0 deletions tests/generative_processes/test_hidden_markov_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,44 @@ def test_generate(z1r: HiddenMarkovModel):
assert final_observations.shape == (batch_size, sequence_len)


def test_generate_with_intermediate_states(z1r: HiddenMarkovModel):
batch_size = 4
sequence_len = 10

initial_states = jnp.repeat(z1r.normalizing_eigenvector[None, :], batch_size, axis=0)
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
intermediate_states, intermediate_observations = z1r.generate(initial_states, keys, sequence_len, True)
assert intermediate_states.shape == (batch_size, sequence_len, z1r.num_states)
assert intermediate_observations.shape == (batch_size, sequence_len)
last_intermediate_states = intermediate_states[:, -1, :]

final_states, final_observations = z1r.generate(last_intermediate_states, keys, sequence_len, True)
assert final_states.shape == (batch_size, sequence_len, z1r.num_states)
assert final_observations.shape == (batch_size, sequence_len)


def test_generate_with_obs_dist(z1r: HiddenMarkovModel):
batch_size = 4
sequence_len = 10

initial_states = jnp.repeat(z1r.normalizing_eigenvector[None, :], batch_size, axis=0)
keys = jax.random.split(jax.random.PRNGKey(0), batch_size)
intermediate_states, intermediate_observations, intermediate_obs_probs = z1r.generate_with_obs_dist(
initial_states, keys, sequence_len
)
assert intermediate_states.shape == (batch_size, sequence_len, z1r.num_states)
assert intermediate_observations.shape == (batch_size, sequence_len)
assert intermediate_obs_probs.shape == (batch_size, sequence_len, z1r.vocab_size)
last_intermediate_states = intermediate_states[:, -1, :]

final_states, final_observations, final_obs_probs = z1r.generate_with_obs_dist(
last_intermediate_states, keys, sequence_len
)
assert final_states.shape == (batch_size, sequence_len, z1r.num_states)
assert final_observations.shape == (batch_size, sequence_len)
assert final_obs_probs.shape == (batch_size, sequence_len, z1r.vocab_size)


def test_observation_probability_distribution(z1r: HiddenMarkovModel):
state = jnp.array([0.3, 0.1, 0.6])
obs_probs = z1r.observation_probability_distribution(state)
Expand Down
Loading