Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| from typing import List | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision.datasets.utils import download_url | |
| from .longclip import longclip | |
| from .viclip import get_viclip | |
| from .video_utils import extract_frames | |
| # All metrics. | |
| __all__ = ["VideoCLIPXLScore"] | |
| _MODELS = { | |
| "ViClip-InternVid-10M-FLT": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/ViClip-InternVid-10M-FLT.pth", | |
| "LongCLIP-L": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/longclip-L.pt", | |
| "VideoCLIP-XL-v2": "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/video_caption/clip/VideoCLIP-XL-v2.bin", | |
| } | |
| _MD5 = { | |
| "ViClip-InternVid-10M-FLT": "b1ebf538225438b3b75e477da7735cd0", | |
| "LongCLIP-L": "5478b662f6f85ca0ebd4bb05f9b592f3", | |
| "VideoCLIP-XL-v2": "cebda0bab14b677ec061a57e80791f35", | |
| } | |
| def normalize( | |
| data: np.array, | |
| mean: list[float] = [0.485, 0.456, 0.406], | |
| std: list[float] = [0.229, 0.224, 0.225] | |
| ): | |
| v_mean = np.array(mean).reshape(1, 1, 3) | |
| v_std = np.array(std).reshape(1, 1, 3) | |
| return (data / 255.0 - v_mean) / v_std | |
| class VideoCLIPXL(nn.Module): | |
| def __init__(self, root: str = "~/.cache/clip"): | |
| super(VideoCLIPXL, self).__init__() | |
| self.root = os.path.expanduser(root) | |
| if not os.path.exists(self.root): | |
| os.makedirs(self.root) | |
| k = "LongCLIP-L" | |
| filename = os.path.basename(_MODELS[k]) | |
| download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) | |
| self.model = longclip.load(os.path.join(self.root, filename), device="cpu")[0].float() | |
| k = "ViClip-InternVid-10M-FLT" | |
| filename = os.path.basename(_MODELS[k]) | |
| download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) | |
| self.viclip_model = get_viclip("l", os.path.join(self.root, filename))["viclip"].float() | |
| # delete unused encoder | |
| del self.model.visual | |
| del self.viclip_model.text_encoder | |
| class VideoCLIPXLScore(): | |
| def __init__(self, root: str = "~/.cache/clip", device: str = "cpu"): | |
| self.root = os.path.expanduser(root) | |
| if not os.path.exists(self.root): | |
| os.makedirs(self.root) | |
| k = "VideoCLIP-XL-v2" | |
| filename = os.path.basename(_MODELS[k]) | |
| download_url(_MODELS[k], self.root, filename=filename, md5=_MD5[k]) | |
| self.model = VideoCLIPXL() | |
| state_dict = torch.load(os.path.join(self.root, filename), map_location="cpu") | |
| self.model.load_state_dict(state_dict) | |
| self.model.to(device) | |
| self.device = device | |
| def __call__(self, videos: List[List[Image.Image]], texts: List[str]): | |
| assert len(videos) == len(texts) | |
| # Use cv2.resize in accordance with the official demo. Resize and Normalize => B * [T, 224, 224, 3]. | |
| videos = [[cv2.cvtColor(np.array(f), cv2.COLOR_RGB2BGR) for f in v] for v in videos] | |
| resize_videos = [[cv2.resize(f, (224, 224)) for f in v] for v in videos] | |
| resize_normalizied_videos = [normalize(np.stack(v)) for v in resize_videos] | |
| video_inputs = torch.stack([torch.from_numpy(v) for v in resize_normalizied_videos]) | |
| video_inputs = video_inputs.float().permute(0, 1, 4, 2, 3).to(self.device, non_blocking=True) # BTCHW | |
| with torch.no_grad(): | |
| vid_features = torch.stack( | |
| [self.model.viclip_model.get_vid_features(x.unsqueeze(0)).float() for x in video_inputs] | |
| ) | |
| vid_features.squeeze_() | |
| # vid_features = self.model.viclip_model.get_vid_features(video_inputs).float() | |
| text_inputs = longclip.tokenize(texts, truncate=True).to(self.device) | |
| text_features = self.model.model.encode_text(text_inputs) | |
| text_features = text_features / text_features.norm(dim=1, keepdim=True) | |
| scores = text_features @ vid_features.T | |
| return scores.tolist() if len(videos) == 1 else scores.diagonal().tolist() | |
| def __repr__(self): | |
| return "videoclipxl_score" | |
| if __name__ == "__main__": | |
| videos = ["your_video_path"] * 3 | |
| texts = [ | |
| "a joker", | |
| "glasses and flower", | |
| "The video opens with a view of a white building with multiple windows, partially obscured by leafless tree branches. The scene transitions to a closer view of the same building, with the tree branches more prominent in the foreground. The focus then shifts to a street sign that reads 'Abesses' in bold, yellow letters against a green background. The sign is attached to a metal structure, possibly a tram or bus stop. The sign is illuminated by a light source above it, and the background reveals a glimpse of the building and tree branches from earlier shots. The colors are muted, with the yellow sign standing out against the grey and green hues." | |
| ] | |
| video_clip_xl_score = VideoCLIPXLScore(device="cuda") | |
| batch_frames = [] | |
| for v in videos: | |
| sampled_frames = extract_frames(v, sample_method="uniform", num_sampled_frames=8)[1] | |
| batch_frames.append(sampled_frames) | |
| print(video_clip_xl_score(batch_frames, texts)) |