Spaces:
Running
on
Zero
Running
on
Zero
| from models.SpaTrackV2.models.predictor import Predictor | |
| from models.SpaTrackV2.models.vggt4track.models.vggt_moe import VGGT4Track | |
| import logging | |
| from .config import ( | |
| TTM_COG_AVAILABLE, TTM_WAN_AVAILABLE, | |
| TTM_COG_MODEL_ID, TTM_WAN_MODEL_ID, TTM_DTYPE | |
| ) | |
| logger = logging.getLogger(__name__) | |
| vggt4track_model = None | |
| tracker_model = None | |
| ttm_cog_pipeline = None | |
| ttm_wan_pipeline = None | |
| def init_spatial_models(): | |
| global vggt4track_model, tracker_model | |
| print("🚀 Initializing models...") | |
| vggt4track_model = VGGT4Track.from_pretrained( | |
| "Yuxihenry/SpatialTrackerV2_Front") | |
| vggt4track_model.eval() | |
| vggt4track_model = vggt4track_model.to("cuda") | |
| tracker_model = Predictor.from_pretrained( | |
| "Yuxihenry/SpatialTrackerV2-Offline") | |
| tracker_model.eval() | |
| print("✅ Spatial Models loaded successfully!") | |
| def get_ttm_cog_pipeline(): | |
| global ttm_cog_pipeline | |
| if ttm_cog_pipeline is None and TTM_COG_AVAILABLE: | |
| from diffusers import CogVideoXImageToVideoPipeline | |
| logger.info("Loading TTM CogVideoX pipeline...") | |
| ttm_cog_pipeline = CogVideoXImageToVideoPipeline.from_pretrained( | |
| TTM_COG_MODEL_ID, | |
| torch_dtype=TTM_DTYPE, | |
| low_cpu_mem_usage=True, | |
| ) | |
| ttm_cog_pipeline.vae.enable_tiling() | |
| ttm_cog_pipeline.vae.enable_slicing() | |
| logger.info("TTM CogVideoX pipeline loaded successfully!") | |
| return ttm_cog_pipeline | |
| def get_ttm_wan_pipeline(): | |
| global ttm_wan_pipeline | |
| if ttm_wan_pipeline is None and TTM_WAN_AVAILABLE: | |
| from diffusers import WanImageToVideoPipeline | |
| logger.info("Loading TTM Wan 2.2 pipeline...") | |
| ttm_wan_pipeline = WanImageToVideoPipeline.from_pretrained( | |
| TTM_WAN_MODEL_ID, | |
| torch_dtype=TTM_DTYPE, | |
| ) | |
| ttm_wan_pipeline.vae.enable_tiling() | |
| ttm_wan_pipeline.vae.enable_slicing() | |
| logger.info("TTM Wan 2.2 pipeline loaded successfully!") | |
| return ttm_wan_pipeline | |
| init_spatial_models() | |