Skip to content

How to sample in image space with checkpoints trained using prediction="score" or prediction="noise"? #44

@Graeme-Lee

Description

@Graeme-Lee

Problem Description

I have trained a SiT model on medical images in image space (not latent space) using the --prediction noise setting. When attempting to sample from the trained checkpoint, the generated images do not make any sense.

Sampling command:

python sample.py ODE --prediction noise

Current Sampling Code

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
from torchvision.utils import save_image
from diffusers.models import AutoencoderKL
from download import find_model
from models import SiT_models
from train_utils import parse_ode_args, parse_sde_args, parse_transport_args
from transport import create_transport, Sampler
import argparse
import sys
from time import time


def main(mode, args):
    # Setup PyTorch:
    torch.manual_seed(args.seed)
    torch.set_grad_enabled(False)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    if args.ckpt is None:
        assert args.model == "SiT-XL/2", "Only SiT-XL/2 models are available for auto-download."
        assert args.image_size in [256, 512]
        assert args.num_classes == 1000
        assert args.image_size == 256, "512x512 models are not yet available for auto-download."
        learn_sigma = args.image_size == 256
    else:
        learn_sigma = False

    # Load model:
    latent_size = args.image_size // 1  # <- Is this correct for image-space training?
    model = SiT_models[args.model](
        input_size=latent_size,
        num_classes=args.num_classes,
        learn_sigma=learn_sigma,
    ).to(device)
    
    ckpt_path = args.ckpt or f"SiT-XL-2-{args.image_size}x{args.image_size}.pt"
    state_dict = find_model(ckpt_path)
    model.load_state_dict(state_dict)
    model.eval()
    
    transport = create_transport(
        args.path_type,
        args.prediction,
        args.loss_weight,
        args.train_eps,
        args.sample_eps
    )
    sampler = Sampler(transport)
    
    if mode == "ODE":
        if args.likelihood:
            assert args.cfg_scale == 1, "Likelihood is incompatible with guidance"
            sample_fn = sampler.sample_ode_likelihood(
                sampling_method=args.sampling_method,
                num_steps=args.num_sampling_steps,
                atol=args.atol,
                rtol=args.rtol,
            )
        else:
            sample_fn = sampler.sample_ode(
                sampling_method=args.sampling_method,
                num_steps=args.num_sampling_steps,
                atol=args.atol,
                rtol=args.rtol,
                reverse=args.reverse
            )
            
    elif mode == "SDE":
        sample_fn = sampler.sample_sde(
            sampling_method=args.sampling_method,
            diffusion_form=args.diffusion_form,
            diffusion_norm=args.diffusion_norm,
            last_step=args.last_step,
            last_step_size=args.last_step_size,
            num_steps=args.num_sampling_steps,
        )
    
    # Labels to condition the model with:
    class_labels = [0, 1]
    
    # Create sampling noise:
    n = len(class_labels)
    z = torch.randn(n, 1, latent_size, latent_size, device=device)  # <- Is 1 channel correct?
    y = torch.tensor(class_labels, device=device)

    # Setup classifier-free guidance:
    use_cfg = args.cfg_scale > 1.0
    if use_cfg:
        zs = torch.cat([z, z], 0)
        y_null = torch.tensor([1000] * n, device=device)
        ys = torch.cat([y, y_null], 0)
        model_kwargs = dict(y=ys, cfg_scale=args.cfg_scale)
        model_fn = model.forward_with_cfg
    else:
        model_kwargs = dict(y=y)
        model_fn = model.forward

    # Sample images:
    start_time = time()
    samples = sample_fn(z, model_fn, **model_kwargs)[-1]
    print(f"Sampling took {time() - start_time:.2f} seconds.")

    # Save images (VAE is loaded but never used for decoding):
    save_image(samples, f"sample.png", nrow=4, normalize=True, value_range=(0, 1))
if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    if len(sys.argv) < 2:
        print("Usage: program.py <mode> [options]")
        sys.exit(1)
    
    mode = sys.argv[1]

    assert mode[:2] != "--", "Usage: program.py <mode> [options]"
    assert mode in ["ODE", "SDE"], "Invalid mode. Please choose 'ODE' or 'SDE'"
    
    parser.add_argument("--model", type=str, choices=list(SiT_models.keys()), default="SiT-B/2")
    parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="mse")
    parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
    parser.add_argument("--num-classes", type=int, default=1000)
    parser.add_argument("--cfg-scale", type=float, default=1.0)
    parser.add_argument("--num-sampling-steps", type=int, default=1000)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--ckpt", type=str, default="/home/exouser/SiT/results/002-SiT-B-2-Linear-noise-None/checkpoints/0072000.pt",
                        help="Optional path to a SiT checkpoint (default: auto-download a pre-trained SiT-XL/2 model).")


    parse_transport_args(parser)
    if mode == "ODE":
        parse_ode_args(parser)
        # Further processing for ODE
    elif mode == "SDE":
        parse_sde_args(parser)
        # Further processing for SDE
    
    args = parser.parse_known_args()[0]
    main(mode, args)

Specific Questions

  1. VAE usage with image-space training:

    • The pre-trained ImageNet examples seem to use latent diffusion with VAE, but I trained directly on images.
  2. Latent size configuration:

    • I'm using latent_size = args.image_size // 1, making them equal. Is this correct for image-space training?
  3. Transport parameters alignment:

    • How do I ensure --path_type, --prediction, --train_eps, and --sample_eps match between training and sampling?
    • Are there any other critical hyperparameters that must match?
  4. Output value range:

    • What range should the sampled values be in before saving?
    • I'm using normalize=True, value_range=(0, 1) but unsure if this is correct.

Training Configuration

  • Checkpoint: Trained with --prediction noise
  • Training space: Image space (no VAE)
  • Dataset: Medical imaging (grayscale/range [0,1], float32)
  • Image size: 256x256
  • Model: SiT-B/2
  • Number of classes: 2

Thank you for your help!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions