|
|
import importlib
|
|
|
import os
|
|
|
import diffusers.pipelines.pipeline_loading_utils as pipe_loading_utils
|
|
|
import diffusers.loaders.single_file_model as single_file_model
|
|
|
from diffusers.utils import (
|
|
|
_maybe_remap_transformers_class,
|
|
|
get_class_from_dynamic_module,
|
|
|
)
|
|
|
from diffusers.loaders.single_file_utils import (
|
|
|
convert_animatediff_checkpoint_to_diffusers,
|
|
|
convert_auraflow_transformer_checkpoint_to_diffusers,
|
|
|
convert_autoencoder_dc_checkpoint_to_diffusers,
|
|
|
convert_chroma_transformer_checkpoint_to_diffusers,
|
|
|
convert_controlnet_checkpoint,
|
|
|
convert_cosmos_transformer_checkpoint_to_diffusers,
|
|
|
convert_flux2_transformer_checkpoint_to_diffusers,
|
|
|
convert_flux_transformer_checkpoint_to_diffusers,
|
|
|
convert_hidream_transformer_to_diffusers,
|
|
|
convert_hunyuan_video_transformer_to_diffusers,
|
|
|
convert_ldm_unet_checkpoint,
|
|
|
convert_ldm_vae_checkpoint,
|
|
|
convert_ltx_transformer_checkpoint_to_diffusers,
|
|
|
convert_ltx_vae_checkpoint_to_diffusers,
|
|
|
convert_lumina2_to_diffusers,
|
|
|
convert_mochi_transformer_checkpoint_to_diffusers,
|
|
|
convert_sana_transformer_to_diffusers,
|
|
|
convert_sd3_transformer_checkpoint_to_diffusers,
|
|
|
convert_stable_cascade_unet_single_file_to_diffusers,
|
|
|
convert_wan_transformer_to_diffusers,
|
|
|
convert_wan_vae_to_diffusers,
|
|
|
convert_z_image_transformer_checkpoint_to_diffusers,
|
|
|
create_controlnet_diffusers_config_from_ldm,
|
|
|
create_unet_diffusers_config_from_ldm,
|
|
|
create_vae_diffusers_config_from_ldm,
|
|
|
)
|
|
|
import torch
|
|
|
def convert_z_image_control_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
|
Z_IMAGE_KEYS_RENAME_DICT = {
|
|
|
"final_layer.": "all_final_layer.2-1.",
|
|
|
"x_embedder.": "all_x_embedder.2-1.",
|
|
|
".attention.out.bias": ".attention.to_out.0.bias",
|
|
|
".attention.k_norm.weight": ".attention.norm_k.weight",
|
|
|
".attention.q_norm.weight": ".attention.norm_q.weight",
|
|
|
".attention.out.weight": ".attention.to_out.0.weight",
|
|
|
"control_x_embedder.": "control_all_x_embedder.2-1.",
|
|
|
}
|
|
|
|
|
|
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
|
|
|
if ".attention.qkv.weight" not in key:
|
|
|
return
|
|
|
|
|
|
fused_qkv_weight = state_dict.pop(key)
|
|
|
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
|
|
new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight")
|
|
|
new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight")
|
|
|
new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight")
|
|
|
|
|
|
state_dict[new_q_name] = to_q_weight
|
|
|
state_dict[new_k_name] = to_k_weight
|
|
|
state_dict[new_v_name] = to_v_weight
|
|
|
return
|
|
|
|
|
|
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
|
|
".attention.qkv.weight": convert_z_image_fused_attention,
|
|
|
}
|
|
|
|
|
|
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
|
|
|
state_dict[new_key] = state_dict.pop(old_key)
|
|
|
|
|
|
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
|
|
|
|
|
|
|
|
for key in list(converted_state_dict.keys()):
|
|
|
new_key = key[:]
|
|
|
for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items():
|
|
|
new_key = new_key.replace(replace_key, rename_key)
|
|
|
|
|
|
update_state_dict(converted_state_dict, key, new_key)
|
|
|
|
|
|
|
|
|
|
|
|
for key in list(converted_state_dict.keys()):
|
|
|
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
|
|
if special_key not in key:
|
|
|
continue
|
|
|
handler_fn_inplace(key, converted_state_dict)
|
|
|
|
|
|
return converted_state_dict
|
|
|
|
|
|
SINGLE_FILE_LOADABLE_CLASSES = {
|
|
|
"StableCascadeUNet": {
|
|
|
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
|
|
},
|
|
|
"UNet2DConditionModel": {
|
|
|
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
|
|
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
|
|
"default_subfolder": "unet",
|
|
|
"legacy_kwargs": {
|
|
|
"num_in_channels": "in_channels",
|
|
|
},
|
|
|
},
|
|
|
"AutoencoderKL": {
|
|
|
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
|
|
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
|
|
"default_subfolder": "vae",
|
|
|
},
|
|
|
"ControlNetModel": {
|
|
|
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
|
|
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
|
|
},
|
|
|
"SD3Transformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"MotionAdapter": {
|
|
|
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
|
|
},
|
|
|
"SparseControlNetModel": {
|
|
|
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
|
|
},
|
|
|
"FluxTransformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"ChromaTransformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"LTXVideoTransformer3DModel": {
|
|
|
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"AutoencoderKLLTXVideo": {
|
|
|
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "vae",
|
|
|
},
|
|
|
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
|
|
"MochiTransformer3DModel": {
|
|
|
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"HunyuanVideoTransformer3DModel": {
|
|
|
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"AuraFlowTransformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"Lumina2Transformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"SanaTransformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"WanTransformer3DModel": {
|
|
|
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"WanVACETransformer3DModel": {
|
|
|
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"AutoencoderKLWan": {
|
|
|
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
|
|
"default_subfolder": "vae",
|
|
|
},
|
|
|
"HiDreamImageTransformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"CosmosTransformer3DModel": {
|
|
|
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"QwenImageTransformer2DModel": {
|
|
|
"checkpoint_mapping_fn": lambda x: x,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"Flux2Transformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"ZImageTransformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
"ZImageControlTransformer2DModel": {
|
|
|
"checkpoint_mapping_fn": convert_z_image_control_transformer_checkpoint_to_diffusers,
|
|
|
"default_subfolder": "transformer",
|
|
|
},
|
|
|
}
|
|
|
|
|
|
def get_class_obj_and_candidates(
|
|
|
library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None
|
|
|
):
|
|
|
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
|
|
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
|
|
|
|
|
|
if is_pipeline_module:
|
|
|
pipeline_module = getattr(pipelines, library_name)
|
|
|
|
|
|
class_obj = getattr(pipeline_module, class_name)
|
|
|
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
|
|
|
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
|
|
|
|
|
class_obj = get_class_from_dynamic_module(
|
|
|
component_folder, module_file=library_name + ".py", class_name=class_name
|
|
|
)
|
|
|
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
|
|
|
else:
|
|
|
|
|
|
library = importlib.import_module(library_name)
|
|
|
|
|
|
|
|
|
if library_name == "transformers":
|
|
|
class_name = _maybe_remap_transformers_class(class_name) or class_name
|
|
|
|
|
|
try:
|
|
|
class_obj = getattr(library, class_name)
|
|
|
except:
|
|
|
module = importlib.import_module("diffusers_local")
|
|
|
class_obj = getattr(module, class_name)
|
|
|
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
|
|
|
|
|
return class_obj, class_candidates
|
|
|
|
|
|
def _get_single_file_loadable_mapping_class(cls):
|
|
|
diffusers_module = importlib.import_module("diffusers")
|
|
|
class_name_str = cls.__name__
|
|
|
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
|
|
try:
|
|
|
loadable_class = getattr(diffusers_module, loadable_class_str)
|
|
|
except:
|
|
|
module = importlib.import_module("diffusers_local")
|
|
|
loadable_class = getattr(module, loadable_class_str)
|
|
|
if issubclass(cls, loadable_class):
|
|
|
return loadable_class_str
|
|
|
|
|
|
return class_name_str
|
|
|
|
|
|
pipe_loading_utils.get_class_obj_and_candidates = get_class_obj_and_candidates
|
|
|
single_file_model.SINGLE_FILE_LOADABLE_CLASSES = SINGLE_FILE_LOADABLE_CLASSES
|
|
|
single_file_model._get_single_file_loadable_mapping_class = _get_single_file_loadable_mapping_class |