elismasilva's picture
Upload 50 files
5f88efd verified
import importlib
import os
import diffusers.pipelines.pipeline_loading_utils as pipe_loading_utils
import diffusers.loaders.single_file_model as single_file_model
from diffusers.utils import (
_maybe_remap_transformers_class,
get_class_from_dynamic_module,
)
from diffusers.loaders.single_file_utils import (
convert_animatediff_checkpoint_to_diffusers,
convert_auraflow_transformer_checkpoint_to_diffusers,
convert_autoencoder_dc_checkpoint_to_diffusers,
convert_chroma_transformer_checkpoint_to_diffusers,
convert_controlnet_checkpoint,
convert_cosmos_transformer_checkpoint_to_diffusers,
convert_flux2_transformer_checkpoint_to_diffusers,
convert_flux_transformer_checkpoint_to_diffusers,
convert_hidream_transformer_to_diffusers,
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
convert_mochi_transformer_checkpoint_to_diffusers,
convert_sana_transformer_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
convert_z_image_transformer_checkpoint_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
create_vae_diffusers_config_from_ldm,
)
import torch
def convert_z_image_control_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
Z_IMAGE_KEYS_RENAME_DICT = {
"final_layer.": "all_final_layer.2-1.",
"x_embedder.": "all_x_embedder.2-1.",
".attention.out.bias": ".attention.to_out.0.bias",
".attention.k_norm.weight": ".attention.norm_k.weight",
".attention.q_norm.weight": ".attention.norm_q.weight",
".attention.out.weight": ".attention.to_out.0.weight",
"control_x_embedder.": "control_all_x_embedder.2-1.",
}
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
if ".attention.qkv.weight" not in key:
return
fused_qkv_weight = state_dict.pop(key)
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight")
new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight")
new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight")
state_dict[new_q_name] = to_q_weight
state_dict[new_k_name] = to_k_weight
state_dict[new_v_name] = to_v_weight
return
TRANSFORMER_SPECIAL_KEYS_REMAP = {
".attention.qkv.weight": convert_z_image_fused_attention,
}
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle single file --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
SINGLE_FILE_LOADABLE_CLASSES = {
"StableCascadeUNet": {
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
},
"UNet2DConditionModel": {
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
"default_subfolder": "unet",
"legacy_kwargs": {
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
},
},
"AutoencoderKL": {
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
"default_subfolder": "vae",
},
"ControlNetModel": {
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
},
"SD3Transformer2DModel": {
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"MotionAdapter": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"SparseControlNetModel": {
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
},
"FluxTransformer2DModel": {
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ChromaTransformer2DModel": {
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTXVideo": {
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
"default_subfolder": "vae",
},
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
"MochiTransformer3DModel": {
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"HunyuanVideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AuraFlowTransformer2DModel": {
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"Lumina2Transformer2DModel": {
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
"default_subfolder": "transformer",
},
"SanaTransformer2DModel": {
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanTransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanVACETransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
},
"HiDreamImageTransformer2DModel": {
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"CosmosTransformer3DModel": {
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
"Flux2Transformer2DModel": {
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ZImageTransformer2DModel": {
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ZImageControlTransformer2DModel": {
"checkpoint_mapping_fn": convert_z_image_control_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
}
def get_class_obj_and_candidates(
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
):
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
if is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
# load custom component
class_obj = get_class_from_dynamic_module(
component_folder, module_file=library_name + ".py", class_name=class_name
)
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
# Handle deprecated Transformers classes
if library_name == "transformers":
class_name = _maybe_remap_transformers_class(class_name) or class_name
try:
class_obj = getattr(library, class_name)
except:
module = importlib.import_module("diffusers_local")
class_obj = getattr(module, class_name)
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
return class_obj, class_candidates
def _get_single_file_loadable_mapping_class(cls):
diffusers_module = importlib.import_module("diffusers")
class_name_str = cls.__name__
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
try:
loadable_class = getattr(diffusers_module, loadable_class_str)
except:
module = importlib.import_module("diffusers_local")
loadable_class = getattr(module, loadable_class_str)
if issubclass(cls, loadable_class):
return loadable_class_str
return class_name_str
pipe_loading_utils.get_class_obj_and_candidates = get_class_obj_and_candidates
single_file_model.SINGLE_FILE_LOADABLE_CLASSES = SINGLE_FILE_LOADABLE_CLASSES
single_file_model._get_single_file_loadable_mapping_class = _get_single_file_loadable_mapping_class