Skip to content

RL gym wrapper to adversarially modify reward signals

License

Notifications You must be signed in to change notification settings

AnkilP/adaptive_rl_env

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

KL Proteus

Environment wrapper for calculating joint KL divergence between subsequent observations and actions, with built-in support for LLM fine-tuning.

Features

  • KL Divergence Tracking - Calculate joint KL divergence between (observation, action) pairs
  • LLM Training Data Collection - Automatically format trajectories for LLM fine-tuning
  • Reward Shaping - Shape rewards based on KL divergence to encourage stable policies
  • Multiple Output Formats - Export data in chat, completion, or preference pair formats
  • Gymnasium Compatible - Works with any Gymnasium environment

Setup

This project uses uv for package management. Do not use pip.

Create virtual environment

uv venv

Install dependencies

uv pip install -e .

Install dev dependencies

uv pip install -e ".[dev]"

Quick Start

Basic KL Divergence Tracking

import gymnasium as gym
from kl_wrapper import JointKLDivergenceWrapper

# Wrap any Gymnasium environment
env = gym.make("CartPole-v1")
env = JointKLDivergenceWrapper(
    env,
    obs_bins=10,      # Discretization resolution
    action_bins=2,    # For discrete actions
    history_size=1000 # KL history window
)

# Use like any Gym environment
obs, info = env.reset()
obs, reward, terminated, truncated, info = env.step(action)

# Access KL divergence metrics
print(info['kl_divergence'])          # Current KL
print(info['kl_divergence_mean'])     # Running mean
print(info['kl_divergence_std'])      # Running std

# Get statistics
stats = env.get_kl_statistics()
print(f"Mean KL: {stats['mean']:.4f}")

LLM Training Data Collection

import gymnasium as gym
from llm_training_wrapper import LLMTrainingWrapper

# Create environment with LLM training
env = LLMTrainingWrapper(
    gym.make("CartPole-v1"),
    output_dir="training_data",
    format_type="chat",  # or "completion" or "preference"
    kl_penalty_weight=0.1,
    kl_bonus_weight=0.05,
)

# Collect trajectories
for episode in range(10):
    obs, _ = env.reset()
    while True:
        action = your_policy(obs)
        obs, reward, term, trunc, info = env.step(action)
        if term or trunc:
            break

# Save training data in JSONL format
env.save_all_trajectories("llm_training_data.jsonl")

# Check statistics
stats = env.get_training_statistics()
print(f"Collected {stats['num_steps']} steps")

Examples

Run the basic KL divergence examples:

uv run python example.py

Run the LLM training data collection examples:

uv run python llm_training_example.py

Documentation

How It Works

KL Divergence Calculation

The wrapper discretizes observations and actions into bins, then calculates the KL divergence between joint distributions:

KL(P_t || P_{t-1})

where P_t is the distribution at time t and P_{t-1} is at time t-1.

  • Low KL divergence → Policy is stable/consistent
  • High KL divergence → Policy is changing rapidly

Reward Shaping for LLM Training

The LLM training wrapper shapes rewards to encourage stable policies:

shaped_reward = base_reward + kl_bonus - kl_penalty

This guides the LLM to:

  • Maintain policy consistency
  • Avoid catastrophic forgetting
  • Balance exploration and stability

Data Formats

Three output formats for different LLM training scenarios:

  1. Chat Format - For instruction-following models (ChatGPT-style)
  2. Completion Format - For prompt-completion models (GPT-3 style)
  3. Preference Format - For RLHF/preference learning

See LLM_TRAINING_GUIDE.md for details.

Project Structure

KL_proteus/
├── kl_wrapper.py              # Base KL divergence wrapper
├── llm_training_wrapper.py    # LLM training data collector
├── example.py                 # Basic usage examples
├── llm_training_example.py    # LLM training examples
├── LLM_TRAINING_GUIDE.md      # Complete LLM training guide
└── training_data/             # Generated training data (gitignored)

Configuration

KL Divergence Parameters

  • obs_bins - Discretization bins per observation dimension (higher = finer granularity)
  • action_bins - Discretization bins per action dimension
  • history_size - Number of KL divergences to track
  • epsilon - Small constant for numerical stability
  • normalize_obs - Whether to normalize observations to [0, 1]
  • normalize_action - Whether to normalize actions to [0, 1]

LLM Training Parameters

  • kl_penalty_weight - Weight for KL divergence penalty (0.05-0.5)
  • kl_bonus_weight - Weight for low KL bonus (0.01-0.2)
  • kl_threshold - KL threshold for penalty/bonus (0.5-2.0)
  • format_type - Output format: "chat", "completion", or "preference"

Development

Adding packages

Use uv to add new packages:

uv pip install <package-name>

Then update pyproject.toml to track the dependency.

Running linter

uv run basedpyright kl_wrapper.py llm_training_wrapper.py

Use Cases

  • LLM Fine-tuning - Train language models on RL tasks
  • Policy Monitoring - Track policy stability during training
  • Reward Shaping - Guide learning with KL-based rewards
  • RLHF - Generate preference pairs for human feedback learning
  • Exploration Analysis - Measure policy exploration vs exploitation

License

MIT

About

RL gym wrapper to adversarially modify reward signals

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published