import importlib import os import diffusers.pipelines.pipeline_loading_utils as pipe_loading_utils import diffusers.loaders.single_file_model as single_file_model from diffusers.utils import ( _maybe_remap_transformers_class, get_class_from_dynamic_module, ) from diffusers.loaders.single_file_utils import ( convert_animatediff_checkpoint_to_diffusers, convert_auraflow_transformer_checkpoint_to_diffusers, convert_autoencoder_dc_checkpoint_to_diffusers, convert_chroma_transformer_checkpoint_to_diffusers, convert_controlnet_checkpoint, convert_cosmos_transformer_checkpoint_to_diffusers, convert_flux2_transformer_checkpoint_to_diffusers, convert_flux_transformer_checkpoint_to_diffusers, convert_hidream_transformer_to_diffusers, convert_hunyuan_video_transformer_to_diffusers, convert_ldm_unet_checkpoint, convert_ldm_vae_checkpoint, convert_ltx_transformer_checkpoint_to_diffusers, convert_ltx_vae_checkpoint_to_diffusers, convert_lumina2_to_diffusers, convert_mochi_transformer_checkpoint_to_diffusers, convert_sana_transformer_to_diffusers, convert_sd3_transformer_checkpoint_to_diffusers, convert_stable_cascade_unet_single_file_to_diffusers, convert_wan_transformer_to_diffusers, convert_wan_vae_to_diffusers, convert_z_image_transformer_checkpoint_to_diffusers, create_controlnet_diffusers_config_from_ldm, create_unet_diffusers_config_from_ldm, create_vae_diffusers_config_from_ldm, ) import torch def convert_z_image_control_transformer_checkpoint_to_diffusers(checkpoint, **kwargs): Z_IMAGE_KEYS_RENAME_DICT = { "final_layer.": "all_final_layer.2-1.", "x_embedder.": "all_x_embedder.2-1.", ".attention.out.bias": ".attention.to_out.0.bias", ".attention.k_norm.weight": ".attention.norm_k.weight", ".attention.q_norm.weight": ".attention.norm_q.weight", ".attention.out.weight": ".attention.to_out.0.weight", "control_x_embedder.": "control_all_x_embedder.2-1.", } def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None: if ".attention.qkv.weight" not in key: return fused_qkv_weight = state_dict.pop(key) to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0) new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight") new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight") new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight") state_dict[new_q_name] = to_q_weight state_dict[new_k_name] = to_k_weight state_dict[new_v_name] = to_v_weight return TRANSFORMER_SPECIAL_KEYS_REMAP = { ".attention.qkv.weight": convert_z_image_fused_attention, } def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None: state_dict[new_key] = state_dict.pop(old_key) converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} # Handle single file --> diffusers key remapping via the remap dict for key in list(converted_state_dict.keys()): new_key = key[:] for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict(converted_state_dict, key, new_key) # Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in # special_keys_remap for key in list(converted_state_dict.keys()): for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): if special_key not in key: continue handler_fn_inplace(key, converted_state_dict) return converted_state_dict SINGLE_FILE_LOADABLE_CLASSES = { "StableCascadeUNet": { "checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers, }, "UNet2DConditionModel": { "checkpoint_mapping_fn": convert_ldm_unet_checkpoint, "config_mapping_fn": create_unet_diffusers_config_from_ldm, "default_subfolder": "unet", "legacy_kwargs": { "num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args }, }, "AutoencoderKL": { "checkpoint_mapping_fn": convert_ldm_vae_checkpoint, "config_mapping_fn": create_vae_diffusers_config_from_ldm, "default_subfolder": "vae", }, "ControlNetModel": { "checkpoint_mapping_fn": convert_controlnet_checkpoint, "config_mapping_fn": create_controlnet_diffusers_config_from_ldm, }, "SD3Transformer2DModel": { "checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "MotionAdapter": { "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, }, "SparseControlNetModel": { "checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers, }, "FluxTransformer2DModel": { "checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "ChromaTransformer2DModel": { "checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "LTXVideoTransformer3DModel": { "checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "AutoencoderKLLTXVideo": { "checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers, "default_subfolder": "vae", }, "AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers}, "MochiTransformer3DModel": { "checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "HunyuanVideoTransformer3DModel": { "checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers, "default_subfolder": "transformer", }, "AuraFlowTransformer2DModel": { "checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "Lumina2Transformer2DModel": { "checkpoint_mapping_fn": convert_lumina2_to_diffusers, "default_subfolder": "transformer", }, "SanaTransformer2DModel": { "checkpoint_mapping_fn": convert_sana_transformer_to_diffusers, "default_subfolder": "transformer", }, "WanTransformer3DModel": { "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, "default_subfolder": "transformer", }, "WanVACETransformer3DModel": { "checkpoint_mapping_fn": convert_wan_transformer_to_diffusers, "default_subfolder": "transformer", }, "AutoencoderKLWan": { "checkpoint_mapping_fn": convert_wan_vae_to_diffusers, "default_subfolder": "vae", }, "HiDreamImageTransformer2DModel": { "checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers, "default_subfolder": "transformer", }, "CosmosTransformer3DModel": { "checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "QwenImageTransformer2DModel": { "checkpoint_mapping_fn": lambda x: x, "default_subfolder": "transformer", }, "Flux2Transformer2DModel": { "checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "ZImageTransformer2DModel": { "checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, "ZImageControlTransformer2DModel": { "checkpoint_mapping_fn": convert_z_image_control_transformer_checkpoint_to_diffusers, "default_subfolder": "transformer", }, } def get_class_obj_and_candidates( library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None ): """Simple helper method to retrieve class object of module as well as potential parent class objects""" component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None if is_pipeline_module: pipeline_module = getattr(pipelines, library_name) class_obj = getattr(pipeline_module, class_name) class_candidates = dict.fromkeys(importable_classes.keys(), class_obj) elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")): # load custom component class_obj = get_class_from_dynamic_module( component_folder, module_file=library_name + ".py", class_name=class_name ) class_candidates = dict.fromkeys(importable_classes.keys(), class_obj) else: # else we just import it from the library. library = importlib.import_module(library_name) # Handle deprecated Transformers classes if library_name == "transformers": class_name = _maybe_remap_transformers_class(class_name) or class_name try: class_obj = getattr(library, class_name) except: module = importlib.import_module("diffusers_local") class_obj = getattr(module, class_name) class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} return class_obj, class_candidates def _get_single_file_loadable_mapping_class(cls): diffusers_module = importlib.import_module("diffusers") class_name_str = cls.__name__ for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES: try: loadable_class = getattr(diffusers_module, loadable_class_str) except: module = importlib.import_module("diffusers_local") loadable_class = getattr(module, loadable_class_str) if issubclass(cls, loadable_class): return loadable_class_str return class_name_str pipe_loading_utils.get_class_obj_and_candidates = get_class_obj_and_candidates single_file_model.SINGLE_FILE_LOADABLE_CLASSES = SINGLE_FILE_LOADABLE_CLASSES single_file_model._get_single_file_loadable_mapping_class = _get_single_file_loadable_mapping_class