diff --git a/README.md b/README.md index 4a8b803..58c86d9 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,21 @@ +# Optimized DeepFloyd IF by neonsecret + +Tested on rtx 3070, 8 gb vram + +stage 1: 1.30 min for 27 steps, ~3.35 sec per iteration \ +stage 2: 11 seconds for 27 steps \ +stage 3: 30 seconds for 40 steps + +### To run the ui: + +```bash +python run_ui.py +``` + +All the models are automatically downloaded. + +Original readme: + [![License](https://img.shields.io/badge/Code_License-Modified_MIT-blue.svg)](LICENSE) [![License](https://img.shields.io/badge/Weights_License-DeepFloyd_IF-orange.svg)](LICENSE-MODEL) [![Downloads](https://pepy.tech/badge/deepfloyd_if)](https://pepy.tech/project/deepfloyd_if) @@ -11,21 +29,32 @@

-We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for text-to-image synthesis. +We introduce DeepFloyd IF, a novel state-of-the-art open-source text-to-image model with a high degree of photorealism +and language understanding. DeepFloyd IF is a modular composed of a frozen text encoder and three cascaded pixel +diffusion modules: a base model that generates 64x64 px image based on text prompt and two super-resolution models, each +designed to generate images of increasing resolution: 256x256 px and 1024x1024 px. All stages of the model utilize a +frozen text encoder based on the T5 transformer to extract text embeddings, which are then fed into a UNet architecture +enhanced with cross-attention and attention pooling. The result is a highly efficient model that outperforms current +state-of-the-art models, achieving a zero-shot FID score of 6.66 on the COCO dataset. Our work underscores the potential +of larger UNet architectures in the first stage of cascaded diffusion models and depicts a promising future for +text-to-image synthesis.

