from models.SpaTrackV2.models.predictor import Predictor from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track import logging from .config import ( TTM_COG_AVAILABLE, TTM_WAN_AVAILABLE, TTM_COG_MODEL_ID, TTM_WAN_MODEL_ID, TTM_DTYPE ) logger = logging.getLogger(__name__) vggt4track_model = None tracker_model = None ttm_cog_pipeline = None ttm_wan_pipeline = None def init_spatial_models(): global vggt4track_model, tracker_model print("🚀 Initializing models...") vggt4track_model = VGGT4Track.from_pretrained( "Yuxihenry/SpatialTrackerV2_Front") vggt4track_model.eval() vggt4track_model = vggt4track_model.to("cuda") tracker_model = Predictor.from_pretrained( "Yuxihenry/SpatialTrackerV2-Offline") tracker_model.eval() print("✅ Spatial Models loaded successfully!") def get_ttm_cog_pipeline(): global ttm_cog_pipeline if ttm_cog_pipeline is None and TTM_COG_AVAILABLE: from diffusers import CogVideoXImageToVideoPipeline logger.info("Loading TTM CogVideoX pipeline...") ttm_cog_pipeline = CogVideoXImageToVideoPipeline.from_pretrained( TTM_COG_MODEL_ID, torch_dtype=TTM_DTYPE, low_cpu_mem_usage=True, ) ttm_cog_pipeline.vae.enable_tiling() ttm_cog_pipeline.vae.enable_slicing() logger.info("TTM CogVideoX pipeline loaded successfully!") return ttm_cog_pipeline def get_ttm_wan_pipeline(): global ttm_wan_pipeline if ttm_wan_pipeline is None and TTM_WAN_AVAILABLE: from diffusers import WanImageToVideoPipeline logger.info("Loading TTM Wan 2.2 pipeline...") ttm_wan_pipeline = WanImageToVideoPipeline.from_pretrained( TTM_WAN_MODEL_ID, torch_dtype=TTM_DTYPE, ) ttm_wan_pipeline.vae.enable_tiling() ttm_wan_pipeline.vae.enable_slicing() logger.info("TTM Wan 2.2 pipeline loaded successfully!") return ttm_wan_pipeline init_spatial_models()