Environment wrapper for calculating joint KL divergence between subsequent observations and actions, with built-in support for LLM fine-tuning.
- KL Divergence Tracking - Calculate joint KL divergence between (observation, action) pairs
- LLM Training Data Collection - Automatically format trajectories for LLM fine-tuning
- Reward Shaping - Shape rewards based on KL divergence to encourage stable policies
- Multiple Output Formats - Export data in chat, completion, or preference pair formats
- Gymnasium Compatible - Works with any Gymnasium environment
This project uses uv for package management. Do not use pip.
uv venvuv pip install -e .uv pip install -e ".[dev]"import gymnasium as gym
from kl_wrapper import JointKLDivergenceWrapper
# Wrap any Gymnasium environment
env = gym.make("CartPole-v1")
env = JointKLDivergenceWrapper(
env,
obs_bins=10, # Discretization resolution
action_bins=2, # For discrete actions
history_size=1000 # KL history window
)
# Use like any Gym environment
obs, info = env.reset()
obs, reward, terminated, truncated, info = env.step(action)
# Access KL divergence metrics
print(info['kl_divergence']) # Current KL
print(info['kl_divergence_mean']) # Running mean
print(info['kl_divergence_std']) # Running std
# Get statistics
stats = env.get_kl_statistics()
print(f"Mean KL: {stats['mean']:.4f}")import gymnasium as gym
from llm_training_wrapper import LLMTrainingWrapper
# Create environment with LLM training
env = LLMTrainingWrapper(
gym.make("CartPole-v1"),
output_dir="training_data",
format_type="chat", # or "completion" or "preference"
kl_penalty_weight=0.1,
kl_bonus_weight=0.05,
)
# Collect trajectories
for episode in range(10):
obs, _ = env.reset()
while True:
action = your_policy(obs)
obs, reward, term, trunc, info = env.step(action)
if term or trunc:
break
# Save training data in JSONL format
env.save_all_trajectories("llm_training_data.jsonl")
# Check statistics
stats = env.get_training_statistics()
print(f"Collected {stats['num_steps']} steps")Run the basic KL divergence examples:
uv run python example.pyRun the LLM training data collection examples:
uv run python llm_training_example.py- LLM Training Guide - Complete guide for fine-tuning LLMs with KL divergence
- Training Modes Guide - START HERE: When to use KL penalties (fine-tuning vs exploration)
The wrapper discretizes observations and actions into bins, then calculates the KL divergence between joint distributions:
KL(P_t || P_{t-1})
where P_t is the distribution at time t and P_{t-1} is at time t-1.
- Low KL divergence → Policy is stable/consistent
- High KL divergence → Policy is changing rapidly
The LLM training wrapper shapes rewards to encourage stable policies:
shaped_reward = base_reward + kl_bonus - kl_penaltyThis guides the LLM to:
- Maintain policy consistency
- Avoid catastrophic forgetting
- Balance exploration and stability
Three output formats for different LLM training scenarios:
- Chat Format - For instruction-following models (ChatGPT-style)
- Completion Format - For prompt-completion models (GPT-3 style)
- Preference Format - For RLHF/preference learning
See LLM_TRAINING_GUIDE.md for details.
KL_proteus/
├── kl_wrapper.py # Base KL divergence wrapper
├── llm_training_wrapper.py # LLM training data collector
├── example.py # Basic usage examples
├── llm_training_example.py # LLM training examples
├── LLM_TRAINING_GUIDE.md # Complete LLM training guide
└── training_data/ # Generated training data (gitignored)
obs_bins- Discretization bins per observation dimension (higher = finer granularity)action_bins- Discretization bins per action dimensionhistory_size- Number of KL divergences to trackepsilon- Small constant for numerical stabilitynormalize_obs- Whether to normalize observations to [0, 1]normalize_action- Whether to normalize actions to [0, 1]
kl_penalty_weight- Weight for KL divergence penalty (0.05-0.5)kl_bonus_weight- Weight for low KL bonus (0.01-0.2)kl_threshold- KL threshold for penalty/bonus (0.5-2.0)format_type- Output format: "chat", "completion", or "preference"
Use uv to add new packages:
uv pip install <package-name>Then update pyproject.toml to track the dependency.
uv run basedpyright kl_wrapper.py llm_training_wrapper.py- LLM Fine-tuning - Train language models on RL tasks
- Policy Monitoring - Track policy stability during training
- Reward Shaping - Guide learning with KL-based rewards
- RLHF - Generate preference pairs for human feedback learning
- Exploration Analysis - Measure policy exploration vs exploitation
MIT