diff --git a/scripts/audio2vid.py b/scripts/audio2vid.py index 070eb68..446b8fe 100644 --- a/scripts/audio2vid.py +++ b/scripts/audio2vid.py @@ -21,7 +21,7 @@ from configs.prompts.test_cases import TestCasesDict from src.models.pose_guider import PoseGuider -from src.models.unet_2d_condition import UNet2DConditionModel +from src.models.model_util import load_models, torch_gc, get_torch_device from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from src.utils.util import get_fps, read_frames, save_videos_grid @@ -56,21 +56,24 @@ def main(): weight_dtype = torch.float16 else: weight_dtype = torch.float32 - + + device = get_torch_device() + audio_infer_config = OmegaConf.load(config.audio_inference_config) # prepare model a2m_model = Audio2MeshModel(audio_infer_config['a2m_model']) a2m_model.load_state_dict(torch.load(audio_infer_config['pretrained_model']['a2m_ckpt']), strict=False) - a2m_model.cuda().eval() - - vae = AutoencoderKL.from_pretrained( - config.pretrained_vae_path, - ).to("cuda", dtype=weight_dtype) - - reference_unet = UNet2DConditionModel.from_pretrained( - config.pretrained_base_model_path, - subfolder="unet", - ).to(dtype=weight_dtype, device="cuda") + a2m_model.to(device).eval() + + (_,_,unet,_,vae,) = load_models( + config.pretrained_base_model_path, + scheduler_name="", + v2=False, + v_pred=False, + weight_dtype=weight_dtype, + ) + vae = vae.to(device, dtype=weight_dtype) + reference_unet = unet.to(dtype=weight_dtype, device=device) inference_config_path = config.inference_config infer_config = OmegaConf.load(inference_config_path) @@ -79,14 +82,14 @@ def main(): config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, - ).to(dtype=weight_dtype, device="cuda") + ).to(dtype=weight_dtype, device=device) - pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention + pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device=device, dtype=weight_dtype) # not use cross attention image_enc = CLIPVisionModelWithProjection.from_pretrained( config.image_encoder_path - ).to(dtype=weight_dtype, device="cuda") + ).to(dtype=weight_dtype, device=device) sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) @@ -115,7 +118,7 @@ def main(): pose_guider=pose_guider, scheduler=scheduler, ) - pipe = pipe.to("cuda", dtype=weight_dtype) + pipe = pipe.to(device, dtype=weight_dtype) date_str = datetime.now().strftime("%Y%m%d") time_str = datetime.now().strftime("%H%M") @@ -145,7 +148,7 @@ def main(): ref_pose = vis.draw_landmarks((ref_image_np.shape[1], ref_image_np.shape[0]), lmks, normed=True) sample = prepare_audio_feature(audio_path, wav2vec_model_path=audio_infer_config['a2m_model']['model_path']) - sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().cuda() + sample['audio_feature'] = torch.from_numpy(sample['audio_feature']).float().to(device) sample['audio_feature'] = sample['audio_feature'].unsqueeze(0) # inference @@ -218,6 +221,7 @@ def main(): stream = ffmpeg.input(save_path) audio = ffmpeg.input(audio_path) ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run() + torch_gc() os.remove(save_path) if __name__ == "__main__": diff --git a/scripts/pose2vid.py b/scripts/pose2vid.py index cfbc301..c9a8dd0 100644 --- a/scripts/pose2vid.py +++ b/scripts/pose2vid.py @@ -20,7 +20,7 @@ from configs.prompts.test_cases import TestCasesDict from src.models.pose_guider import PoseGuider -from src.models.unet_2d_condition import UNet2DConditionModel +from src.models.model_util import load_models, torch_gc, get_torch_device from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from src.utils.util import get_fps, read_frames, save_videos_grid @@ -53,14 +53,17 @@ def main(): else: weight_dtype = torch.float32 - vae = AutoencoderKL.from_pretrained( - config.pretrained_vae_path, - ).to("cuda", dtype=weight_dtype) + device = get_torch_device() - reference_unet = UNet2DConditionModel.from_pretrained( - config.pretrained_base_model_path, - subfolder="unet", - ).to(dtype=weight_dtype, device="cuda") + (_,_,unet,_,vae,) = load_models( + config.pretrained_base_model_path, + scheduler_name="", + v2=False, + v_pred=False, + weight_dtype=weight_dtype, + ) + vae = vae.to(device, dtype=weight_dtype) + reference_unet = unet.to(dtype=weight_dtype, device=device) inference_config_path = config.inference_config infer_config = OmegaConf.load(inference_config_path) @@ -69,13 +72,13 @@ def main(): config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, - ).to(dtype=weight_dtype, device="cuda") + ).to(dtype=weight_dtype, device=device) - pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention + pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device=device, dtype=weight_dtype) # not use cross attention image_enc = CLIPVisionModelWithProjection.from_pretrained( config.image_encoder_path - ).to(dtype=weight_dtype, device="cuda") + ).to(dtype=weight_dtype, device=device) sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) @@ -104,7 +107,7 @@ def main(): pose_guider=pose_guider, scheduler=scheduler, ) - pipe = pipe.to("cuda", dtype=weight_dtype) + pipe = pipe.to(device, dtype=weight_dtype) date_str = datetime.now().strftime("%Y%m%d") time_str = datetime.now().strftime("%H%M") @@ -191,7 +194,7 @@ def main(): stream = ffmpeg.input(save_path) audio = ffmpeg.input(audio_output) ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run() - + torch_gc() os.remove(save_path) os.remove(audio_output) diff --git a/scripts/vid2vid.py b/scripts/vid2vid.py index a5ec047..53dbbd6 100644 --- a/scripts/vid2vid.py +++ b/scripts/vid2vid.py @@ -20,7 +20,7 @@ from configs.prompts.test_cases import TestCasesDict from src.models.pose_guider import PoseGuider -from src.models.unet_2d_condition import UNet2DConditionModel +from src.models.model_util import load_models, torch_gc, get_torch_device from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from src.utils.util import get_fps, read_frames, save_videos_grid @@ -54,14 +54,17 @@ def main(): else: weight_dtype = torch.float32 - vae = AutoencoderKL.from_pretrained( - config.pretrained_vae_path, - ).to("cuda", dtype=weight_dtype) + device = get_torch_device() - reference_unet = UNet2DConditionModel.from_pretrained( - config.pretrained_base_model_path, - subfolder="unet", - ).to(dtype=weight_dtype, device="cuda") + (_,_,unet,_,vae,) = load_models( + config.pretrained_base_model_path, + scheduler_name="", + v2=False, + v_pred=False, + weight_dtype=weight_dtype, + ) + vae = vae.to(device, dtype=weight_dtype) + reference_unet = unet.to(dtype=weight_dtype, device=device) inference_config_path = config.inference_config infer_config = OmegaConf.load(inference_config_path) @@ -70,13 +73,13 @@ def main(): config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, - ).to(dtype=weight_dtype, device="cuda") + ).to(dtype=weight_dtype, device=device) - pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device="cuda", dtype=weight_dtype) # not use cross attention + pose_guider = PoseGuider(noise_latent_channels=320, use_ca=True).to(device=device, dtype=weight_dtype) # not use cross attention image_enc = CLIPVisionModelWithProjection.from_pretrained( config.image_encoder_path - ).to(dtype=weight_dtype, device="cuda") + ).to(dtype=weight_dtype, device=device) sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) @@ -105,7 +108,7 @@ def main(): pose_guider=pose_guider, scheduler=scheduler, ) - pipe = pipe.to("cuda", dtype=weight_dtype) + pipe = pipe.to(device, dtype=weight_dtype) date_str = datetime.now().strftime("%Y%m%d") time_str = datetime.now().strftime("%H%M") @@ -224,7 +227,7 @@ def main(): stream = ffmpeg.input(save_path) audio = ffmpeg.input(audio_output) ffmpeg.output(stream.video, audio.audio, save_path.replace('_noaudio.mp4', '.mp4'), vcodec='copy', acodec='aac').run() - + torch_gc() os.remove(save_path) os.remove(audio_output) diff --git a/src/models/model_util.py b/src/models/model_util.py new file mode 100644 index 0000000..0ab16b3 --- /dev/null +++ b/src/models/model_util.py @@ -0,0 +1,469 @@ +from typing import Literal, Union, Optional + +import torch +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection +from diffusers import ( + SchedulerMixin, + StableDiffusionPipeline, + StableDiffusionXLPipeline, + AutoencoderKL, +) +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + convert_ldm_unet_checkpoint, +) +from safetensors.torch import load_file +from src.models.unet_2d_condition import UNet2DConditionModel +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + LMSDiscreteScheduler, + EulerAncestralDiscreteScheduler, + UniPCMultistepScheduler, +) + +from omegaconf import OmegaConf + +# DiffUsers版StableDiffusionのモデルパラメータ +NUM_TRAIN_TIMESTEPS = 1000 +BETA_START = 0.00085 +BETA_END = 0.0120 + +UNET_PARAMS_MODEL_CHANNELS = 320 +UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] +UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` +UNET_PARAMS_IN_CHANNELS = 4 +UNET_PARAMS_OUT_CHANNELS = 4 +UNET_PARAMS_NUM_RES_BLOCKS = 2 +UNET_PARAMS_CONTEXT_DIM = 768 +UNET_PARAMS_NUM_HEADS = 8 +# UNET_PARAMS_USE_LINEAR_PROJECTION = False + +VAE_PARAMS_Z_CHANNELS = 4 +VAE_PARAMS_RESOLUTION = 256 +VAE_PARAMS_IN_CHANNELS = 3 +VAE_PARAMS_OUT_CH = 3 +VAE_PARAMS_CH = 128 +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] +VAE_PARAMS_NUM_RES_BLOCKS = 2 + +# V2 +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] +V2_UNET_PARAMS_CONTEXT_DIM = 1024 +# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True + +TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4" +TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1" + +AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a", "uniPC"] + +SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection] + +DIFFUSERS_CACHE_DIR = None # if you want to change the cache dir, change this + + +def load_checkpoint_with_text_encoder_conversion(ckpt_path: str, device="cpu"): + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ( + "cond_stage_model.transformer.embeddings.", + "cond_stage_model.transformer.text_model.embeddings.", + ), + ( + "cond_stage_model.transformer.encoder.", + "cond_stage_model.transformer.text_model.encoder.", + ), + ( + "cond_stage_model.transformer.final_layer_norm.", + "cond_stage_model.transformer.text_model.final_layer_norm.", + ), + ] + + if ckpt_path.endswith(".safetensors"): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error + else: + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None + + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) + + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return checkpoint, state_dict + + +def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [ + UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT + ] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnDownBlock2D" + if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS + else "DownBlock2D" + ) + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = ( + "CrossAttnUpBlock2D" + if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS + else "UpBlock2D" + ) + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM + if not v2 + else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS + if not v2 + else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, + ) + if v2 and use_linear_projection_in_v2: + config["use_linear_projection"] = True + + return config + + +def load_diffusers_model( + pretrained_model_name_or_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + if v2: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V2_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + # default is clip skip 2 + num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + else: + tokenizer = CLIPTokenizer.from_pretrained( + TOKENIZER_V1_MODEL_NAME, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + return tokenizer, text_encoder, unet, vae + + +def load_checkpoint_model( + checkpoint_path: str, + v2: bool = False, + clip_skip: Optional[int] = None, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]: + pipe = StableDiffusionPipeline.from_single_file( + checkpoint_path, + upcast_attention=True if v2 else False, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + _, state_dict = load_checkpoint_with_text_encoder_conversion(checkpoint_path) + unet_config = create_unet_diffusers_config(v2, use_linear_projection_in_v2=v2) + unet_config["class_embed_type"] = None + unet_config["addition_embed_type"] = None + converted_unet_checkpoint = convert_ldm_unet_checkpoint(state_dict, unet_config) + unet = UNet2DConditionModel(**unet_config) + converted_unet_checkpoint.pop("conv_out.weight") + converted_unet_checkpoint.pop("conv_out.bias") + unet.load_state_dict(converted_unet_checkpoint) + + tokenizer = pipe.tokenizer + text_encoder = pipe.text_encoder + vae = pipe.vae + if clip_skip is not None: + if v2: + text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1) + else: + text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1) + + del pipe + + return tokenizer, text_encoder, unet, vae + + +def load_models( + pretrained_model_name_or_path: str, + scheduler_name: str, + v2: bool = False, + v_pred: bool = False, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + tokenizer, text_encoder, unet, vae = load_checkpoint_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + else: # diffusers + tokenizer, text_encoder, unet, vae = load_diffusers_model( + pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype + ) + + if scheduler_name: + scheduler = create_noise_scheduler( + scheduler_name, + prediction_type="v_prediction" if v_pred else "epsilon", + ) + else: + scheduler = None + + return tokenizer, text_encoder, unet, scheduler, vae + + +def load_diffusers_model_xl( + pretrained_model_name_or_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet + + tokenizers = [ + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, + subfolder="tokenizer_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + pad_token_id=0, # same as open clip + ), + ] + + text_encoders = [ + CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + CLIPTextModelWithProjection.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ), + ] + + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, + subfolder="unet", + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + return tokenizers, text_encoders, unet, vae + + +def load_checkpoint_model_xl( + checkpoint_path: str, + weight_dtype: torch.dtype = torch.float32, +) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]: + pipe = StableDiffusionXLPipeline.from_single_file( + checkpoint_path, + torch_dtype=weight_dtype, + cache_dir=DIFFUSERS_CACHE_DIR, + ) + + unet = pipe.unet + vae = pipe.vae + tokenizers = [pipe.tokenizer, pipe.tokenizer_2] + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] + if len(text_encoders) == 2: + text_encoders[1].pad_token_id = 0 + + del pipe + + return tokenizers, text_encoders, unet, vae + + +def load_models_xl( + pretrained_model_name_or_path: str, + scheduler_name: str, + weight_dtype: torch.dtype = torch.float32, + noise_scheduler_kwargs=None, +) -> tuple[ + list[CLIPTokenizer], + list[SDXL_TEXT_ENCODER_TYPE], + UNet2DConditionModel, + SchedulerMixin, +]: + if pretrained_model_name_or_path.endswith( + ".ckpt" + ) or pretrained_model_name_or_path.endswith(".safetensors"): + (tokenizers, text_encoders, unet, vae) = load_checkpoint_model_xl( + pretrained_model_name_or_path, weight_dtype + ) + else: # diffusers + (tokenizers, text_encoders, unet, vae) = load_diffusers_model_xl( + pretrained_model_name_or_path, weight_dtype + ) + if scheduler_name: + scheduler = create_noise_scheduler(scheduler_name, noise_scheduler_kwargs) + else: + scheduler = None + + return tokenizers, text_encoders, unet, scheduler, vae + + +def create_noise_scheduler( + scheduler_name: AVAILABLE_SCHEDULERS = "ddpm", + noise_scheduler_kwargs=None, + prediction_type: Literal["epsilon", "v_prediction"] = "epsilon", +) -> SchedulerMixin: + name = scheduler_name.lower().replace(" ", "_") + if name.lower() == "ddim": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim + scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) + elif name.lower() == "ddpm": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm + scheduler = DDPMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) + elif name.lower() == "lms": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete + scheduler = LMSDiscreteScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + elif name.lower() == "euler_a": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral + scheduler = EulerAncestralDiscreteScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + elif name.lower() == "unipc": + # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/unipc + scheduler = UniPCMultistepScheduler( + **OmegaConf.to_container(noise_scheduler_kwargs) + ) + else: + raise ValueError(f"Unknown scheduler name: {name}") + + return scheduler + + +def torch_gc(): + import gc + + gc.collect() + if torch.cuda.is_available(): + with torch.cuda.device("cuda"): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +from enum import Enum + + +class CPUState(Enum): + GPU = 0 + CPU = 1 + MPS = 2 + + +cpu_state = CPUState.GPU +xpu_available = False +directml_enabled = False + + +def is_intel_xpu(): + global cpu_state + global xpu_available + if cpu_state == CPUState.GPU: + if xpu_available: + return True + return False + + +try: + import intel_extension_for_pytorch as ipex + + if torch.xpu.is_available(): + xpu_available = True +except: + pass + +try: + if torch.backends.mps.is_available(): + cpu_state = CPUState.MPS + import torch.mps +except: + pass + + +def get_torch_device(): + global directml_enabled + global cpu_state + if directml_enabled: + global directml_device + return directml_device + if cpu_state == CPUState.MPS: + return torch.device("mps") + if cpu_state == CPUState.CPU: + return torch.device("cpu") + else: + if is_intel_xpu(): + return torch.device("xpu") + else: + return torch.device(torch.cuda.current_device()) diff --git a/src/models/unet_3d.py b/src/models/unet_3d.py index 6f6181e..21675df 100644 --- a/src/models/unet_3d.py +++ b/src/models/unet_3d.py @@ -590,7 +590,7 @@ def from_pretrained_2d( ): pretrained_model_path = Path(pretrained_model_path) motion_module_path = Path(motion_module_path) - if subfolder is not None: + if not pretrained_model_path.is_file() and subfolder is not None: pretrained_model_path = pretrained_model_path.joinpath(subfolder) logger.info( f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..." @@ -598,43 +598,91 @@ def from_pretrained_2d( config_file = pretrained_model_path / "config.json" if not (config_file.exists() and config_file.is_file()): - raise RuntimeError(f"{config_file} does not exist or is not a file") - - unet_config = cls.load_config(config_file) - unet_config["_class_name"] = cls.__name__ - unet_config["down_block_types"] = [ - "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", - "DownBlock3D", - ] - unet_config["up_block_types"] = [ - "UpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D", - ] - unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn" + unet_config = { + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.6.0", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [320, 640, 1280, 1280], + "center_input_sample": False, + "cross_attention_dim": 768, + "down_block_types": [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ], + "downsample_padding": 1, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "out_channels": 4, + "sample_size": 64, + "up_block_types": [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ], + } + else: + unet_config = cls.load_config(config_file) + unet_config["_class_name"] = cls.__name__ + unet_config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ] + unet_config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + ] + unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn" model = cls.from_config(unet_config, **unet_additional_kwargs) # load the vanilla weights - if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists(): + if str(pretrained_model_path).endswith(".safetensors"): logger.debug( f"loading safeTensors weights from {pretrained_model_path} ..." ) state_dict = load_file( - pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu" + pretrained_model_path, device="cpu" ) - elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists(): + elif str(pretrained_model_path).endswith(".ckpt"): logger.debug(f"loading weights from {pretrained_model_path} ...") state_dict = torch.load( - pretrained_model_path.joinpath(WEIGHTS_NAME), + pretrained_model_path, map_location="cpu", - weights_only=True, + weights_only=False, ) else: - raise FileNotFoundError(f"no weights file found in {pretrained_model_path}") + if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists(): + + logger.debug( + f"loading safeTensors weights from {pretrained_model_path} ..." + ) + state_dict = load_file( + pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu" + ) + elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists(): + logger.debug(f"loading weights from {pretrained_model_path} ...") + state_dict = torch.load( + pretrained_model_path.joinpath(WEIGHTS_NAME), + map_location="cpu", + weights_only=True, + ) + else: + raise FileNotFoundError( + f"no weights file found for {pretrained_model_path}" + ) # load the motion module weights if motion_module_path.exists() and motion_module_path.is_file():