Spaces:
Runtime error
Runtime error
| import json | |
| import logging | |
| from os import PathLike | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional, Tuple, Union | |
| from pydantic import BaseConfig, BaseSettings, Field | |
| from pydantic.env_settings import (EnvSettingsSource, InitSettingsSource, | |
| SecretsSettingsSource, | |
| SettingsSourceCallable) | |
| from animatediff import get_dir | |
| from animatediff.schedulers import DiffusionScheduler | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| CKPT_EXTENSIONS = [".pt", ".ckpt", ".pth", ".safetensors"] | |
| class JsonSettingsSource: | |
| __slots__ = ["json_config_path"] | |
| def __init__( | |
| self, | |
| json_config_path: Optional[Union[PathLike, list[PathLike]]] = list(), | |
| ) -> None: | |
| if isinstance(json_config_path, list): | |
| self.json_config_path = [Path(path) for path in json_config_path] | |
| else: | |
| self.json_config_path = [Path(json_config_path)] if json_config_path is not None else [] | |
| def __call__(self, settings: BaseSettings) -> Dict[str, Any]: # noqa C901 | |
| classname = settings.__class__.__name__ | |
| encoding = settings.__config__.env_file_encoding | |
| if len(self.json_config_path) == 0: | |
| pass # no json config provided | |
| merged_config = dict() # create an empty dict to merge configs into | |
| for idx, path in enumerate(self.json_config_path): | |
| if path.exists() and path.is_file(): # check if the path exists and is a file | |
| logger.debug(f"{classname}: loading config #{idx+1} from {path}") | |
| merged_config.update(json.loads(path.read_text(encoding=encoding))) | |
| logger.debug(f"{classname}: config state #{idx+1}: {merged_config}") | |
| else: | |
| raise FileNotFoundError(f"{classname}: config #{idx+1} at {path} not found or not a file") | |
| logger.debug(f"{classname}: loaded config: {merged_config}") | |
| return merged_config # return the merged config | |
| def __repr__(self) -> str: | |
| return f"JsonSettingsSource(json_config_path={repr(self.json_config_path)})" | |
| class JsonConfig(BaseConfig): | |
| json_config_path: Optional[Union[Path, list[Path]]] = None | |
| env_file_encoding: str = "utf-8" | |
| def customise_sources( | |
| cls, | |
| init_settings: InitSettingsSource, | |
| env_settings: EnvSettingsSource, | |
| file_secret_settings: SecretsSettingsSource, | |
| ) -> Tuple[SettingsSourceCallable, ...]: | |
| # pull json_config_path from init_settings if passed, otherwise use the class var | |
| json_config_path = init_settings.init_kwargs.pop("json_config_path", cls.json_config_path) | |
| logger.debug(f"Using JsonSettingsSource for {cls.__name__}") | |
| json_settings = JsonSettingsSource(json_config_path=json_config_path) | |
| # return the new settings sources | |
| return ( | |
| init_settings, | |
| json_settings, | |
| ) | |
| class InferenceConfig(BaseSettings): | |
| unet_additional_kwargs: dict[str, Any] | |
| noise_scheduler_kwargs: dict[str, Any] | |
| class Config(JsonConfig): | |
| json_config_path: Path | |
| def get_infer_config( | |
| is_v2:bool, | |
| is_sdxl:bool, | |
| ) -> InferenceConfig: | |
| config_path: Path = get_dir("config").joinpath("inference/default.json" if not is_v2 else "inference/motion_v2.json") | |
| if is_sdxl: | |
| config_path = get_dir("config").joinpath("inference/motion_sdxl.json") | |
| settings = InferenceConfig(json_config_path=config_path) | |
| return settings | |
| class ModelConfig(BaseSettings): | |
| name: str = Field(...) # Config name, not actually used for much of anything | |
| path: Path = Field(...) # Path to the model | |
| vae_path: str = "" # Path to the model | |
| motion_module: Path = Field(...) # Path to the motion module | |
| context_schedule: str = "uniform" | |
| lcm_map: Dict[str,Any]= Field({}) | |
| gradual_latent_hires_fix_map: Dict[str,Any]= Field({}) | |
| compile: bool = Field(False) # whether to compile the model with TorchDynamo | |
| tensor_interpolation_slerp: bool = Field(True) | |
| seed: list[int] = Field([]) # Seed(s) for the random number generators | |
| scheduler: DiffusionScheduler = Field(DiffusionScheduler.k_dpmpp_2m) # Scheduler to use | |
| steps: int = 25 # Number of inference steps to run | |
| guidance_scale: float = 7.5 # CFG scale to use | |
| unet_batch_size: int = 1 | |
| clip_skip: int = 1 # skip the last N-1 layers of the CLIP text encoder | |
| prompt_fixed_ratio: float = 0.5 | |
| head_prompt: str = "" | |
| prompt_map: Dict[str,str]= Field({}) | |
| tail_prompt: str = "" | |
| n_prompt: list[str] = Field([]) # Anti-prompt(s) to use | |
| is_single_prompt_mode : bool = Field(False) | |
| lora_map: Dict[str,Any]= Field({}) | |
| motion_lora_map: Dict[str,float]= Field({}) | |
| ip_adapter_map: Dict[str,Any]= Field({}) | |
| img2img_map: Dict[str,Any]= Field({}) | |
| region_map: Dict[str,Any]= Field({}) | |
| controlnet_map: Dict[str,Any]= Field({}) | |
| upscale_config: Dict[str,Any]= Field({}) | |
| stylize_config: Dict[str,Any]= Field({}) | |
| output: Dict[str,Any]= Field({}) | |
| result: Dict[str,Any]= Field({}) | |
| class Config(JsonConfig): | |
| json_config_path: Path | |
| def save_name(self): | |
| return f"{self.name.lower()}-{self.path.stem.lower()}" | |
| def get_model_config(config_path: Path) -> ModelConfig: | |
| settings = ModelConfig(json_config_path=config_path) | |
| return settings | |