diff --git a/simplexity/generative_processes/generalized_hidden_markov_model.py b/simplexity/generative_processes/generalized_hidden_markov_model.py index d93523c4..3e8660ec 100644 --- a/simplexity/generative_processes/generalized_hidden_markov_model.py +++ b/simplexity/generative_processes/generalized_hidden_markov_model.py @@ -68,6 +68,31 @@ def initial_state(self) -> State: """The initial state of the model.""" return cast(State, self._initial_state) + @eqx.filter_vmap(in_axes=(None, 0, 0, None)) + def generate_with_obs_dist( + self, state: State, key: chex.PRNGKey, sequence_len: int + ) -> tuple[State, chex.Array, chex.Array]: + """Generate a batch of sequences of observations from the generative process. + + Inputs: + state: (batch_size, num_states) + key: (batch_size, 2) + Returns: tuple of (belief states, observations, observation probabilities) where: + states: (batch_size, sequence_len, num_states) + obs: (batch_size, sequence_len) + obs_probs: (batch_size, sequence_len, vocab_size) + """ + keys = jax.random.split(key, sequence_len) + + def gen_sequences(state: State, key: chex.PRNGKey) -> tuple[State, tuple[State, chex.Array, chex.Array]]: + obs_probs = self.observation_probability_distribution(state) + obs = jax.random.choice(key, self.vocab_size, p=obs_probs) + new_state = self.transition_states(state, obs) + return new_state, (state, obs, obs_probs) + + _, (states, obs, obs_probs) = jax.lax.scan(gen_sequences, state, keys) + return states, obs, obs_probs + @eqx.filter_jit def emit_observation(self, state: State, key: chex.PRNGKey) -> jax.Array: """Emit an observation based on the state of the generative process.""" diff --git a/tests/generative_processes/test_generalized_hidden_markov_model.py b/tests/generative_processes/test_generalized_hidden_markov_model.py index 33a60c59..fe813184 100644 --- a/tests/generative_processes/test_generalized_hidden_markov_model.py +++ b/tests/generative_processes/test_generalized_hidden_markov_model.py @@ -122,6 +122,26 @@ def test_generate_with_intermediate_states(model_name: str, request: pytest.Fixt assert observations.shape == (batch_size, sequence_len) +@pytest.mark.parametrize("model_name", ["z1r", "fanizza_model"]) +def test_generate_with_obs_dist(model_name: str, request: pytest.FixtureRequest): + model: GeneralizedHiddenMarkovModel = request.getfixturevalue(model_name) + batch_size = 4 + sequence_len = 10 + + initial_states = jnp.repeat(model.initial_state[None, :], batch_size, axis=0) + keys = jax.random.split(jax.random.PRNGKey(0), batch_size) + intermediate_states, observations, obs_probs = model.generate_with_obs_dist(initial_states, keys, sequence_len) + assert intermediate_states.shape == (batch_size, sequence_len, model.num_states) + assert observations.shape == (batch_size, sequence_len) + assert obs_probs.shape == (batch_size, sequence_len, model.vocab_size) + last_intermediate_states = intermediate_states[:, -1, :] + + final_states, observations, obs_probs = model.generate_with_obs_dist(last_intermediate_states, keys, sequence_len) + assert final_states.shape == (batch_size, sequence_len, model.num_states) + assert observations.shape == (batch_size, sequence_len) + assert obs_probs.shape == (batch_size, sequence_len, model.vocab_size) + + def test_hmm_observation_probability_distribution(z1r: GeneralizedHiddenMarkovModel): state = jnp.array([0.3, 0.1, 0.6]) obs_probs = z1r.observation_probability_distribution(state) diff --git a/tests/generative_processes/test_hidden_markov_model.py b/tests/generative_processes/test_hidden_markov_model.py index 4c495bcb..5464ce36 100644 --- a/tests/generative_processes/test_hidden_markov_model.py +++ b/tests/generative_processes/test_hidden_markov_model.py @@ -97,6 +97,44 @@ def test_generate(z1r: HiddenMarkovModel): assert final_observations.shape == (batch_size, sequence_len) +def test_generate_with_intermediate_states(z1r: HiddenMarkovModel): + batch_size = 4 + sequence_len = 10 + + initial_states = jnp.repeat(z1r.normalizing_eigenvector[None, :], batch_size, axis=0) + keys = jax.random.split(jax.random.PRNGKey(0), batch_size) + intermediate_states, intermediate_observations = z1r.generate(initial_states, keys, sequence_len, True) + assert intermediate_states.shape == (batch_size, sequence_len, z1r.num_states) + assert intermediate_observations.shape == (batch_size, sequence_len) + last_intermediate_states = intermediate_states[:, -1, :] + + final_states, final_observations = z1r.generate(last_intermediate_states, keys, sequence_len, True) + assert final_states.shape == (batch_size, sequence_len, z1r.num_states) + assert final_observations.shape == (batch_size, sequence_len) + + +def test_generate_with_obs_dist(z1r: HiddenMarkovModel): + batch_size = 4 + sequence_len = 10 + + initial_states = jnp.repeat(z1r.normalizing_eigenvector[None, :], batch_size, axis=0) + keys = jax.random.split(jax.random.PRNGKey(0), batch_size) + intermediate_states, intermediate_observations, intermediate_obs_probs = z1r.generate_with_obs_dist( + initial_states, keys, sequence_len + ) + assert intermediate_states.shape == (batch_size, sequence_len, z1r.num_states) + assert intermediate_observations.shape == (batch_size, sequence_len) + assert intermediate_obs_probs.shape == (batch_size, sequence_len, z1r.vocab_size) + last_intermediate_states = intermediate_states[:, -1, :] + + final_states, final_observations, final_obs_probs = z1r.generate_with_obs_dist( + last_intermediate_states, keys, sequence_len + ) + assert final_states.shape == (batch_size, sequence_len, z1r.num_states) + assert final_observations.shape == (batch_size, sequence_len) + assert final_obs_probs.shape == (batch_size, sequence_len, z1r.vocab_size) + + def test_observation_probability_distribution(z1r: HiddenMarkovModel): state = jnp.array([0.3, 0.1, 0.6]) obs_probs = z1r.observation_probability_distribution(state)