akhaliq's picture
akhaliq HF Staff
Upload 157 files
939bf35 verified
import os
from abc import ABC, abstractmethod
import torch
import torchvision.transforms as transforms
from einops import rearrange
from torchvision.datasets.utils import download_url
from typing import Optional, Tuple
# All reward models.
__all__ = ["AestheticReward", "HPSReward", "PickScoreReward", "MPSReward"]
class BaseReward(ABC):
"""An base class for reward models. A custom Reward class must implement two functions below.
"""
def __init__(self):
"""Define your reward model and image transformations (optional) here.
"""
pass
@abstractmethod
def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]:
"""Given batch frames with shape `[B, C, T, H, W]` extracted from a list of videos and a list of prompts
(optional) correspondingly, return the loss and reward computed by your reward model (reduction by mean).
"""
pass
class AestheticReward(BaseReward):
"""Aesthetic Predictor [V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)
and [V2.5](https://github.com/discus0434/aesthetic-predictor-v2-5) reward model.
"""
def __init__(
self,
encoder_path="openai/clip-vit-large-patch14",
predictor_path=None,
version="v2",
device="cpu",
dtype=torch.float16,
max_reward=10,
loss_scale=0.1,
):
from .improved_aesthetic_predictor import ImprovedAestheticPredictor
from ..video_caption.utils.siglip_v2_5 import convert_v2_5_from_siglip
self.encoder_path = encoder_path
self.predictor_path = predictor_path
self.version = version
self.device = device
self.dtype = dtype
self.max_reward = max_reward
self.loss_scale = loss_scale
if self.version != "v2" and self.version != "v2.5":
raise ValueError("Only v2 and v2.5 are supported.")
if self.version == "v2":
assert "clip-vit-large-patch14" in encoder_path.lower()
self.model = ImprovedAestheticPredictor(encoder_path=self.encoder_path, predictor_path=self.predictor_path)
# https://huggingface.co/openai/clip-vit-large-patch14/blob/main/preprocessor_config.json
# TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
self.transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])
elif self.version == "v2.5":
assert "siglip-so400m-patch14-384" in encoder_path.lower()
self.model, _ = convert_v2_5_from_siglip(encoder_model_name=self.encoder_path)
# https://huggingface.co/google/siglip-so400m-patch14-384/blob/main/preprocessor_config.json
self.transform = transforms.Compose([
transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
self.model.to(device=self.device, dtype=self.dtype)
self.model.requires_grad_(False)
def __call__(self, batch_frames: torch.Tensor, batch_prompt: Optional[list[str]]=None) -> Tuple[torch.Tensor, torch.Tensor]:
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
batch_loss, batch_reward = 0, 0
for frames in batch_frames:
pixel_values = torch.stack([self.transform(frame) for frame in frames])
pixel_values = pixel_values.to(self.device, dtype=self.dtype)
if self.version == "v2":
reward = self.model(pixel_values)
elif self.version == "v2.5":
reward = self.model(pixel_values).logits.squeeze()
# Convert reward to loss in [0, 1].
if self.max_reward is None:
loss = (-1 * reward) * self.loss_scale
else:
loss = abs(reward - self.max_reward) * self.loss_scale
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
class HPSReward(BaseReward):
"""[HPS](https://github.com/tgxs002/HPSv2) v2 and v2.1 reward model.
"""
def __init__(
self,
model_path=None,
version="v2.0",
device="cpu",
dtype=torch.float16,
max_reward=1,
loss_scale=1,
):
from hpsv2.src.open_clip import create_model_and_transforms, get_tokenizer
self.model_path = model_path
self.version = version
self.device = device
self.dtype = dtype
self.max_reward = max_reward
self.loss_scale = loss_scale
self.model, _, _ = create_model_and_transforms(
"ViT-H-14",
"laion2B-s32B-b79K",
precision=self.dtype,
device=self.device,
jit=False,
force_quick_gelu=False,
force_custom_text=False,
force_patch_dropout=False,
force_image_size=None,
pretrained_image=False,
image_mean=None,
image_std=None,
light_augmentation=True,
aug_cfg={},
output_dict=True,
with_score_predictor=False,
with_region_predictor=False,
)
self.tokenizer = get_tokenizer("ViT-H-14")
# https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json
# TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
self.transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])
if version == "v2.0":
url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2_compressed.pt"
filename = "HPS_v2_compressed.pt"
md5 = "fd9180de357abf01fdb4eaad64631db4"
elif version == "v2.1":
url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/HPS_v2.1_compressed.pt"
filename = "HPS_v2.1_compressed.pt"
md5 = "4067542e34ba2553a738c5ac6c1d75c0"
else:
raise ValueError("Only v2.0 and v2.1 are supported.")
if self.model_path is None or not os.path.exists(self.model_path):
download_url(url, torch.hub.get_dir(), md5=md5)
model_path = os.path.join(torch.hub.get_dir(), filename)
state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
self.model.load_state_dict(state_dict)
self.model.to(device=self.device, dtype=self.dtype)
self.model.requires_grad_(False)
self.model.eval()
def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
assert batch_frames.shape[0] == len(batch_prompt)
# Compute batch reward and loss in frame-wise.
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
batch_loss, batch_reward = 0, 0
for frames in batch_frames:
image_inputs = torch.stack([self.transform(frame) for frame in frames])
image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
text_inputs = self.tokenizer(batch_prompt).to(device=self.device)
outputs = self.model(image_inputs, text_inputs)
image_features, text_features = outputs["image_features"], outputs["text_features"]
logits = image_features @ text_features.T
reward = torch.diagonal(logits)
# Convert reward to loss in [0, 1].
if self.max_reward is None:
loss = (-1 * reward) * self.loss_scale
else:
loss = abs(reward - self.max_reward) * self.loss_scale
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
class PickScoreReward(BaseReward):
"""[PickScore](https://github.com/yuvalkirstain/PickScore) reward model.
"""
def __init__(
self,
model_path="yuvalkirstain/PickScore_v1",
device="cpu",
dtype=torch.float16,
max_reward=1,
loss_scale=1,
):
from transformers import AutoProcessor, AutoModel
self.model_path = model_path
self.device = device
self.dtype = dtype
self.max_reward = max_reward
self.loss_scale = loss_scale
# https://huggingface.co/yuvalkirstain/PickScore_v1/blob/main/preprocessor_config.json
self.transform = transforms.Compose([
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])
self.processor = AutoProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=self.dtype)
self.model = AutoModel.from_pretrained(model_path, torch_dtype=self.dtype).eval().to(device)
self.model.requires_grad_(False)
self.model.eval()
def __call__(self, batch_frames: torch.Tensor, batch_prompt: list[str]) -> Tuple[torch.Tensor, torch.Tensor]:
assert batch_frames.shape[0] == len(batch_prompt)
# Compute batch reward and loss in frame-wise.
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
batch_loss, batch_reward = 0, 0
for frames in batch_frames:
image_inputs = torch.stack([self.transform(frame) for frame in frames])
image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
text_inputs = self.processor(
text=batch_prompt,
padding=True,
truncation=True,
max_length=77,
return_tensors="pt",
).to(self.device)
image_features = self.model.get_image_features(pixel_values=image_inputs)
text_features = self.model.get_text_features(**text_inputs)
image_features = image_features / torch.norm(image_features, dim=-1, keepdim=True)
text_features = text_features / torch.norm(text_features, dim=-1, keepdim=True)
logits = image_features @ text_features.T
reward = torch.diagonal(logits)
# Convert reward to loss in [0, 1].
if self.max_reward is None:
loss = (-1 * reward) * self.loss_scale
else:
loss = abs(reward - self.max_reward) * self.loss_scale
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
class MPSReward(BaseReward):
"""[MPS](https://github.com/Kwai-Kolors/MPS) reward model.
"""
def __init__(
self,
model_path=None,
device="cpu",
dtype=torch.float16,
max_reward=1,
loss_scale=1,
):
from transformers import AutoTokenizer, AutoConfig
from .MPS.trainer.models.clip_model import CLIPModel
self.model_path = model_path
self.device = device
self.dtype = dtype
self.condition = "light, color, clarity, tone, style, ambiance, artistry, shape, face, hair, hands, limbs, structure, instance, texture, quantity, attributes, position, number, location, word, things."
self.max_reward = max_reward
self.loss_scale = loss_scale
processor_name_or_path = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
# https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/preprocessor_config.json
# TODO: [transforms.Resize(224), transforms.CenterCrop(224)] for any aspect ratio.
self.transform = transforms.Compose([
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]),
])
# We convert the original [ckpt](http://drive.google.com/file/d/17qrK_aJkVNM75ZEvMEePpLj6L867MLkN/view?usp=sharing)
# (contains the entire model) to a `state_dict`.
url = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/easyanimate/Third_Party/MPS_overall.pth"
filename = "MPS_overall.pth"
md5 = "1491cbbbd20565747fe07e7572e2ac56"
if self.model_path is None or not os.path.exists(self.model_path):
download_url(url, torch.hub.get_dir(), md5=md5)
model_path = os.path.join(torch.hub.get_dir(), filename)
self.tokenizer = AutoTokenizer.from_pretrained(processor_name_or_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(processor_name_or_path)
self.model = CLIPModel(config)
state_dict = torch.load(model_path, map_location="cpu")
self.model.load_state_dict(state_dict, strict=False)
self.model.to(device=self.device, dtype=self.dtype)
self.model.requires_grad_(False)
self.model.eval()
def _tokenize(self, caption):
input_ids = self.tokenizer(
caption,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt"
).input_ids
return input_ids
def __call__(
self,
batch_frames: torch.Tensor,
batch_prompt: list[str],
batch_condition: Optional[list[str]] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
if batch_condition is None:
batch_condition = [self.condition] * len(batch_prompt)
batch_frames = rearrange(batch_frames, "b c t h w -> t b c h w")
batch_loss, batch_reward = 0, 0
for frames in batch_frames:
image_inputs = torch.stack([self.transform(frame) for frame in frames])
image_inputs = image_inputs.to(device=self.device, dtype=self.dtype)
text_inputs = self._tokenize(batch_prompt).to(self.device)
condition_inputs = self._tokenize(batch_condition).to(device=self.device)
text_features, image_features = self.model(text_inputs, image_inputs, condition_inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# reward = self.model.logit_scale.exp() * torch.diag(torch.einsum('bd,cd->bc', text_features, image_features))
logits = image_features @ text_features.T
reward = torch.diagonal(logits)
# Convert reward to loss in [0, 1].
if self.max_reward is None:
loss = (-1 * reward) * self.loss_scale
else:
loss = abs(reward - self.max_reward) * self.loss_scale
batch_loss, batch_reward = batch_loss + loss.mean(), batch_reward + reward.mean()
return batch_loss / batch_frames.shape[0], batch_reward / batch_frames.shape[0]
if __name__ == "__main__":
import numpy as np
from decord import VideoReader
video_path_list = ["your_video_path_1.mp4", "your_video_path_2.mp4"]
prompt_list = ["your_prompt_1", "your_prompt_2"]
num_sampled_frames = 8
to_tensor = transforms.ToTensor()
sampled_frames_list = []
for video_path in video_path_list:
vr = VideoReader(video_path)
sampled_frame_indices = np.linspace(0, len(vr), num_sampled_frames, endpoint=False, dtype=int)
sampled_frames = vr.get_batch(sampled_frame_indices).asnumpy()
sampled_frames = torch.stack([to_tensor(frame) for frame in sampled_frames])
sampled_frames_list.append(sampled_frames)
sampled_frames = torch.stack(sampled_frames_list)
sampled_frames = rearrange(sampled_frames, "b t c h w -> b c t h w")
aesthetic_reward_v2 = AestheticReward(device="cuda", dtype=torch.bfloat16)
print(f"aesthetic_reward_v2: {aesthetic_reward_v2(sampled_frames)}")
aesthetic_reward_v2_5 = AestheticReward(
encoder_path="google/siglip-so400m-patch14-384", version="v2.5", device="cuda", dtype=torch.bfloat16
)
print(f"aesthetic_reward_v2_5: {aesthetic_reward_v2_5(sampled_frames)}")
hps_reward_v2 = HPSReward(device="cuda", dtype=torch.bfloat16)
print(f"hps_reward_v2: {hps_reward_v2(sampled_frames, prompt_list)}")
hps_reward_v2_1 = HPSReward(version="v2.1", device="cuda", dtype=torch.bfloat16)
print(f"hps_reward_v2_1: {hps_reward_v2_1(sampled_frames, prompt_list)}")
pick_score = PickScoreReward(device="cuda", dtype=torch.bfloat16)
print(f"pick_score_reward: {pick_score(sampled_frames, prompt_list)}")
mps_score = MPSReward(device="cuda", dtype=torch.bfloat16)
print(f"mps_reward: {mps_score(sampled_frames, prompt_list)}")