-*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding*](https://arxiv.org/pdf/2205.11487.pdf) +*Inspired by* [*Photorealistic Text-to-Image Diffusion Models with Deep Language +Understanding*](https://arxiv.org/pdf/2205.11487.pdf) ## Minimum requirements to use all IF models: + - 16GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) -- 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to 1024x1024 upscaler) +- 24GB vRAM for IF-I-XL (4.3B text to 64x64 base module) & IF-II-L (1.2B to 256x256 upscaler module) & Stable x4 (to + 1024x1024 upscaler) - `xformers` and set env variable `FORCE_MEM_EFFICIENT_ATTN=1` - ## Quick Start + [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/deepfloyd_if_free_tier_google_colab.ipynb) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/DeepFloyd/IF) @@ -36,25 +65,28 @@ pip install git+https://github.com/openai/CLIP.git --no-deps ``` ## Local notebooks + [![Jupyter Notebook](https://img.shields.io/badge/jupyter_notebook-%23FF7A01.svg?logo=jupyter&logoColor=white)](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb) [![Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://www.kaggle.com/code/shonenkov/deepfloyd-if-4-3b-generator-of-pictures) -The Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb). - - +The Dream, Style Transfer, Super Resolution or Inpainting modes are avaliable in a Jupyter +Notebook [here](https://huggingface.co/DeepFloyd/IF-notebooks/blob/main/pipes-DeepFloyd-IF-v1.0.ipynb). ## Integration with 🤗 Diffusers IF is also integrated with the 🤗 Hugging Face [Diffusers library](https://github.com/huggingface/diffusers/). -Diffusers runs each stage individually allowing the user to customize the image generation process as well as allowing to inspect intermediate results easily. +Diffusers runs each stage individually allowing the user to customize the image generation process as well as allowing +to inspect intermediate results easily. ### Example Before you can use IF, you need to accept its usage conditions. To do so: + 1. Make sure to have a [Hugging Face account](https://huggingface.co/join) and be loggin in 2. Accept the license on the model card of [DeepFloyd/IF-I-XL-v1.0](https://huggingface.co/DeepFloyd/IF-I-XL-v1.0) 3. Make sure to login locally. Install `huggingface_hub` + ```sh pip install huggingface_hub --upgrade ``` @@ -67,7 +99,8 @@ from huggingface_hub import login login() ``` -and enter your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens). +and enter +your [Hugging Face Hub access token](https://huggingface.co/docs/hub/security-tokens#what-are-user-access-tokens). Next we install `diffusers` and dependencies: @@ -77,7 +110,9 @@ pip install diffusers accelerate transformers safetensors And we can now run the model locally. -By default `diffusers` makes use of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) to run the whole IF pipeline with as little as 14 GB of VRAM. +By default `diffusers` makes use +of [model cpu offloading](https://huggingface.co/docs/diffusers/optimization/fp16#model-offloading-for-fast-inference-and-memory-savings) +to run the whole IF pipeline with as little as 14 GB of VRAM. If you are using `torch>=2.0.0`, make sure to **delete all** `enable_xformers_memory_efficient_attention()` functions. @@ -100,8 +135,10 @@ stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__v stage_2.enable_model_cpu_offload() # stage 3 -safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, "watermarker": stage_1.watermarker} -stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, torch_dtype=torch.float16) +safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, + "watermarker": stage_1.watermarker} +stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, + torch_dtype=torch.float16) stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 stage_3.enable_model_cpu_offload() @@ -113,12 +150,14 @@ prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) generator = torch.manual_seed(0) # stage 1 -image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt").images +image = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, + output_type="pt").images pt_to_pil(image)[0].save("./if_stage_I.png") # stage 2 image = stage_2( - image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, output_type="pt" + image=image, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, + output_type="pt" ).images pt_to_pil(image)[0].save("./if_stage_II.png") @@ -127,12 +166,16 @@ image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100 image[0].save("./if_stage_III.png") ``` - There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To do so, please have a look at the Diffusers docs: +There are multiple ways to speed up the inference time and lower the memory consumption even more with `diffusers`. To +do so, please have a look at the Diffusers docs: - 🚀 [Optimizing for inference time](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-speed) -- ⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory) +- +⚙️ [Optimizing for low memory during inference](https://huggingface.co/docs/diffusers/api/pipelines/if#optimizing-for-memory) -For more in-detail information about how to use IF, please have a look at [the IF blog post](https://huggingface.co/blog/if) and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖. +For more in-detail information about how to use IF, please have a look +at [the IF blog post](https://huggingface.co/blog/if) +and [the documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/if) 📖. ## Run the code locally @@ -150,6 +193,7 @@ t5 = T5Embedder(device="cpu") ``` ### I. Dream + Dream is the text-to-image mode of the IF model ```python @@ -160,7 +204,7 @@ count = 4 result = dream( t5=t5, if_I=if_I, if_II=if_II, if_III=if_III, - prompt=[prompt]*count, + prompt=[prompt] * count, seed=42, if_I_kwargs={ "guidance_scale": 7.0, @@ -179,6 +223,7 @@ result = dream( if_III.show(result['III'], size=14) ``` + ![](./pics/dream-III.jpg) ## II. Zero-shot Image-to-Image Translation @@ -186,6 +231,7 @@ if_III.show(result['III'], size=14) ![](./pics/img_to_img_scheme.jpeg) In Style Transfer mode, the output of your prompt comes out at the style of the `support_pil_img` + ```python from deepfloyd_if.pipelines import style_transfer @@ -215,9 +261,10 @@ if_I.show(result['II'], 1, 20) ![Alternative Text](./pics/deep_floyd_if_image_2_image.gif) - ## III. Super Resolution -For super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated by IF (two cascades): + +For super-resolution, users can run `IF-II` and `IF-III` or 'Stable x4' on an image that was not necessarely generated +by IF (two cascades): ```python from deepfloyd_if.pipelines import super_resolution @@ -253,7 +300,6 @@ show_superres(raw_pil_image, high_res['III'][0]) ![](./pics/if_as_upscaler.jpg) - ### IV. Zero-shot Inpainting ```python @@ -289,9 +335,11 @@ if_I.show(result['I'], 2, 3) if_I.show(result['II'], 2, 6) if_I.show(result['III'], 2, 14) ``` + ![](./pics/deep_floyd_if_inpainting.gif) ### 🤗 Model Zoo 🤗 + The link to download the weights as well as the model cards will be available soon on each model of the model zoo #### Original @@ -305,7 +353,7 @@ The link to download the weights as well as the model cards will be available so | [IF-II-L](https://huggingface.co/DeepFloyd/IF-II-L-v1.0)* | II | 1.2B | - | 1536 | 2.5M | | IF-III-L* _(soon)_ | III | 700M | - | 3072 | 1.25M | - *best modules +*best modules ### Quantitative Evaluation @@ -315,16 +363,19 @@ The link to download the weights as well as the model cards will be available so ## License -The code in this repository is released under the bespoke license (see added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)). +The code in this repository is released under the bespoke license (see +added [point two](https://github.com/deep-floyd/IF/blob/main/LICENSE#L13)). -The weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd) and have their own LICENSE. +The weights will be available soon via [the DeepFloyd organization at Hugging Face](https://huggingface.co/DeepFloyd) +and have their own LICENSE. -**Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.* +**Disclaimer:** *The initial release of the IF model is under a restricted research-purposes-only license temporarily to +gather feedback, and after that we intend to release a fully open-source model in line with other Stability AI models.* ## Limitations and Biases -The models available in this codebase have known limitations and biases. Please refer to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information. - +The models available in this codebase have known limitations and biases. Please refer +to [the model card](https://huggingface.co/DeepFloyd/IF-I-L-v1.0) for more information. ## 🎓 DeepFloyd IF creators: @@ -335,17 +386,26 @@ The models available in this codebase have known limitations and biases. Please - Ksenia Ivanova [GitHub](https://github.com/ivksu) | [Twitter](https://twitter.com/susiaiv) - Nadiia Klokova [GitHub](https://github.com/vauimpuls) | [Twitter](https://twitter.com/vauimpuls) - ## 📄 Research Paper (Soon) ## Acknowledgements -Special thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai) and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory consumption during inference, creating demos and giving cool advice! +Special thanks to [StabilityAI](http://stability.ai) and its CEO [Emad Mostaque](https://twitter.com/emostaque) for +invaluable support, providing GPU compute and infrastructure to train the models (our gratitude goes +to [Richard Vencu](https://github.com/rvencu)); thanks to [LAION](https://laion.ai) +and [Christoph Schuhmann](https://github.com/christophschuhmann) in particular for contribution to the project and +well-prepared datasets; thanks to [Huggingface](https://huggingface.co) teams for optimizing models' speed and memory +consumption during inference, creating demos and giving cool advice! ## 🚀 External Contributors 🚀 -- The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a friendly atmosphere in difficult moments 🦉; + +- The Biggest Thanks [@Apolinário](https://github.com/apolinario), for ideas, consultations, help and support on all + stages to make IF available in open-source; for writing a lot of documentation and instructions; for creating a + friendly atmosphere in difficult moments 🦉; - Thanks, [@patrickvonplaten](https://github.com/patrickvonplaten), for improving loading time of unet models by 80%; -for integration Stable-Diffusion-x4 as native pipeline 💪; -- Thanks, [@williamberman](https://github.com/williamberman) and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌; -- Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀; + for integration Stable-Diffusion-x4 as native pipeline 💪; +- Thanks, [@williamberman](https://github.com/williamberman) + and [@patrickvonplaten](https://github.com/patrickvonplaten) for diffusers integration 🙌; +- Thanks, [@hysts](https://github.com/hysts) and [@Apolinário](https://github.com/apolinario) for + creating [the best gradio demo with IF](https://huggingface.co/spaces/DeepFloyd/IF) 🚀; - Thanks, [@Dango233](https://github.com/Dango233), for adapting IF with xformers memory efficient attention 💪; diff --git a/deepfloyd_if/model/__init__.py b/deepfloyd_if/model/__init__.py index 332da3e..55e58ca 100644 --- a/deepfloyd_if/model/__init__.py +++ b/deepfloyd_if/model/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from .unet import UNetModel, SuperResUNetModel +from .unet_split import UNetSplitModel -__all__ = ['UNetModel', 'SuperResUNetModel'] +__all__ = ['UNetModel', 'SuperResUNetModel', 'UNetSplitModel'] diff --git a/deepfloyd_if/model/gaussian_diffusion.py b/deepfloyd_if/model/gaussian_diffusion.py index e058fbc..2394b31 100644 --- a/deepfloyd_if/model/gaussian_diffusion.py +++ b/deepfloyd_if/model/gaussian_diffusion.py @@ -110,13 +110,13 @@ class GaussianDiffusion: """ def __init__( - self, - *, - betas, - model_mean_type, - model_var_type, - loss_type, - rescale_timesteps=False, + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type, + rescale_timesteps=False, ): self.model_mean_type = model_mean_type self.model_var_type = model_var_type @@ -146,7 +146,7 @@ def __init__( # calculations for posterior q(x_{t-1} | x_t, x_0) self.posterior_variance = ( - betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) # log calculation clipped because the posterior variance is 0 at the # beginning of the diffusion chain. @@ -154,12 +154,12 @@ def __init__( np.append(self.posterior_variance[1], self.posterior_variance[1:]) ) self.posterior_mean_coef1 = ( - betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) ) self.posterior_mean_coef2 = ( - (1.0 - self.alphas_cumprod_prev) - * np.sqrt(alphas) - / (1.0 - self.alphas_cumprod) + (1.0 - self.alphas_cumprod_prev) + * np.sqrt(alphas) + / (1.0 - self.alphas_cumprod) ) def dynamic_thresholding(self, x, p=0.995, c=1.7): @@ -189,7 +189,7 @@ def q_mean_variance(self, x_start, t): :return: A tuple (mean, variance, log_variance), all of x_start's shape. """ mean = ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start ) variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) log_variance = _extract_into_tensor( @@ -210,9 +210,9 @@ def q_sample(self, x_start, t, noise=None): noise = torch.randn_like(x_start) assert noise.shape == x_start.shape return ( - _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) - * noise + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise ) def q_posterior_mean_variance(self, x_start, x_t, t): @@ -222,24 +222,24 @@ def q_posterior_mean_variance(self, x_start, x_t, t): """ assert x_start.shape == x_t.shape posterior_mean = ( - _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start - + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) posterior_log_variance_clipped = _extract_into_tensor( self.posterior_log_variance_clipped, t, x_t.shape ) assert ( - posterior_mean.shape[0] - == posterior_variance.shape[0] - == posterior_log_variance_clipped.shape[0] - == x_start.shape[0] + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] ) return posterior_mean, posterior_variance, posterior_log_variance_clipped def p_mean_variance( - self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, - denoised_fn=None, model_kwargs=None + self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, + denoised_fn=None, model_kwargs=None ): """ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of @@ -280,7 +280,7 @@ def p_mean_variance( max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) # The model_var_values is [-1, 1] for [min_var, max_var]. frac = (model_var_values + 1) / 2 - model_log_variance = frac * max_log + (1 - frac) * min_log + model_log_variance = frac * max_log.to(frac.device) + (1 - frac) * min_log.to(frac.device) model_variance = torch.exp(model_log_variance) else: model_variance, model_log_variance = { @@ -306,6 +306,7 @@ def process_xstart(x): return x # x.clamp(-1, 1) return x + x, t = x.to(model_output.device), t.to(model_output.device) if self.model_mean_type == ModelMeanType.PREVIOUS_X: pred_xstart = process_xstart( self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output) @@ -325,7 +326,7 @@ def process_xstart(x): raise NotImplementedError(self.model_mean_type) assert ( - model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape ) return { 'mean': model_mean, @@ -337,25 +338,25 @@ def process_xstart(x): def _predict_xstart_from_eps(self, x_t, t, eps): assert x_t.shape == eps.shape return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) def _predict_xstart_from_xprev(self, x_t, t, xprev): assert x_t.shape == xprev.shape return ( # (xprev - coef2*x_t) / coef1 - _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev - - _extract_into_tensor( - self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape - ) - * x_t + _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev + - _extract_into_tensor( + self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape + ) + * x_t ) def _predict_eps_from_xstart(self, x_t, t, pred_xstart): return ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - pred_xstart - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) def _scale_timesteps(self, t): if self.rescale_timesteps: @@ -363,8 +364,8 @@ def _scale_timesteps(self, t): return t def p_sample( - self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, - denoised_fn=None, model_kwargs=None, inpainting_mask=None, + self, model, x, t, clip_denoised=True, dynamic_thresholding_p=0.99, dynamic_thresholding_c=1.7, + denoised_fn=None, model_kwargs=None, inpainting_mask=None, ): """ Sample x_{t-1} from the model at the given timestep. @@ -390,31 +391,36 @@ def p_sample( denoised_fn=denoised_fn, model_kwargs=model_kwargs, ) - noise = torch.randn_like(x) + device = out['mean'].device + noise = torch.randn_like(x, device=device) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) - ) # no noise when t == 0 + ).to(device) # no noise when t == 0 if inpainting_mask is None: - inpainting_mask = torch.ones_like(x, device=x.device) + inpainting_mask = torch.ones_like(x, device=device) + + x, t = x.to(device), t.to(device) - sample = out['mean'] + nonzero_mask * torch.exp(0.5 * out['log_variance']) * noise - sample = (1 - inpainting_mask)*x + inpainting_mask*sample + noise = (torch.exp(0.5 * out['log_variance']) * noise).to(device) + + sample = out['mean'] + nonzero_mask * noise + sample = (1 - inpainting_mask) * x + inpainting_mask * sample return {'sample': sample, 'pred_xstart': out['pred_xstart']} def p_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - inpainting_mask=None, - denoised_fn=None, - model_kwargs=None, - device=None, - progress=False, - sample_fn=None, + self, + model, + shape, + noise=None, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + inpainting_mask=None, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + sample_fn=None, ): """ Generate samples from the model. @@ -434,17 +440,17 @@ def p_sample_loop( """ final = None for step_idx, sample in enumerate(self.p_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - dynamic_thresholding_p=dynamic_thresholding_p, - dynamic_thresholding_c=dynamic_thresholding_c, - denoised_fn=denoised_fn, - inpainting_mask=inpainting_mask, - model_kwargs=model_kwargs, - device=device, - progress=progress, + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + dynamic_thresholding_p=dynamic_thresholding_p, + dynamic_thresholding_c=dynamic_thresholding_c, + denoised_fn=denoised_fn, + inpainting_mask=inpainting_mask, + model_kwargs=model_kwargs, + device=device, + progress=progress, )): if sample_fn is not None: sample = sample_fn(step_idx, sample) @@ -452,18 +458,18 @@ def p_sample_loop( return final['sample'] def p_sample_loop_progressive( - self, - model, - shape, - inpainting_mask=None, - noise=None, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - device=None, - progress=False, + self, + model, + shape, + inpainting_mask=None, + noise=None, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, ): """ Generate samples from the model and yield intermediate samples from @@ -472,8 +478,6 @@ def p_sample_loop_progressive( Returns a generator over dicts, where each dict is the return value of p_sample(). """ - if device is None: - device = next(model.parameters()).device assert isinstance(shape, (tuple, list)) if noise is not None: img = noise @@ -505,16 +509,16 @@ def p_sample_loop_progressive( img = out['sample'] def ddim_sample( - self, - model, - x, - t, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - eta=0.0, + self, + model, + x, + t, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + eta=0.0, ): """ Sample x_{t-1} from the model using DDIM. @@ -536,15 +540,15 @@ def ddim_sample( alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) sigma = ( - eta - * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) - * torch.sqrt(1 - alpha_bar / alpha_bar_prev) + eta + * torch.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * torch.sqrt(1 - alpha_bar / alpha_bar_prev) ) # Equation 12. noise = torch.randn_like(x) mean_pred = ( - out['pred_xstart'] * torch.sqrt(alpha_bar_prev) - + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + out['pred_xstart'] * torch.sqrt(alpha_bar_prev) + + torch.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps ) nonzero_mask = ( (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) @@ -553,16 +557,16 @@ def ddim_sample( return {'sample': sample, 'pred_xstart': out['pred_xstart']} def ddim_reverse_sample( - self, - model, - x, - t, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - eta=0.0, + self, + model, + x, + t, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + eta=0.0, ): """ Sample x_{t+1} from the model using DDIM reverse ODE. @@ -581,33 +585,33 @@ def ddim_reverse_sample( # Usually our model outputs epsilon, but we re-derive it # in case we used x_start or x_prev prediction. eps = ( - _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - - out['pred_xstart'] - ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out['pred_xstart'] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) # Equation 12. reversed mean_pred = ( - out['pred_xstart'] * torch.sqrt(alpha_bar_next) - + torch.sqrt(1 - alpha_bar_next) * eps + out['pred_xstart'] * torch.sqrt(alpha_bar_next) + + torch.sqrt(1 - alpha_bar_next) * eps ) return {'sample': mean_pred, 'pred_xstart': out['pred_xstart']} def ddim_sample_loop( - self, - model, - shape, - noise=None, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, - sample_fn=None, + self, + model, + shape, + noise=None, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + sample_fn=None, ): """ Generate samples from the model using DDIM. @@ -615,17 +619,17 @@ def ddim_sample_loop( """ final = None for step_idx, sample in enumerate(self.ddim_sample_loop_progressive( - model, - shape, - noise=noise, - clip_denoised=clip_denoised, - denoised_fn=denoised_fn, - dynamic_thresholding_p=dynamic_thresholding_p, - dynamic_thresholding_c=dynamic_thresholding_c, - model_kwargs=model_kwargs, - device=device, - progress=progress, - eta=eta, + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + dynamic_thresholding_p=dynamic_thresholding_p, + dynamic_thresholding_c=dynamic_thresholding_c, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, )): if sample_fn is not None: sample = sample_fn(step_idx, sample) @@ -633,18 +637,18 @@ def ddim_sample_loop( return final['sample'] def ddim_sample_loop_progressive( - self, - model, - shape, - noise=None, - clip_denoised=True, - dynamic_thresholding_p=0.99, - dynamic_thresholding_c=1.7, - denoised_fn=None, - model_kwargs=None, - device=None, - progress=False, - eta=0.0, + self, + model, + shape, + noise=None, + clip_denoised=True, + dynamic_thresholding_p=0.99, + dynamic_thresholding_c=1.7, + denoised_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, ): """ Use DDIM to sample from the model and yield intermediate samples from @@ -684,7 +688,7 @@ def ddim_sample_loop_progressive( img = out['sample'] def _vb_terms_bpd( - self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None ): """ Get a term for the variational lower-bound. @@ -871,7 +875,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ - res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps.to(torch.long)].float() while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) diff --git a/deepfloyd_if/model/nn.py b/deepfloyd_if/model/nn.py index 4f1a0f0..a957825 100644 --- a/deepfloyd_if/model/nn.py +++ b/deepfloyd_if/model/nn.py @@ -171,17 +171,16 @@ def timestep_embedding(timesteps, dim, max_period=10000, dtype=None): :param max_period: controls the minimum frequency of the embeddings. :return: an [N x dim] Tensor of positional embeddings. """ - if dtype is None: - dtype = torch.float32 + dtype2 = torch.float32 half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half - ).to(device=timesteps.device, dtype=dtype) - args = timesteps[:, None].type(dtype) * freqs[None] + ).to(device=timesteps.device, dtype=dtype2) + args = timesteps[:, None].type(dtype2) * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding + return embedding.to(dtype) def attention(q, k, v, d_k): diff --git a/deepfloyd_if/model/unet.py b/deepfloyd_if/model/unet.py index bb83590..d82fdeb 100644 --- a/deepfloyd_if/model/unet.py +++ b/deepfloyd_if/model/unet.py @@ -11,10 +11,7 @@ from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \ AttentionPooling -_FORCE_MEM_EFFICIENT_ATTN = int(os.environ.get('FORCE_MEM_EFFICIENT_ATTN', 0)) -print('FORCE_MEM_EFFICIENT_ATTN=', _FORCE_MEM_EFFICIENT_ATTN, '@UNET:QKVATTENTION') -if _FORCE_MEM_EFFICIENT_ATTN: - from xformers.ops import memory_efficient_attention # noqa +from xformers.ops import memory_efficient_attention # noqa class TimestepBlock(nn.Module): @@ -246,7 +243,7 @@ def __init__( self.num_heads = num_heads else: assert ( - channels % num_head_channels == 0 + channels % num_head_channels == 0 ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' self.num_heads = channels // num_head_channels self.norm = normalization(channels, dtype=self.dtype) @@ -310,16 +307,16 @@ def forward(self, qkv, encoder_kv=None): k = torch.cat([ek, k], dim=-1) v = torch.cat([ev, v], dim=-1) scale = 1 / math.sqrt(math.sqrt(ch)) - if _FORCE_MEM_EFFICIENT_ATTN: - q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v)) - a = memory_efficient_attention(q, k, v) - a = a.permute(0, 2, 1) - else: - weight = torch.einsum( - 'bct,bcs->bts', q * scale, k * scale - ) # More stable with f16 than dividing afterwards - weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) - a = torch.einsum('bts,bcs->bct', weight, v) + # if True: # legacy + q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v)) + a = memory_efficient_attention(q, k, v) + a = a.permute(0, 2, 1) + # else: + # weight = torch.einsum( + # 'bct,bcs->bts', q * scale, k * scale + # ) # More stable with f16 than dividing afterwards + # weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + # a = torch.einsum('bts,bcs->bct', weight, v) return a.reshape(bs, -1, length) @@ -456,7 +453,7 @@ def __init__( ds = 1 if isinstance(num_res_blocks, int): - num_res_blocks = [num_res_blocks]*len(self.channel_mult) + num_res_blocks = [num_res_blocks] * len(self.channel_mult) self.num_res_blocks = num_res_blocks for level, mult in enumerate(self.channel_mult): @@ -632,7 +629,7 @@ def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, if use_cache and self.cache is not None: encoder_out, encoder_pool = self.cache else: - text_emb = text_emb.type(self.dtype) + text_emb = text_emb.type(self.dtype).to(x.device) encoder_out = self.encoder_proj(text_emb) encoder_out = encoder_out.permute(0, 2, 1) # NLC -> NCL if timestep_text_emb is None: @@ -677,6 +674,7 @@ def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwar get_activation(kwargs['activation']), linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype), ) + self.primary_device = torch.device(0) def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): bs, _, new_height, new_width = x.shape @@ -687,18 +685,19 @@ def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): upsampled = F.interpolate( low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners - ) + ).to(x.device) if aug_level is None: - aug_steps = (np.random.random(bs)*1000).astype(np.int64) # uniform [0, 1) + aug_steps = (np.random.random(bs) * 1000).astype(np.int64) # uniform [0, 1) aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long) else: aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long) + aug_steps = aug_steps.to(self.dtype) + upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps) x = torch.cat([x, upsampled], dim=1) - aug_emb = self.aug_proj( - timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype) - ) + timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype).to(self.dtype) + ).to(x.device) return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs) diff --git a/deepfloyd_if/model/unet_split.py b/deepfloyd_if/model/unet_split.py new file mode 100644 index 0000000..bdcdb8d --- /dev/null +++ b/deepfloyd_if/model/unet_split.py @@ -0,0 +1,744 @@ +# -*- coding: utf-8 -*- +import gc +import os +import math +import time +from abc import abstractmethod + +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from .nn import avg_pool_nd, conv_nd, linear, normalization, timestep_embedding, zero_module, get_activation, \ + AttentionPooling + +from xformers.ops import memory_efficient_attention # noqa + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, encoder_out=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, AttentionBlock): + x = layer(x, encoder_out) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining a convolution is applied. + :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.dtype = dtype + if use_conv: + self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=1, dtype=self.dtype) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode='nearest') + else: + if self.dtype == torch.bfloat16: + x = x.type(torch.float32 if x.device.type == 'cpu' else torch.float16) + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.dtype == torch.bfloat16: + x = x.type(torch.bfloat16) + if self.use_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining a convolution is applied. + :param dims: determines the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, dtype=None): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + self.dtype = dtype + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=1, dtype=self.dtype) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: specified, the number of out channels. + :param use_conv: True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines the signal is 1D, 2D, or 3D. + :param up: True, use this block for upsampling. + :param down: True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + activation, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + dtype=None, + efficient_activation=False, + scale_skip_connection=False, + ): + super().__init__() + self.dtype = dtype + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.efficient_activation = efficient_activation + self.scale_skip_connection = scale_skip_connection + + self.in_layers = nn.Sequential( + normalization(channels, dtype=self.dtype), + get_activation(activation), + conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims, dtype=self.dtype) + self.x_upd = Upsample(channels, False, dims, dtype=self.dtype) + elif down: + self.h_upd = Downsample(channels, False, dims, dtype=self.dtype) + self.x_upd = Downsample(channels, False, dims, dtype=self.dtype) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.Identity() if self.efficient_activation else get_activation(activation), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + dtype=self.dtype + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels, dtype=self.dtype), + get_activation(activation), + nn.Dropout(p=dropout), + zero_module(conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1, dtype=self.dtype)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1, dtype=self.dtype) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1, dtype=self.dtype) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + + res = self.skip_connection(x) + h + if self.scale_skip_connection: + res *= 0.7071 # 1 / sqrt(2), https://arxiv.org/pdf/2104.07636.pdf + return res + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + disable_self_attention=False, + encoder_channels=None, + dtype=None, + ): + super().__init__() + self.dtype = dtype + self.channels = channels + self.disable_self_attention = disable_self_attention + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}' + self.num_heads = channels // num_head_channels + self.norm = normalization(channels, dtype=self.dtype) + self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype) + if self.disable_self_attention: + self.qkv = conv_nd(1, channels, channels, 1, dtype=self.dtype) + else: + self.qkv = conv_nd(1, channels, channels * 3, 1, dtype=self.dtype) + self.attention = QKVAttention(self.num_heads, disable_self_attention=disable_self_attention) + + if encoder_channels is not None: + self.encoder_kv = conv_nd(1, encoder_channels, channels * 2, 1, dtype=self.dtype) + self.norm_encoder = normalization(encoder_channels, dtype=self.dtype) + self.proj_out = zero_module(conv_nd(1, channels, channels, 1, dtype=self.dtype)) + + def forward(self, x, encoder_out=None): + b, c, *spatial = x.shape + qkv = self.qkv(self.norm(x).view(b, c, -1)) + if encoder_out is not None: + # from imagen article: https://arxiv.org/pdf/2205.11487.abs + encoder_out = self.norm_encoder(encoder_out) + # # # + encoder_out = self.encoder_kv(encoder_out) + h = self.attention(qkv, encoder_out) + else: + h = self.attention(qkv) + h = self.proj_out(h) + return x + h.reshape(b, c, *spatial) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads, disable_self_attention=False): + super().__init__() + self.n_heads = n_heads + self.disable_self_attention = disable_self_attention + + def forward(self, qkv, encoder_kv=None): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + if self.disable_self_attention: + ch = width // (1 * self.n_heads) + q, = qkv.reshape(bs * self.n_heads, ch * 1, length).split(ch, dim=1) + else: + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1) + if encoder_kv is not None: + assert encoder_kv.shape[1] == self.n_heads * ch * 2 + if self.disable_self_attention: + k, v = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) + else: + ek, ev = encoder_kv.reshape(bs * self.n_heads, ch * 2, -1).split(ch, dim=1) + k = torch.cat([ek, k], dim=-1) + v = torch.cat([ev, v], dim=-1) + scale = 1 / math.sqrt(math.sqrt(ch)) + # if _FORCE_MEM_EFFICIENT_ATTN: + q, k, v = map(lambda t: t.permute(0, 2, 1).contiguous(), (q, k, v)) + a = memory_efficient_attention(q, k, v) + a = a.permute(0, 2, 1) + # else: + # weight = torch.einsum( + # 'bct,bcs->bts', q * scale, k * scale + # ) # More stable with f16 than dividing afterwards + # weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + # a = torch.einsum('bts,bcs->bct', weight, v) + return a.reshape(bs, -1, length) + + +class UNetSplitModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: True, use learned convolutions for upsampling and + downsampling. + :param dims: determines the signal is 1D, 2D, or 3D. + :param num_classes: specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + """ + + def __init__( + self, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + activation, + encoder_dim, + att_pool_heads, + encoder_channels, + image_size, + disable_self_attentions=None, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + precision='32', + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + efficient_activation=False, + scale_skip_connection=False, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.encoder_channels = encoder_channels + self.encoder_dim = encoder_dim + self.efficient_activation = efficient_activation + self.scale_skip_connection = scale_skip_connection + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.dropout = dropout + + self.secondary_device = torch.device("cpu") + + # adapt attention resolutions + if isinstance(attention_resolutions, str): + self.attention_resolutions = [] + for res in attention_resolutions.split(','): + self.attention_resolutions.append(image_size // int(res)) + else: + self.attention_resolutions = attention_resolutions + self.attention_resolutions = tuple(self.attention_resolutions) + # + + # adapt disable self attention resolutions + if not disable_self_attentions: + self.disable_self_attentions = [] + elif disable_self_attentions is True: + self.disable_self_attentions = attention_resolutions + elif isinstance(disable_self_attentions, str): + self.disable_self_attentions = [] + for res in disable_self_attentions.split(','): + self.disable_self_attentions.append(image_size // int(res)) + else: + self.disable_self_attentions = disable_self_attentions + self.disable_self_attentions = tuple(self.disable_self_attentions) + # + + # adapt channel mult + if isinstance(channel_mult, str): + self.channel_mult = tuple(int(ch_mult) for ch_mult in channel_mult.split(',')) + else: + self.channel_mult = tuple(channel_mult) + # + + self.conv_resample = conv_resample + self.num_classes = num_classes + self.dtype = torch.float32 + + self.precision = str(precision) + self.use_fp16 = precision == '16' + if self.precision == '16': + self.dtype = torch.float16 + elif self.precision == 'bf16': + self.dtype = torch.bfloat16 + + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + self.time_embed_dim = model_channels * max(self.channel_mult) + self.time_embed = nn.Sequential( + linear(model_channels, self.time_embed_dim, dtype=self.dtype), + get_activation(activation), + linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype), + ) + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, self.time_embed_dim) + + ch = input_ch = int(self.channel_mult[0] * model_channels) + self.input_blocks = nn.ModuleList( + [TimestepEmbedSequential(conv_nd(dims, in_channels, ch, 3, padding=1, dtype=self.dtype))] + ) + self._feature_size = ch + input_block_chans = [ch] + ds = 1 + + if isinstance(num_res_blocks, int): + num_res_blocks = [num_res_blocks] * len(self.channel_mult) + self.num_res_blocks = num_res_blocks + + for level, mult in enumerate(self.channel_mult): + for _ in range(num_res_blocks[level]): + layers = [ + ResBlock( + ch, + self.time_embed_dim, + dropout, + out_channels=int(mult * model_channels), + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ) + ] + ch = int(mult * model_channels) + if ds in self.attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + dtype=self.dtype, + disable_self_attention=ds in self.disable_self_attentions, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(self.channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + self.time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ) + if resblock_updown + else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + self.time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ), + AttentionBlock( + ch, + num_heads=num_heads, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + dtype=self.dtype, + disable_self_attention=ds in self.disable_self_attentions, + ), + ResBlock( + ch, + self.time_embed_dim, + dropout, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(self.channel_mult))[::-1]: + for i in range(num_res_blocks[level] + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + self.time_embed_dim, + dropout, + out_channels=int(model_channels * mult), + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ) + ] + ch = int(model_channels * mult) + if ds in self.attention_resolutions: + layers.append( + AttentionBlock( + ch, + num_heads=num_heads_upsample, + num_head_channels=num_head_channels, + encoder_channels=encoder_channels, + dtype=self.dtype, + disable_self_attention=ds in self.disable_self_attentions, + ) + ) + if level and i == num_res_blocks[level]: + out_ch = ch + layers.append( + ResBlock( + ch, + self.time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + dtype=self.dtype, + activation=activation, + efficient_activation=self.efficient_activation, + scale_skip_connection=self.scale_skip_connection, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch, dtype=self.dtype), + get_activation(activation), + zero_module(conv_nd(dims, input_ch, out_channels, 3, padding=1, dtype=self.dtype)), + ) + + self.activation_layer = get_activation(activation) if self.efficient_activation else nn.Identity() + + self.encoder_pooling = nn.Sequential( + nn.LayerNorm(encoder_dim, dtype=self.dtype), + AttentionPooling(att_pool_heads, encoder_dim, dtype=self.dtype), + nn.Linear(encoder_dim, self.time_embed_dim, dtype=self.dtype), + nn.LayerNorm(self.time_embed_dim, dtype=self.dtype) + ) + + if encoder_dim != encoder_channels: + self.encoder_proj = nn.Linear(encoder_dim, encoder_channels, dtype=self.dtype) + else: + self.encoder_proj = nn.Identity() + + self.cache = None + + def collect(self): + gc.collect() + torch.cuda.empty_cache() + + def to(self, x, stage=1): # 0, 1, 2, 3 + if isinstance(x, torch.device): + secondary_device = self.secondary_device + if stage == 1: + self.output_blocks.to(secondary_device) + self.out.to(secondary_device) + + # self.collect() + + self.time_embed.to(x) + self.encoder_proj.to(x) + self.encoder_pooling.to(x) + self.input_blocks.to(x) + self.middle_block.to(x) + elif stage == 2: + self.time_embed.to(secondary_device) + self.encoder_proj.to(secondary_device) + self.encoder_pooling.to(secondary_device) + self.input_blocks.to(secondary_device) + self.middle_block.to(secondary_device) + + # self.collect() + + self.output_blocks.to(x) + self.out.to(x) + else: + super().to(x) + + def forward(self, x, timesteps, text_emb, timestep_text_emb=None, aug_emb=None, use_cache=False, **kwargs): + hs = [] + self.to(self.primary_device, stage=1) + emb = self.time_embed(timestep_embedding(timesteps.to(torch.float32), self.model_channels, + dtype=torch.float32).to(self.primary_device).to(self.dtype)) + + if use_cache and self.cache is not None: + encoder_out, encoder_pool = self.cache + else: + text_emb = text_emb.type(self.dtype).to(self.primary_device) + encoder_out = self.encoder_proj(text_emb) + encoder_out = encoder_out.permute(0, 2, 1) # NLC -> NCL + if timestep_text_emb is None: + timestep_text_emb = text_emb + encoder_pool = self.encoder_pooling(timestep_text_emb) + if use_cache: + self.cache = (encoder_out, encoder_pool) + + emb = emb + encoder_pool.to(emb) + + if aug_emb is not None: + emb = emb + aug_emb.to(emb) + + emb = self.activation_layer(emb) + + h = x.type(self.dtype).to(self.primary_device) + + for module in self.input_blocks: + h = module(h, emb, encoder_out) + hs.append(h) + + h = self.middle_block(h, emb, encoder_out) + + self.to(self.primary_device, stage=2) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = module(h, emb, encoder_out) + h = h.type(self.dtype) + h = self.out(h) + return h + + +class SuperResUNetModel(UNetSplitModel): + """ + A text2im model that performs super-resolution. + Expects an extra kwarg `low_res` to condition on a low-resolution image. + """ + + def __init__(self, low_res_diffusion, interpolate_mode='bilinear', *args, **kwargs): + self.low_res_diffusion = low_res_diffusion + self.interpolate_mode = interpolate_mode + super().__init__(*args, **kwargs) + + self.aug_proj = nn.Sequential( + linear(self.model_channels, self.time_embed_dim, dtype=self.dtype), + get_activation(kwargs['activation']), + linear(self.time_embed_dim, self.time_embed_dim, dtype=self.dtype), + ) + + def forward(self, x, timesteps, low_res, aug_level=None, **kwargs): + bs, _, new_height, new_width = x.shape + + align_corners = True + if self.interpolate_mode == 'nearest': + align_corners = None + + upsampled = F.interpolate( + low_res, (new_height, new_width), mode=self.interpolate_mode, align_corners=align_corners + ) + + if aug_level is None: + aug_steps = (np.random.random(bs) * 1000).astype(np.int64) # uniform [0, 1) + aug_steps = torch.from_numpy(aug_steps).to(x.device, dtype=torch.long) + else: + aug_steps = torch.tensor([int(aug_level * 1000)]).repeat(bs).to(x.device, dtype=torch.long) + + upsampled = self.low_res_diffusion.q_sample(upsampled, aug_steps) + x = torch.cat([x, upsampled], dim=1) + + aug_emb = self.aug_proj( + timestep_embedding(aug_steps, self.model_channels, dtype=self.dtype) + ) + return super().forward(x, timesteps, aug_emb=aug_emb, **kwargs) diff --git a/deepfloyd_if/modules/base.py b/deepfloyd_if/modules/base.py index c808a3c..6cad9f3 100644 --- a/deepfloyd_if/modules/base.py +++ b/deepfloyd_if/modules/base.py @@ -14,14 +14,12 @@ from huggingface_hub import hf_hub_download from accelerate.utils import set_module_tensor_to_device - from .. import utils from ..model.respace import create_gaussian_diffusion from .utils import load_model_weights, predict_proba, clip_process_generations class IFBaseModule: - stage = '-' available_models = [] @@ -68,45 +66,49 @@ def use_diffusers(self): return False def embeddings_to_image( - self, t5_embs, low_res=None, *, - style_t5_embs=None, - positive_t5_embs=None, - negative_t5_embs=None, - batch_repeat=1, - dynamic_thresholding_p=0.95, - sample_loop='ddpm', - sample_timestep_respacing='smart185', - dynamic_thresholding_c=1.5, - guidance_scale=7.0, - aug_level=0.25, - positive_mixer=0.15, - blur_sigma=None, - img_size=None, - img_scale=4.0, - aspect_ratio='1:1', - progress=True, - seed=None, - sample_fn=None, - support_noise=None, - support_noise_less_qsample_steps=0, - inpainting_mask=None, - **kwargs, + self, t5_embs, low_res=None, *, + style_t5_embs=None, + positive_t5_embs=None, + negative_t5_embs=None, + batch_repeat=1, + dynamic_thresholding_p=0.95, + sample_loop='ddpm', + sample_timestep_respacing='smart185', + dynamic_thresholding_c=1.5, + guidance_scale=7.0, + aug_level=0.25, + positive_mixer=0.15, + blur_sigma=None, + img_size=None, + img_scale=4.0, + aspect_ratio='1:1', + progress=True, + seed=None, + sample_fn=None, + support_noise=None, + support_noise_less_qsample_steps=0, + inpainting_mask=None, + device=None, + force_size=False, + **kwargs, ): + if device is None: + device = self.model.primary_device self._clear_cache() - image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale) + image_w, image_h = self._get_image_sizes(low_res, img_size, aspect_ratio, img_scale, force_size=force_size) diffusion = self.get_diffusion(sample_timestep_respacing) bs_scale = 2 if positive_t5_embs is None else 3 def model_fn(x_t, ts, **kwargs): half = x_t[: len(x_t) // bs_scale] - combined = torch.cat([half]*bs_scale, dim=0) + combined = torch.cat([half] * bs_scale, dim=0).to(device) model_out = self.model(combined, ts, **kwargs) eps, rest = model_out[:, :3], model_out[:, 3:] if bs_scale == 3: cond_eps, pos_cond_eps, uncond_eps = torch.split(eps, len(eps) // bs_scale, dim=0) half_eps = uncond_eps + guidance_scale * ( - cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps) + cond_eps * (1 - positive_mixer) + pos_cond_eps * positive_mixer - uncond_eps) pos_half_eps = uncond_eps + guidance_scale * (pos_cond_eps - uncond_eps) eps = torch.cat([half_eps, pos_half_eps, half_eps], dim=0) else: @@ -170,7 +172,7 @@ def model_fn(x_t, ts, **kwargs): if low_res is not None: if blur_sigma is not None: low_res = T.GaussianBlur(3, sigma=(blur_sigma, blur_sigma))(low_res) - model_kwargs['low_res'] = torch.cat([low_res]*bs_scale, dim=0).to(self.device) + model_kwargs['low_res'] = torch.cat([low_res] * bs_scale, dim=0).to(self.device) model_kwargs['aug_level'] = aug_level if support_noise is None: @@ -186,7 +188,7 @@ def model_fn(x_t, ts, **kwargs): support_noise[inpainting_mask.cpu().bool() if inpainting_mask is not None else ...], q_sample_steps, ) - noise = noise.repeat(batch_size*bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype) + noise = noise.repeat(batch_size * bs_scale, 1, 1, 1).to(device=self.device, dtype=self.model.dtype) if inpainting_mask is not None: inpainting_mask = inpainting_mask.to(device=self.device, dtype=torch.long) @@ -202,7 +204,7 @@ def model_fn(x_t, ts, **kwargs): dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, inpainting_mask=inpainting_mask, - device=self.device, + device=device, progress=progress, sample_fn=sample_fn, )[:batch_size] @@ -216,7 +218,7 @@ def model_fn(x_t, ts, **kwargs): model_kwargs=model_kwargs, dynamic_thresholding_p=dynamic_thresholding_p, dynamic_thresholding_c=dynamic_thresholding_c, - device=self.device, + device=device, progress=progress, sample_fn=sample_fn, )[:batch_size] @@ -311,7 +313,7 @@ def to_images(self, generations, disable_watermark=False): def show(self, pil_images, nrow=None, size=10): if nrow is None: - nrow = round(len(pil_images)**0.5) + nrow = round(len(pil_images) ** 0.5) imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) if not isinstance(imgs, list): @@ -330,19 +332,21 @@ def show(self, pil_images, nrow=None, size=10): def _clear_cache(self): self.model.cache = None - def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale): + def _get_image_sizes(self, low_res, img_size, aspect_ratio, img_scale, force_size=True): if low_res is not None: bs, c, h, w = low_res.shape - image_h, image_w = int((h*img_scale)//32)*32, int((w*img_scale//32))*32 + image_h, image_w = int((h * img_scale) // 32) * 32, int((w * img_scale // 32)) * 32 else: + if force_size: + return img_size[0], img_size[1] scale_w, scale_h = aspect_ratio.split(':') scale_w, scale_h = int(scale_w), int(scale_h) coef = scale_w / scale_h image_h, image_w = img_size, img_size if coef >= 1: - image_w = int(round(img_size/8 * coef) * 8) + image_w = int(round(img_size / 8 * coef) * 8) else: - image_h = int(round(img_size/8 / coef) * 8) + image_h = int(round(img_size / 8 / coef) * 8) assert image_h % 8 == 0 assert image_w % 8 == 0 diff --git a/deepfloyd_if/modules/stage_I.py b/deepfloyd_if/modules/stage_I.py index a9c62cc..92883ac 100644 --- a/deepfloyd_if/modules/stage_I.py +++ b/deepfloyd_if/modules/stage_I.py @@ -1,15 +1,16 @@ # -*- coding: utf-8 -*- import accelerate +import torch from .base import IFBaseModule -from ..model import UNetModel +from ..model import UNetModel, UNetSplitModel class IFStageI(IFBaseModule): stage = 'I' available_models = ['IF-I-M-v1.0', 'IF-I-L-v1.0', 'IF-I-XL-v1.0'] - def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs): + def __init__(self, *args, model_kwargs=None, pil_img_size=64, use_split=True, **kwargs): """ :param conf_or_path: :param device: @@ -19,16 +20,23 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=64, **kwargs): super().__init__(*args, pil_img_size=pil_img_size, **kwargs) model_params = dict(self.conf.params) model_params.update(model_kwargs or {}) + UNetClass = UNetSplitModel if use_split else UNetModel with accelerate.init_empty_weights(): - self.model = UNetModel(**model_params) + self.model = UNetClass(**model_params) self.model = self.load_checkpoint(self.model, self.dir_or_name) self.model.eval().to(self.device) + def to(self, x, stage=1, secondary_device=torch.device("cpu")): # 0, 1, 2, 3 + if isinstance(x, torch.device): + self.model.primary_device = x + self.model.secondary_device = secondary_device + self.model.to(x) + def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, dynamic_thresholding_p=0.95, sample_loop='ddpm', positive_mixer=0.25, sample_timestep_respacing='150', dynamic_thresholding_c=1.5, guidance_scale=7.0, - aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, **kwargs): - + img_size=(64, 64), aspect_ratio='1:1', progress=True, seed=None, sample_fn=None, + force_size=False, **kwargs): return super().embeddings_to_image( t5_embs=t5_embs, style_t5_embs=style_t5_embs, @@ -40,11 +48,12 @@ def embeddings_to_image(self, t5_embs, style_t5_embs=None, positive_t5_embs=None sample_loop=sample_loop, sample_timestep_respacing=sample_timestep_respacing, guidance_scale=guidance_scale, - img_size=64, + img_size=img_size, aspect_ratio=aspect_ratio, progress=progress, seed=seed, sample_fn=sample_fn, positive_mixer=positive_mixer, + force_size=force_size, **kwargs ) diff --git a/deepfloyd_if/modules/stage_II.py b/deepfloyd_if/modules/stage_II.py index d14b838..d4eb0af 100644 --- a/deepfloyd_if/modules/stage_II.py +++ b/deepfloyd_if/modules/stage_II.py @@ -22,7 +22,7 @@ def embeddings_to_image( self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, aug_level=0.25, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, sample_loop='ddpm', sample_timestep_respacing='smart50', guidance_scale=4.0, img_scale=4.0, positive_mixer=0.5, - progress=True, seed=None, sample_fn=None, **kwargs): + progress=True, seed=None, sample_fn=None, device=None, **kwargs): return super().embeddings_to_image( t5_embs=t5_embs, low_res=low_res, @@ -42,5 +42,9 @@ def embeddings_to_image( progress=progress, seed=seed, sample_fn=sample_fn, + device=device, **kwargs ) + + def to(self, x): + self.model.to(x) diff --git a/deepfloyd_if/modules/stage_III_sd_x4.py b/deepfloyd_if/modules/stage_III_sd_x4.py index 307fad2..2f6d6ef 100644 --- a/deepfloyd_if/modules/stage_III_sd_x4.py +++ b/deepfloyd_if/modules/stage_III_sd_x4.py @@ -9,7 +9,6 @@ class StableStageIII(IFBaseModule): - available_models = ['stable-diffusion-x4-upscaler'] def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs): @@ -20,7 +19,7 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs): ' Please run `pip install diffusers --upgrade`' ) - model_id = os.path.join('stabilityai', self.dir_or_name) + model_id = 'stabilityai' + "/" + self.dir_or_name.strip() model_kwargs = model_kwargs or {} precision = str(model_kwargs.get('precision', '16')) @@ -34,8 +33,8 @@ def __init__(self, *args, model_kwargs=None, pil_img_size=1024, **kwargs): self.model = DiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, token=self.hf_token) self.model.to(self.device) - if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')): - self.model.enable_xformers_memory_efficient_attention() + # if bool(os.environ.get('FORCE_MEM_EFFICIENT_ATTN')): + self.model.enable_xformers_memory_efficient_attention() @property def use_diffusers(self): @@ -46,13 +45,10 @@ def use_diffusers(self): return False def embeddings_to_image( - self, low_res, t5_embs, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, + self, low_res, prompt, style_t5_embs=None, positive_t5_embs=None, negative_t5_embs=None, batch_repeat=1, aug_level=0.0, blur_sigma=None, dynamic_thresholding_p=0.95, dynamic_thresholding_c=1.0, positive_mixer=0.5, sample_loop='ddpm', sample_timestep_respacing='75', guidance_scale=4.0, img_scale=4.0, - progress=True, seed=None, sample_fn=None, **kwargs): - - prompt = kwargs.pop('prompt') - noise_level = kwargs.pop('noise_level', 20) + progress=True, seed=None, sample_fn=None, device=None, **kwargs): if sample_loop == 'ddpm': self.model.scheduler = DDPMScheduler.from_config(self.model.scheduler.config) @@ -64,13 +60,9 @@ def embeddings_to_image( self.model.set_progress_bar_config(disable=not progress) generator = torch.manual_seed(seed) - prompt = sum([batch_repeat * [p] for p in prompt], []) - low_res = low_res.repeat(batch_repeat, 1, 1, 1) - metadata = { - 'image': low_res, + 'image': low_res, # 1 3 256 256 'prompt': prompt, - 'noise_level': noise_level, 'generator': generator, 'guidance_scale': guidance_scale, 'num_inference_steps': num_inference_steps, @@ -82,3 +74,6 @@ def embeddings_to_image( sample = self._IFBaseModule__validate_generations(images) return sample, metadata + + def to(self, x): + self.model.to(x) diff --git a/deepfloyd_if/modules/t5.py b/deepfloyd_if/modules/t5.py index 7443426..1cbd3dd 100644 --- a/deepfloyd_if/modules/t5.py +++ b/deepfloyd_if/modules/t5.py @@ -12,9 +12,9 @@ class T5Embedder: - available_models = ['t5-v1_1-xxl'] - bad_punct_regex = re.compile(r'['+'#®•©™&@·º½¾¿¡§~'+'\)'+'\('+'\]'+'\['+'\}'+'\{'+'\|'+'\\'+'\/'+'\*' + r']{1,}') # noqa + bad_punct_regex = re.compile( + r'[' + '#®•©™&@·º½¾¿¡§~' + '\)' + '\(' + '\]' + '\[' + '\}' + '\{' + '\|' + '\\' + '\/' + '\*' + r']{1,}') # noqa def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_token=None, use_text_preprocessing=True, t5_model_kwargs=None, torch_dtype=None, use_offload_folder=None): @@ -75,6 +75,22 @@ def __init__(self, device, dir_or_name='t5-v1_1-xxl', *, cache_dir=None, hf_toke self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = T5EncoderModel.from_pretrained(path, **t5_model_kwargs).eval() + self.saved_path = path + self.saved_kwargs = t5_model_kwargs + self.loaded = True + + def reload(self, dmap): + del self.model + torch.cuda.empty_cache() + self.saved_kwargs["device_map"] = dmap + self.model = T5EncoderModel.from_pretrained(self.saved_path, **self.saved_kwargs).eval() + self.loaded = True + + def to(self, x): + self.model.to(x) + + def cpu(self): + self.model.base_model().to(torch.device("cpu")) def get_text_embeddings(self, texts): texts = [self.text_preprocessing(text) for text in texts] @@ -121,10 +137,12 @@ def clean_caption(self, caption): caption = re.sub('', 'person', caption) # urls: caption = re.sub( - r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + r'\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', + # noqa '', caption) # regex for urls caption = re.sub( - r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', # noqa + r'\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))', + # noqa '', caption) # regex for urls # html: caption = BeautifulSoup(caption, features='html.parser').text @@ -150,7 +168,8 @@ def clean_caption(self, caption): # все виды тире / all types of dash --> "-" caption = re.sub( - r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', # noqa + r'[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+', + # noqa '-', caption) # кавычки к одному стандарту diff --git a/deepfloyd_if/pipelines/dream.py b/deepfloyd_if/pipelines/dream.py index 13d7479..0b98621 100644 --- a/deepfloyd_if/pipelines/dream.py +++ b/deepfloyd_if/pipelines/dream.py @@ -5,22 +5,22 @@ def dream( - t5, - if_I, - if_II=None, - if_III=None, - *, - prompt, - style_prompt=None, - negative_prompt=None, - seed=None, - aspect_ratio='1:1', - if_I_kwargs=None, - if_II_kwargs=None, - if_III_kwargs=None, - progress=True, - return_tensors=False, - disable_watermark=False, + t5, + if_I, + if_II=None, + if_III=None, + *, + prompt, + style_prompt=None, + negative_prompt=None, + seed=None, + aspect_ratio='1:1', + if_I_kwargs=None, + if_II_kwargs=None, + if_III_kwargs=None, + progress=True, + return_tensors=False, + disable_watermark=False, ): """ Generate pictures using text description! @@ -108,18 +108,18 @@ def dream( stageIII_generations = [] for idx in range(len(stageII_generations)): if if_III.use_diffusers: - if_III_kwargs['prompt'] = prompt[idx: idx+1] + if_III_kwargs['prompt'] = prompt[idx: idx + 1] - if_III_kwargs['low_res'] = stageII_generations[idx:idx+1] + if_III_kwargs['low_res'] = stageII_generations[idx:idx + 1] if_III_kwargs['seed'] = seed - if_III_kwargs['t5_embs'] = t5_embs[idx:idx+1] + if_III_kwargs['t5_embs'] = t5_embs[idx:idx + 1] if_III_kwargs['progress'] = progress style_t5_embs = if_I_kwargs.get('style_t5_embs') if style_t5_embs is not None: - style_t5_embs = style_t5_embs[idx:idx+1] + style_t5_embs = style_t5_embs[idx:idx + 1] positive_t5_embs = if_I_kwargs.get('positive_t5_embs') if positive_t5_embs is not None: - positive_t5_embs = positive_t5_embs[idx:idx+1] + positive_t5_embs = positive_t5_embs[idx:idx + 1] if_III_kwargs['style_t5_embs'] = style_t5_embs if_III_kwargs['positive_t5_embs'] = positive_t5_embs diff --git a/deepfloyd_if/pipelines/optimized_dream.py b/deepfloyd_if/pipelines/optimized_dream.py new file mode 100644 index 0000000..2267db2 --- /dev/null +++ b/deepfloyd_if/pipelines/optimized_dream.py @@ -0,0 +1,100 @@ +import gc +import numpy as np + +import torch.cuda +from PIL import Image + + +def run_garbage_collection(): + gc.collect() + torch.cuda.empty_cache() + + +def to_pil_images(images: torch.Tensor) -> list[Image]: + images = (images / 2 + 0.5).clamp(0, 1) + images = images.cpu().permute(0, 2, 3, 1).float().numpy() + images = np.round(images * 255).astype(np.uint8) + return [Image.fromarray(image) for image in images] + + +def run_stage1( + model, + t5_embs, + negative_t5_embs, + seed: int = 0, + num_images: int = 1, + guidance_scale_1: float = 7.0, + custom_timesteps_1: str = 'smart100', + num_inference_steps_1: int = 100, + aspect_ratio='1:1', + img_size=(64, 64) +): + run_garbage_collection() + + if custom_timesteps_1 == "none": + custom_timesteps_1 = str(num_inference_steps_1) + + ret_images1, ret_images2 = [], [] + for _ in range(num_images): + images, _ = model.embeddings_to_image(t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, + guidance_scale=guidance_scale_1, img_size=img_size, + sample_timestep_respacing=custom_timesteps_1, + seed=seed, aspect_ratio=aspect_ratio, force_size=True + ) + pil_images_I = model.to_images(images, disable_watermark=True) + ret_images1.append(pil_images_I[0]) + ret_images2.append(images[0]) + + return ret_images2, ret_images1 + + +def run_stage2( + model, + t5_embs, + negative_t5_embs, + images, + seed: int = 0, + guidance_scale: float = 4.0, + custom_timesteps_2: str = 'smart50', + num_inference_steps_2: int = 50, + disable_watermark: bool = True, + device=None +) -> Image: + run_garbage_collection() + + if custom_timesteps_2 == "none": + custom_timesteps_2 = str(num_inference_steps_2) + stageII_generations, _ = model.embeddings_to_image(low_res=images, + t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, + guidance_scale=guidance_scale, + sample_timestep_respacing=custom_timesteps_2, + seed=seed, device=device) + pil_images_II = model.to_images(stageII_generations, disable_watermark=disable_watermark) + return stageII_generations, pil_images_II + + +def run_stage3( + model, + prompt, + negative_t5_embs, + images, + seed: int = 0, + guidance_scale: float = 4.0, + custom_timesteps_2: str = 'smart50', + num_inference_steps_2: int = 50, + disable_watermark: bool = True, + device=None +) -> Image: + run_garbage_collection() + stageII_generations, _ = model.embeddings_to_image(low_res=images, + prompt=prompt, + negative_t5_embs=negative_t5_embs, + guidance_scale=guidance_scale, + sample_timestep_respacing=num_inference_steps_2, + num_images_per_prompt=1, + noise_level=20, + seed=seed, device=device) + pil_images_III = model.to_images(stageII_generations, disable_watermark=disable_watermark) + return pil_images_III diff --git a/run_ui.py b/run_ui.py new file mode 100644 index 0000000..ff63bca --- /dev/null +++ b/run_ui.py @@ -0,0 +1,408 @@ +import argparse +import gc + +import numpy as np +from accelerate import dispatch_model + +from deepfloyd_if.modules import IFStageI, IFStageII, StableStageIII +from deepfloyd_if.modules.t5 import T5Embedder +from deepfloyd_if.pipelines.optimized_dream import run_stage1, run_stage2, run_stage3 + +import torch + +import gradio as gr + +from ui_files.utils import randomize_seed_fn, show_gallery_view, update_upscale_button, get_stage2_index, \ + check_if_stage2_selected, show_upscaled_view, get_device_map + +device = torch.device(0) +if_I = IFStageI('IF-I-XL-v1.0', device=torch.device("cpu")) +if_I.to(torch.float16) # half +if_II = IFStageII('IF-II-L-v1.0', device=torch.device("cpu")) +if_I.to(torch.float16) # half +if_III = StableStageIII('stable-diffusion-x4-upscaler', device=torch.device("cpu")) +t5_device = torch.device(0) +t5 = T5Embedder(device=t5_device, t5_model_kwargs={"low_cpu_mem_usage": True, + "torch_dtype": torch.float16, + "device_map": get_device_map(t5_device), + "offload_folder": True}) + + +def switch_devices(stage): + if stage == 0: + if_I.to(torch.device("cpu")) + if_II.to(torch.device("cpu")) + if_III.to(torch.device("cpu")) + + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + + if not t5.loaded: + print("Reloading t5") + t5.reload(get_device_map(t5_device, all2cpu=False)) + # dispatch_model(t5.model, get_device_map(t5_device, all2cpu=False)) + if stage == 1: + # t5.model.cpu() + dispatch_model(t5.model, get_device_map(t5_device, all2cpu=True)) + t5.loaded = False + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + if_I.to(torch.device(0)) + elif stage == 2: + if_I.to(torch.device("cpu")) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + if_II.to(torch.device(0)) + elif stage == 3: + if_II.to(torch.device("cpu")) + gc.collect() + torch.cuda.empty_cache() + torch.cuda.synchronize() + if_III.to(torch.device(0)) + + +def process_and_run_stage1(prompt, + negative_prompt, + seed_1, + num_images, + guidance_scale_1, + custom_timesteps_1, + num_inference_steps_1, + # aspect_ratio, + width, + height): + global t5_embs, negative_t5_embs, images + print("Encoding prompts..") + switch_devices(stage=0) + t5_embs = t5.get_text_embeddings([prompt] * num_images) + negative_t5_embs = t5.get_text_embeddings([negative_prompt] * num_images) + switch_devices(stage=1) + t5_embs = t5_embs.to(if_I.device) + negative_t5_embs = negative_t5_embs.to(if_I.device) + print("Encoded. Running 1st stage") + images, images_ret = run_stage1( + if_I, + t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, + seed=seed_1, + num_images=num_images, + guidance_scale_1=guidance_scale_1, + custom_timesteps_1=custom_timesteps_1, + num_inference_steps_1=num_inference_steps_1, + # aspect_ratio=aspect_ratio, + img_size=(width, height) + ) + return images_ret + + +def process_and_run_stage2( + index, + seed_2, + guidance_scale_2, + custom_timesteps_2, + num_inference_steps_2): + global t5_embs, negative_t5_embs, images + print("Stage 2..") + switch_devices(stage=2) + images, images_ret = run_stage2( + if_II, + t5_embs=t5_embs, + negative_t5_embs=negative_t5_embs, + images=images[index].unsqueeze(0).to(device), + seed=seed_2, + guidance_scale=guidance_scale_2, + custom_timesteps_2=custom_timesteps_2, + num_inference_steps_2=num_inference_steps_2, + device=device + ) + return images_ret + + +def process_and_run_stage3( + index, + prompt, + seed_2, + guidance_scale_2, + custom_timesteps_2, + num_inference_steps_2): + global t5_embs, negative_t5_embs, images + print("Stage 3..") + switch_devices(stage=3) + return run_stage3( + if_III, + prompt=prompt, + negative_t5_embs=negative_t5_embs, + images=images[index].unsqueeze(0).to(device), + seed=seed_2, + guidance_scale=guidance_scale_2, + custom_timesteps_2=custom_timesteps_2, + num_inference_steps_2=num_inference_steps_2, + device=device + ) + + +def create_ui(args): + with gr.Blocks(css='ui_files/style.css') as demo: + with gr.Box(): + with gr.Row(elem_id='prompt-container').style(equal_height=True): + with gr.Column(): + prompt = gr.Text( + label='Prompt', + show_label=False, + max_lines=1, + placeholder='Enter your prompt', + elem_id='prompt-text-input', + ).style(container=False) + negative_prompt = gr.Text( + label='Negative prompt', + show_label=False, + max_lines=1, + placeholder='Enter a negative prompt', + elem_id='negative-prompt-text-input', + ).style(container=False) + width = gr.Slider(32, 128, value=64, step=8, label="Width").style(container=False) + height = gr.Slider(32, 128, value=64, step=8, label="Height").style(container=False) + # aspect_ratio_1 = gr.Radio( + # ["16:9", "4:3", "1:1", "3:4", "9:16"], value="1:1", label="Aspect ratio" + # ).style(container=False) + generate_button = gr.Button('Generate').style(full_width=False) + + with gr.Column() as gallery_view: + gallery = gr.Gallery(label='Stage 1 results', + show_label=False, + elem_id='gallery').style( + columns=args.GALLERY_COLUMN_NUM, + object_fit='contain') + gr.Markdown('Pick your favorite generation to upscale.') + with gr.Row(): + upscale_to_256_button = gr.Button( + 'Upscale to 256px', + visible=args.DISABLE_SD_X4_UPSCALER, + interactive=False) + with gr.Column(visible=False) as upscale_view: + result = gr.Gallery(label='Result', + show_label=False, + elem_id='upscaled-image').style( + columns=args.GALLERY_COLUMN_NUM, + object_fit='contain') + back_to_selection_button = gr.Button('Back to selection') + upscale_button = gr.Button('Upscale 4x', + interactive=False, + visible=True) + with gr.Accordion('Advanced options', + open=False, + visible=args.SHOW_ADVANCED_OPTIONS): + with gr.Tabs(): + with gr.Tab(label='Generation'): + seed_1 = gr.Slider(label='Seed', + minimum=0, + maximum=args.MAX_SEED, + step=1, + value=0) + randomize_seed_1 = gr.Checkbox(label='Randomize seed', + value=True) + guidance_scale_1 = gr.Slider(label='Guidance scale', + minimum=1, + maximum=20, + step=0.1, + value=7.0) + custom_timesteps_1 = gr.Dropdown( + label='Custom timesteps 1', + choices=[ + 'none', + 'fast27', + 'smart27', + 'smart50', + 'smart100', + 'smart185', + ], + value="smart50", + visible=True) + num_inference_steps_1 = gr.Slider( + label='Number of inference steps', + minimum=1, + maximum=200, + step=1, + value=100, + visible=True) + num_images = gr.Slider(label='Number of images', + minimum=1, + maximum=4, + step=1, + value=1, + visible=True) + with gr.Tab(label='Super-resolution 1'): + seed_2 = gr.Slider(label='Seed', + minimum=0, + maximum=args.MAX_SEED, + step=1, + value=0) + randomize_seed_2 = gr.Checkbox(label='Randomize seed', + value=True) + guidance_scale_2 = gr.Slider(label='Guidance scale', + minimum=1, + maximum=20, + step=0.1, + value=4.0) + custom_timesteps_2 = gr.Dropdown( + label='Custom timesteps 2', + choices=[ + 'none', + 'fast27', + 'smart27', + 'smart50', + 'smart100', + 'smart185', + ], + value="smart27", + visible=True) + num_inference_steps_2 = gr.Slider( + label='Number of inference steps', + minimum=1, + maximum=200, + step=1, + value=50, + visible=True) + with gr.Tab(label='Super-resolution 2'): + seed_3 = gr.Slider(label='Seed', + minimum=0, + maximum=args.MAX_SEED, + step=1, + value=0) + randomize_seed_3 = gr.Checkbox(label='Randomize seed', + value=True) + guidance_scale_3 = gr.Slider(label='Guidance scale', + minimum=1, + maximum=20, + step=0.1, + value=9.0) + num_inference_steps_3 = gr.Slider( + label='Number of inference steps', + minimum=1, + maximum=200, + step=1, + value=60, + visible=True) + with gr.Box(): + with gr.Row(): + with gr.Accordion(label='Hidden params'): + selected_index_for_stage2 = gr.Number( + label='Selected index for Stage 2', value=-1, precision=0) + + generate_button.click( + process_and_run_stage1, + [prompt, + negative_prompt, + seed_1, + num_images, + guidance_scale_1, + custom_timesteps_1, + num_inference_steps_1, + # aspect_ratio_1, + width, + height], + gallery + ) + + gallery.select( + fn=get_stage2_index, + outputs=selected_index_for_stage2, + queue=False, + ) + + selected_index_for_stage2.change( + fn=update_upscale_button, + inputs=selected_index_for_stage2, + outputs=[ + upscale_button, + upscale_to_256_button, + ], + queue=False, + ) + + upscale_to_256_button.click( + fn=check_if_stage2_selected, + inputs=selected_index_for_stage2, + queue=False, + ).then( + fn=randomize_seed_fn, + inputs=[seed_2, randomize_seed_2], + outputs=seed_2, + queue=False, + ).then( + fn=show_upscaled_view, + outputs=[ + gallery_view, + upscale_view, + ], + queue=False, + ).then( + fn=process_and_run_stage2, + inputs=[ + selected_index_for_stage2, + seed_2, + guidance_scale_2, + custom_timesteps_2, + num_inference_steps_2, + ], + outputs=result, + ) + + upscale_button.click( + fn=check_if_stage2_selected, + inputs=selected_index_for_stage2, + queue=False, + ).then( + fn=randomize_seed_fn, + inputs=[seed_2, randomize_seed_2], + outputs=seed_2, + queue=False, + ).then( + fn=randomize_seed_fn, + inputs=[seed_3, randomize_seed_3], + outputs=seed_3, + queue=False, + ).then( + fn=show_upscaled_view, + outputs=[ + gallery_view, + upscale_view, + ], + queue=False, + ).then( + fn=process_and_run_stage3, + inputs=[ + selected_index_for_stage2, + prompt, + seed_2, + guidance_scale_3, + custom_timesteps_2, + num_inference_steps_3, + ], + outputs=result, + ) + + back_to_selection_button.click( + fn=show_gallery_view, + outputs=[ + gallery_view, + upscale_view, + ], + queue=False, + ) + return demo + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='IF UI settings') + parser.add_argument('--GALLERY_COLUMN_NUM', type=int, default=4) + parser.add_argument('--DISABLE_SD_X4_UPSCALER', type=bool, default=True) + parser.add_argument('--SHOW_ADVANCED_OPTIONS', type=bool, default=True) + parser.add_argument('--MAX_SEED', type=int, default=np.iinfo(np.int32).max) + + demo = create_ui(parser.parse_args()) + demo.launch() diff --git a/ui_files/style.css b/ui_files/style.css new file mode 100644 index 0000000..17fb109 --- /dev/null +++ b/ui_files/style.css @@ -0,0 +1,238 @@ +/* +This CSS file is modified from: +https://huggingface.co/spaces/stabilityai/stable-diffusion/blob/2794a3c3ba66115c307075098e713f572b08bf80/app.py +*/ + +h1 { + text-align: center; +} + +.gradio-container { + font-family: 'IBM Plex Sans', sans-serif; +} + +.gr-button { + color: white; + border-color: black; + background: black; +} + +input[type='range'] { + accent-color: black; +} + +.dark input[type='range'] { + accent-color: #dfdfdf; +} + +.container { + max-width: 730px; + margin: auto; + padding-top: 1.5rem; +} + +#gallery { + min-height: auto; + height: 185px; + margin-top: 15px; + margin-left: auto; + margin-right: auto; + border-bottom-right-radius: .5rem !important; + border-bottom-left-radius: .5rem !important; +} +#gallery .grid-wrap, #gallery .empty{ + height: 185px; + min-height: 185px; +} +#gallery .preview{ + height: 185px; + min-height: 185px!important; +} +#gallery>div>.h-full { + min-height: 20rem; +} + +.details:hover { + text-decoration: underline; +} + +.gr-button { + white-space: nowrap; +} + +.gr-button:focus { + border-color: rgb(147 197 253 / var(--tw-border-opacity)); + outline: none; + box-shadow: var(--tw-ring-offset-shadow), var(--tw-ring-shadow), var(--tw-shadow, 0 0 #0000); + --tw-border-opacity: 1; + --tw-ring-offset-shadow: var(--tw-ring-inset) 0 0 0 var(--tw-ring-offset-width) var(--tw-ring-offset-color); + --tw-ring-shadow: var(--tw-ring-inset) 0 0 0 calc(3px var(--tw-ring-offset-width)) var(--tw-ring-color); + --tw-ring-color: rgb(191 219 254 / var(--tw-ring-opacity)); + --tw-ring-opacity: .5; +} + +#advanced-btn { + font-size: .7rem !important; + line-height: 19px; + margin-top: 12px; + margin-bottom: 12px; + padding: 2px 8px; + border-radius: 14px !important; +} + +#advanced-options { + display: none; + margin-bottom: 20px; +} + +.footer { + margin-bottom: 45px; + margin-top: 35px; + text-align: center; + border-bottom: 1px solid #e5e5e5; +} + +.footer>p { + font-size: .8rem; + display: inline-block; + padding: 0 10px; + transform: translateY(10px); + background: white; +} + +.dark .footer { + border-color: #303030; +} + +.dark .footer>p { + background: #0b0f19; +} + +.acknowledgments h4 { + margin: 1.25em 0 .25em 0; + font-weight: bold; + font-size: 115%; +} + +.animate-spin { + animation: spin 1s linear infinite; +} + +@keyframes spin { + from { + transform: rotate(0deg); + } + + to { + transform: rotate(360deg); + } +} + +#share-btn-container { + display: flex; + padding-left: 0.5rem !important; + padding-right: 0.5rem !important; + background-color: #000000; + justify-content: center; + align-items: center; + border-radius: 9999px !important; + width: 13rem; + margin-top: 10px; + margin-left: auto; +} + +#share-btn { + all: initial; + color: #ffffff; + font-weight: 600; + cursor: pointer; + font-family: 'IBM Plex Sans', sans-serif; + margin-left: 0.5rem !important; + padding-top: 0.25rem !important; + padding-bottom: 0.25rem !important; + right: 0; +} + +#share-btn * { + all: unset; +} + +#share-btn-container div:nth-child(-n+2) { + width: auto !important; + min-height: 0px !important; +} + +#share-btn-container .wrap { + display: none !important; +} + +.gr-form { + flex: 1 1 50%; + border-top-right-radius: 0; + border-bottom-right-radius: 0; +} + +#prompt-container { + gap: 0; +} + +#prompt-text-input, +#negative-prompt-text-input { + padding: .45rem 0.625rem +} + +#component-16 { + border-top-width: 1px !important; + margin-top: 1em +} + +.image_duplication { + position: absolute; + width: 100px; + left: 50px +} + +#component-0 { + max-width: 730px; + margin: auto; + padding-top: 1.5rem; +} + +#upscaled-image img { + object-fit: scale-down; +} +/* share button */ +#share-btn-container { + display: flex; + padding-left: 0.5rem !important; + padding-right: 0.5rem !important; + background-color: #000000; + justify-content: center; + align-items: center; + border-radius: 9999px !important; + width: 13rem; + margin-top: 10px; + margin-left: auto; + flex: unset !important; +} +#share-btn { + all: initial; + color: #ffffff; + font-weight: 600; + cursor: pointer; + font-family: 'IBM Plex Sans', sans-serif; + margin-left: 0.5rem !important; + padding-top: 0.25rem !important; + padding-bottom: 0.25rem !important; + right:0; +} +#share-btn * { + all: unset !important; +} +#share-btn-container div:nth-child(-n+2){ + width: auto !important; + min-height: 0px !important; +} +#share-btn-container .wrap { + display: none !important; +} \ No newline at end of file diff --git a/ui_files/utils.py b/ui_files/utils.py new file mode 100644 index 0000000..541a577 --- /dev/null +++ b/ui_files/utils.py @@ -0,0 +1,75 @@ +import gradio as gr +import numpy as np +import random + +import torch + + +def _update_result_view(show_gallery: bool) -> tuple[dict, dict]: + return gr.update(visible=show_gallery), gr.update(visible=not show_gallery) + + +def show_gallery_view() -> tuple[dict, dict]: + return _update_result_view(True) + + +def show_upscaled_view() -> tuple[dict, dict]: + return _update_result_view(False) + + +def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: + if randomize_seed: + seed = random.randint(0, np.iinfo(np.int32).max) + return seed + + +def update_upscale_button(selected_index: int) -> tuple[dict, dict]: + if selected_index == -1: + return gr.update(interactive=False), gr.update(interactive=False) + else: + return gr.update(interactive=True), gr.update(interactive=True) + + +def get_stage2_index(evt: gr.SelectData) -> int: + return evt.index + + +def check_if_stage2_selected(index: int) -> None: + if index == -1: + raise gr.Error( + 'You need to select the image you would like to upscale from the Stage 1 results by clicking.' + ) + + +def get_device_map(device, all2cpu=False): + device = device if not all2cpu else torch.device("cpu") + return { + 'shared': device, + 'encoder.embed_tokens': device, + 'encoder.block.0': device, + 'encoder.block.1': device, + 'encoder.block.2': device, + 'encoder.block.3': device, + 'encoder.block.4': device, + 'encoder.block.5': device, + 'encoder.block.6': device, + 'encoder.block.7': device, + 'encoder.block.8': device, + 'encoder.block.9': device, + 'encoder.block.10': device, + 'encoder.block.11': device, + 'encoder.block.12': 'cpu', + 'encoder.block.13': 'cpu', + 'encoder.block.14': 'cpu', + 'encoder.block.15': 'cpu', + 'encoder.block.16': 'cpu', + 'encoder.block.17': 'cpu', + 'encoder.block.18': 'cpu', + 'encoder.block.19': 'cpu', + 'encoder.block.20': 'cpu', + 'encoder.block.21': 'cpu', + 'encoder.block.22': 'cpu', + 'encoder.block.23': 'cpu', + 'encoder.final_layer_norm': 'cpu', + 'encoder.dropout': 'cpu', + }