Azx is a JAX implementation of AlphaZero- and MuZero-style algorithms.
Run an example end-to-end:
python examples/maze.pyOr the MuZero variant:
python examples/maze-mz.pyThere are four main components to set up an Azx training run:
- An environment adapter (
EnvironmentAdapter) - A network definition
- An agent (
AlphaZeroorMuZero) - A trainer (
AlphaZeroTrainerorMuZeroTrainer)
The adapter wraps a Jumanji-compatible environments and provides standard interfaces for observations and action masks.
from azx.core.env import EnvironmentAdapter
adapter = EnvironmentAdapter(
env=env,
obs_fn=flatten_observation,
action_mask_fn=action_mask_fn,
)Where:
obs_fn(state) -> jax.Array: converts env state to a model observation.action_mask_fn(state) -> jax.Array: boolean mask of valid actions.
AlphaZero expects a single network:
def policy_value(obs):
return (policy_logits, value_logits)MuZero expects three networks:
def representation(obs):
return latent
def dynamics(latent, action):
return (next_latent, reward_logits)
def prediction(latent):
return (policy_logits, value_logits)Logits are over a discrete support for value/reward (see configs below).
The agent is what is trained. It implements the inference logic including MCTS.
AlphaZero:
from azx.alphazero.agent import AlphaZero, Config as AZConfig
agent = AlphaZero(
adapter=adapter,
config=Config(...),
network_fn=policy_value_net,
)MuZero:
from azx.muzero.agent import MuZero, Config as MZConfig
agent = MuZero(
config=Config(...),
representation_fn=representation,
dynamics_fn=dynamics,
prediction_fn=prediction,
)Trainers are responsible for optimizing the provided agent via reinforcement learning.
AlphaZero:
from azx.alphazero.trainer import AlphaZeroTrainer, TrainConfig as AZTrain
trainer = AlphaZeroTrainer(
agent=agent,
config=TrainConfig(...),
opt=optax.adamw(1e-4),
)MuZero:
from azx.muzero.trainer import MuZeroTrainer, TrainConfig as MZTrain
trainer = MuZeroTrainer(
agent=agent,
adapter=adapter,
config=TrainConfig(...),
opt=optax.adamw(1e-4),
)Train:
state = trainer.init(jax.random.PRNGKey(0))
state = trainer.learn(state, num_steps=100000, checkpoints_dir=Path("./checkpoints"))discount: discount factor for value targets.num_simulations: MCTS simulations per move.use_mixed_value,value_scale: MCTS q-transform settings.support_min,support_max,support_eps: discrete support for value/reward logits. All values use square root scaling.
actor_batch_size: number of parallel envs for rollouts.train_batch_size: batch size sampled from the trajectory buffer.n_step: n-step return length.unroll_steps: number of steps used for loss computation.eval_frequency: number of actor steps between eval+log events.max_eval_steps: cap on stepsf or eval rollouts.checkpoint_frequency: save interval (in env steps).gumbel_scale: exploration temperature for MCTS.max_length_buffer,min_length_buffer: trajectory buffer size controls.value_loss_weight: weight for value loss.
All AlphaZero fields, plus:
consistency_loss_weight: weight for representation consistency.