Skip to content
/ csbm Public

[ICML 2025] Categorical Schrödinger Bridge Matching

License

Notifications You must be signed in to change notification settings

gregkseno/csbm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Categorical Schrödinger Bridge Matching (CSBM)

Grigoriy Ksenofontov, Alexander Korotin

arXiv Paper OpenReview Paper GitHub Hugging Face Model WandB GitHub License

This repository contains the official implementation of the paper "Categorical Schrödinger Bridge Matching", accepted at ICML 2025.

📌 TL;DR

This paper extends the Schrödinger Bridge problem to work with discrete time and spaces.

📦 Dependencies

Create the Anaconda environment using the following command:

conda env create -f environment.yml

🛠️ Preparations

Download Datasets

  1. Use this link to obtain the CelebA dataset;
  2. Follow these instructions to obtain the AFHQv2 dataset.

Additionally, for the CelebA dataset, rename the main folder to celeba, then rename celeba/img_align_celeba/img_align_celeba to celeba/img_align_celeba/raw.

Train VQ-GAN

  1. Configure the appropriate configuration file configs/vqgan_*.yaml.
  2. Run the corresponding quantize_*.sh script to save quantized images as .npy files in celeba/img_align_celeba/quantized/ or afhq/*/*/.

Tip

For more details on training VQ-GAN, refer to the official repository.

Train Tokenizer

  1. Set tokenizer.path in the main config file configs/amazon.yaml or configs/yelp.yaml
  2. Run train_tokenizer_*.sh to train the tokenizer.

🏋️‍♂️ Training

  1. Set the corresponding configuration files;
  2. Use the appropriate scripts or notebooks.
Experiment name Script/Notebook Configs (config/) Weights (W&B link)
Convergence of D-IMF on Discrete Spaces notebooks/convergence_d_imf.ipynb N/A N/A
Illustrative 2D Experiments train_swiss_roll.sh swiss_roll.yaml N/A
Unpaired Translation on Colored MNIST train_cmnist.sh cmnist.yaml CSBM
Unpaired Translation of CelebA Faces train_celeba.sh celeba.yaml, vqgan_celeba_f8_1024.yaml CSBM, VQ-GAN
Unpaired Translation of AFHQ Faces train_afhq.sh afhq.yaml, vqgan_afhq_f32_1024.yaml N/A
Unpaired Text Style Transfer of Amazon Reviews train_amazon.sh amazon.yaml CSBM, Tokenizer
Unpaired Text Style Transfer of Yelp Reviews train_yelp.sh yelp.yaml N/A

Tip

Set the exp_dir parameter in any train_*.sh script to define a custom path for saving experiment results, following the structure below:

data.type               # e.g., toy, images, etc.
`-- data.dataset        # e.g., swiss_roll, cmnist, etc. 
   `-- prior.type       # e.g., gaussian, uniform, etc. 
       |-- checkpoints 
       |   |-- forward_*
       |   |   `-- model.safetensors
       |   |-- ...
       |   |-- backward_*
       |   `-- ...
       |-- samples      # images of samples
       |-- trajectories # images of trajectories
       `-- config.yaml  # copied config

📊 Evaluation

  1. Specify the exp_path parameter, pointing to the saved experiment folder;
  2. Run eval_*.sh with the appropriate iteration argument.

Important

Reusing an earlier evaluation pipeline for the CelebA dataset may yield different results. In the article, images were generated first (see scripts/generate.py) and then evaluated with the following metrics (see notebooks/eval.ipynb):

🎓 Citation

@inproceedings{
  ksenofontov2025categorical,
  title={Categorical {Schr\"odinger} Bridge Matching},
  author={Grigoriy Ksenofontov and Alexander Korotin},
  booktitle={Forty-second International Conference on Machine Learning},
  year={2025},
  url={https://openreview.net/forum?id=RBly0nOr2h}
}

🙏 Credits

About

[ICML 2025] Categorical Schrödinger Bridge Matching

Resources

License

Stars

Watchers

Forks