Spaces:
Paused
Paused
| import torch | |
| from transformers import TextStreamer | |
| import os | |
| import sys | |
| sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Evaluation")) | |
| from llava.constants import IMAGE_TOKEN_INDEX | |
| from llava.conversation import conv_templates, SeparatorStyle | |
| from llava.mm_utils import get_model_name_from_path, KeywordsStoppingCriteria, tokenizer_image_token | |
| from llava.model.builder import load_pretrained_model | |
| from llava.utils import disable_torch_init | |
| import shutil | |
| cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
| title_markdown = (""" | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <div> | |
| <h1 >VLM-RLAIF: Tuning Large Multimodal Models for Videos using Reinforcement Learning from AI Feedback (ACL 2024 Oral) </h1> | |
| <h5 style="margin: 0;">If you like our project, please give us a star ✨ on Github for the latest update.</h5> | |
| </div> | |
| </div> | |
| <div align="center"> | |
| <div style="display:flex; gap: 0.25rem;" align="center"> | |
| <a href='https://github.com/yonseivnl/vlm-rlaif'><img src='https://img.shields.io/badge/Github-Code-blue'></a> | |
| <a href="https://arxiv.org/abs/2402.03746"><img src="https://img.shields.io/badge/Paper-arxiv-green"></a> | |
| </div> | |
| </div> | |
| """) | |
| block_css = """ | |
| #buttons button { | |
| min-width: min(120px,100%); | |
| } | |
| """ | |
| tos_markdown = ("""""") | |
| learn_more_markdown = (""" | |
| ### License | |
| The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA. | |
| """) | |
| class Chat: | |
| def __init__(self, model_path, conv_mode, model_base=None, load_8bit=False, load_4bit=False, device='cuda', cache_dir=None): | |
| disable_torch_init() | |
| model_name = get_model_name_from_path(model_path) | |
| is_rlhf_checkpoint = 'rlhf' in model_path.lower() | |
| print("MODEL_PATH", model_path) | |
| print("RLHF Checkpoint: ", is_rlhf_checkpoint) | |
| if not model_base or model_base == "none": model_base = None | |
| if is_rlhf_checkpoint: | |
| model_name = model_path | |
| print("Config?", os.path.exists(os.path.join(model_path, "config.json"))) | |
| if not os.path.exists(os.path.join(model_path, "config.json")): | |
| print("Copying") | |
| shutil.copy(os.path.join(model_base, "config.json"), os.path.join(model_path, "config.json")) # Copy SFT model's config -> to RLHF folder | |
| print("Listed", os.listdir(model_path)) | |
| print("Copying done") | |
| self.tokenizer, self.model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, False, False, device=device) | |
| self.image_processor = image_processor | |
| self.conv_mode = conv_mode | |
| self.conv = conv_templates[conv_mode].copy() | |
| self.device = self.model.device | |
| print(self.model) | |
| def get_prompt(self, qs, state): | |
| state.append_message(state.roles[0], qs) | |
| state.append_message(state.roles[1], None) | |
| return state | |
| def _get_latest_prompt(self, state): | |
| new_state = state.copy() | |
| new_state.messages = state.messages[-2:] | |
| return new_state | |
| # def generate(self, images_tensor: list, prompt: str, first_run: bool, state): | |
| def generate(self, images_tensor: torch.Tensor, prompt: str, first_run: bool, state): | |
| tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor | |
| state = self.get_prompt(prompt, state) | |
| # prompt = state.get_prompt() | |
| latest_state = self._get_latest_prompt(state) | |
| prompt = latest_state.get_prompt() | |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device) | |
| temperature = 0.2 | |
| max_new_tokens = 1024 | |
| stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 | |
| keywords = [stop_str] | |
| stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) | |
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| print(prompt, input_ids.shape, images_tensor.shape) | |
| # print(images_tensor) | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids, | |
| images=images_tensor, | |
| do_sample=True, | |
| temperature=temperature, | |
| max_new_tokens=max_new_tokens, | |
| streamer=streamer, | |
| use_cache=True, | |
| stopping_criteria=[stopping_criteria]) | |
| input_token_len = input_ids.shape[1] | |
| n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() | |
| if n_diff_input_output > 0: | |
| print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') | |
| outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] | |
| outputs = outputs.strip() | |
| outputs = outputs.replace("QA_GT_caption_based_noisy", "") | |
| if outputs.endswith(stop_str): | |
| outputs = outputs[:-len(stop_str)] | |
| outputs = outputs.strip() | |
| print('response', outputs) | |
| return outputs, state | |