Spaces:
Running
on
Zero
Running
on
Zero
| from dataclasses import dataclass | |
| from transformers import CLIPModel as HFCLIPModel | |
| from transformers import AutoTokenizer | |
| from torch import nn, einsum | |
| # Modified: import | |
| # from trainer.models.base_model import BaseModelConfig | |
| from .base_model import BaseModelConfig | |
| from transformers import CLIPConfig | |
| from typing import Any, Optional, Tuple, Union | |
| import torch | |
| # Modified: import | |
| # from trainer.models.cross_modeling import Cross_model | |
| from .cross_modeling import Cross_model | |
| import gc | |
| class XCLIPModel(HFCLIPModel): | |
| def __init__(self, config: CLIPConfig): | |
| super().__init__(config) | |
| def get_text_features( | |
| self, | |
| input_ids: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> torch.FloatTensor: | |
| # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| text_outputs = self.text_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # pooled_output = text_outputs[1] | |
| # text_features = self.text_projection(pooled_output) | |
| last_hidden_state = text_outputs[0] | |
| text_features = self.text_projection(last_hidden_state) | |
| pooled_output = text_outputs[1] | |
| text_features_EOS = self.text_projection(pooled_output) | |
| # del last_hidden_state, text_outputs | |
| # gc.collect() | |
| return text_features, text_features_EOS | |
| def get_image_features( | |
| self, | |
| pixel_values: Optional[torch.FloatTensor] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> torch.FloatTensor: | |
| # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| vision_outputs = self.vision_model( | |
| pixel_values=pixel_values, | |
| output_attentions=output_attentions, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| ) | |
| # pooled_output = vision_outputs[1] # pooled_output | |
| # image_features = self.visual_projection(pooled_output) | |
| last_hidden_state = vision_outputs[0] | |
| image_features = self.visual_projection(last_hidden_state) | |
| return image_features | |
| class ClipModelConfig(BaseModelConfig): | |
| _target_: str = "trainer.models.clip_model.CLIPModel" | |
| pretrained_model_name_or_path: str ="openai/clip-vit-base-patch32" | |
| class CLIPModel(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| # Modified: We convert the original ckpt (contains the entire model) to a `state_dict`. | |
| # self.model = XCLIPModel.from_pretrained(ckpt) | |
| self.model = XCLIPModel(config) | |
| self.cross_model = Cross_model(dim=1024, layer_num=4, heads=16) | |
| def get_text_features(self, *args, **kwargs): | |
| return self.model.get_text_features(*args, **kwargs) | |
| def get_image_features(self, *args, **kwargs): | |
| return self.model.get_image_features(*args, **kwargs) | |
| def forward(self, text_inputs=None, image_inputs=None, condition_inputs=None): | |
| outputs = () | |
| text_f, text_EOS = self.model.get_text_features(text_inputs) # B*77*1024 | |
| outputs += text_EOS, | |
| image_f = self.model.get_image_features(image_inputs.half()) # 2B*257*1024 | |
| # [B, 77, 1024] | |
| condition_f, _ = self.model.get_text_features(condition_inputs) # B*5*1024 | |
| sim_text_condition = einsum('b i d, b j d -> b j i', text_f, condition_f) | |
| sim_text_condition = torch.max(sim_text_condition, dim=1, keepdim=True)[0] | |
| sim_text_condition = sim_text_condition / sim_text_condition.max() | |
| mask = torch.where(sim_text_condition > 0.01, 0, float('-inf')) # B*1*77 | |
| # Modified: Support both torch.float16 and torch.bfloat16 | |
| # mask = mask.repeat(1,image_f.shape[1],1) # B*257*77 | |
| model_dtype = next(self.cross_model.parameters()).dtype | |
| mask = mask.repeat(1,image_f.shape[1],1).to(model_dtype) # B*257*77 | |
| # bc = int(image_f.shape[0]/2) | |
| # Modified: The original input consists of a (batch of) text and two (batches of) images, | |
| # primarily used to compute which (batch of) image is more consistent with the text. | |
| # The modified input consists of a (batch of) text and a (batch of) images. | |
| # sim0 = self.cross_model(image_f[:bc,:,:], text_f,mask.half()) | |
| # sim1 = self.cross_model(image_f[bc:,:,:], text_f,mask.half()) | |
| # outputs += sim0[:,0,:], | |
| # outputs += sim1[:,0,:], | |
| sim = self.cross_model(image_f, text_f,mask) | |
| outputs += sim[:,0,:], | |
| return outputs | |
| def logit_scale(self): | |
| return self.model.logit_scale | |
| def save(self, path): | |
| self.model.save_pretrained(path) | |