This repository contains the official implementation of the paper "Categorical Schrödinger Bridge Matching", accepted at ICML 2025.
This paper extends the Schrödinger Bridge problem to work with discrete time and spaces.
Create the Anaconda environment using the following command:
conda env create -f environment.yml- Use this link to obtain the CelebA dataset;
- Follow these instructions to obtain the AFHQv2 dataset.
Additionally, for the CelebA dataset, rename the main folder to celeba, then rename celeba/img_align_celeba/img_align_celeba to celeba/img_align_celeba/raw.
- Configure the appropriate configuration file
configs/vqgan_*.yaml. - Run the corresponding
quantize_*.shscript to save quantized images as.npyfiles inceleba/img_align_celeba/quantized/orafhq/*/*/.
Tip
For more details on training VQ-GAN, refer to the official repository.
- Set
tokenizer.pathin the main config fileconfigs/amazon.yamlorconfigs/yelp.yaml - Run
train_tokenizer_*.shto train the tokenizer.
- Set the corresponding configuration files;
- Use the appropriate scripts or notebooks.
| Experiment name | Script/Notebook | Configs (config/) |
Weights (W&B link) |
|---|---|---|---|
| Convergence of D-IMF on Discrete Spaces | notebooks/convergence_d_imf.ipynb |
N/A | N/A |
| Illustrative 2D Experiments | train_swiss_roll.sh |
swiss_roll.yaml |
N/A |
| Unpaired Translation on Colored MNIST | train_cmnist.sh |
cmnist.yaml |
CSBM |
| Unpaired Translation of CelebA Faces | train_celeba.sh |
celeba.yaml, vqgan_celeba_f8_1024.yaml |
CSBM, VQ-GAN |
| Unpaired Translation of AFHQ Faces | train_afhq.sh |
afhq.yaml, vqgan_afhq_f32_1024.yaml |
N/A |
| Unpaired Text Style Transfer of Amazon Reviews | train_amazon.sh |
amazon.yaml |
CSBM, Tokenizer |
| Unpaired Text Style Transfer of Yelp Reviews | train_yelp.sh |
yelp.yaml |
N/A |
Tip
Set the exp_dir parameter in any train_*.sh script to define a custom path for saving experiment results, following the structure below:
data.type # e.g., toy, images, etc.
`-- data.dataset # e.g., swiss_roll, cmnist, etc.
`-- prior.type # e.g., gaussian, uniform, etc.
|-- checkpoints
| |-- forward_*
| | `-- model.safetensors
| |-- ...
| |-- backward_*
| `-- ...
|-- samples # images of samples
|-- trajectories # images of trajectories
`-- config.yaml # copied config- Specify the
exp_pathparameter, pointing to the saved experiment folder; - Run
eval_*.shwith the appropriateiterationargument.
Important
Reusing an earlier evaluation pipeline for the CelebA dataset may yield different results. In the article, images were generated first (see scripts/generate.py) and then evaluated with the following metrics (see notebooks/eval.ipynb):
- FID from pytorch-fid
- CMMD from cmmd-pytorch
- LPIPS from torchmetrics
@inproceedings{
ksenofontov2025categorical,
title={Categorical {Schr\"odinger} Bridge Matching},
author={Grigoriy Ksenofontov and Alexander Korotin},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=RBly0nOr2h}
}- Weights & Biases — experiment-tracking and visualization toolkit;
- Hugging Face — Tokenizers and Accelerate libraries for tokenizer implementation, parallel training, and checkpoint hosting on the Hub;
- D3PM — reference implementation of discrete-diffusion models;
- Taming Transformers — original VQ-GAN codebase;
- VQ-Diffusion — vector-quantized diffusion architecture;
- MDLM — diffusion architecture for text-generation experiments;
- ASBM — evaluation metrics and baseline models for CelebA face transfer;
- Balancing the Style-Content Trade-Off in Sentiment Transfer Using Polarity-Aware Denoising — processed Amazon Reviews dataset and sentiment-transfer baselines;
- Inkscape — an excellent open-source editor for vector graphics.