Skip to content

ankitpatnala/WeGenDiffusion

Repository files navigation

Codebase for the implementation of Diffusion Transformers to generate temperature maps

This code and repository is built on the original implementation of Scalable Diffusion Models with Transformers (DiT)
Official PyTorch Implementation

This repo contains PyTorch model definitions, description, and training/sampling code for our implementation of generating temperature maps diffusion models with transformers (DiTs).

We train diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on patches.

Setup

Login to JURECA system

Load uv with command ml uv

Run command uv sync to create virtual environment

Training DiT

The training script for DiT is train.py. This script can be used to train class-conditional DiT models, but it can be easily modified to support other types of conditioning. To launch training

sbatch train_DiT_jsc.sh

Sampling DiT

The sampling script is sample_ddp.py. The script will generate samples from scratch.

sbatch test_DiT.sh

Here is a sample generated temperature map.

sample 1

Evaluation (FID, Inception Score, etc.)

We include a sample_ddp.py script which samples a large number of images from a DiT model in parallel. This script generates a folder of samples as well as a .npz file which can be directly used with ADM's TensorFlow evaluation suite to compute FID, Inception Score and other metrics. For example, to sample 50K images from our pre-trained DiT-XL/2 model over N GPUs, run:

Checkpoint path

/p/project1/training2533/patnala1/WeGenDiffusion/results/DiT-B-2

Group Tasks

Team 1: Implement Evaluation Metrics

Team 2: Conditioning based on seasonal forcings

Team 3: Conditioning based on seasonal forcings + Day + Hour

Team 4: Conditioning based on temperature from previous time steps

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •