Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2023, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import contextlib | |
| import logging | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from lavis.common.dist_utils import download_cached_file | |
| from lavis.common.utils import is_url | |
| from lavis.models.base_model import BaseModel | |
| from lavis.models.blip2_models.Qformer import BertConfig, BertLMHeadModel | |
| from transformers import BertTokenizer | |
| from model.gin_model import GNN | |
| class Blip2Base(BaseModel): | |
| def init_tokenizer(cls): | |
| if True: | |
| bert_name = 'allenai/scibert_scivocab_uncased' | |
| else: | |
| bert_name = 'bert_pretrained/' | |
| tokenizer = BertTokenizer.from_pretrained(bert_name) | |
| tokenizer.add_special_tokens({"bos_token": "[DEC]"}) | |
| return tokenizer | |
| def maybe_autocast(self, dtype=torch.float16): | |
| # if on cpu, don't use autocast | |
| # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 | |
| enable_autocast = self.device != torch.device("cpu") | |
| if enable_autocast: | |
| return torch.cuda.amp.autocast(dtype=dtype) | |
| else: | |
| return contextlib.nullcontext() | |
| def init_Qformer(cls, model_name, num_query_token, graph_width, cross_attention_freq=2): | |
| assert model_name == 'scibert' | |
| print("bert load scibert") | |
| if True: | |
| bert_name = 'allenai/scibert_scivocab_uncased' | |
| else: | |
| bert_name = 'bert_pretrained/' | |
| encoder_config = BertConfig.from_pretrained(bert_name) | |
| encoder_config.encoder_width = graph_width | |
| # insert cross-attention layer every other block | |
| encoder_config.add_cross_attention = True | |
| encoder_config.cross_attention_freq = cross_attention_freq | |
| encoder_config.query_length = num_query_token | |
| Qformer = BertLMHeadModel.from_pretrained( | |
| bert_name, config=encoder_config | |
| ) | |
| query_tokens = nn.Parameter( | |
| torch.zeros(1, num_query_token, encoder_config.hidden_size) | |
| ) | |
| query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) | |
| return Qformer, query_tokens | |
| def init_graph_encoder( | |
| cls, gin_num_layers, gin_hidden_dim, gin_drop_ratio): | |
| graph_encoder = GNN( | |
| num_layer=gin_num_layers, | |
| emb_dim=gin_hidden_dim, | |
| gnn_type='gin', | |
| drop_ratio=gin_drop_ratio, | |
| JK='last', | |
| ) | |
| ckpt = torch.load('gin_pretrained/graphcl_80.pth', map_location=torch.device('cpu')) | |
| missing_keys, unexpected_keys = graph_encoder.load_state_dict(ckpt, strict=False) | |
| if len(missing_keys) or len(unexpected_keys): | |
| print(missing_keys) | |
| print(unexpected_keys) | |
| ln_graph = LayerNorm(graph_encoder.num_features) | |
| return graph_encoder, ln_graph | |
| def load_from_pretrained(self, url_or_filename): | |
| if is_url(url_or_filename): | |
| cached_file = download_cached_file( | |
| url_or_filename, check_hash=False, progress=True | |
| ) | |
| checkpoint = torch.load(cached_file, map_location="cpu") | |
| elif os.path.isfile(url_or_filename): | |
| checkpoint = torch.load(url_or_filename, map_location="cpu") | |
| else: | |
| raise RuntimeError("checkpoint url or path is invalid") | |
| state_dict = checkpoint["model"] | |
| msg = self.load_state_dict(state_dict, strict=False) | |
| # logging.info("Missing keys {}".format(msg.missing_keys)) | |
| logging.info("load checkpoint from %s" % url_or_filename) | |
| return msg | |
| def disabled_train(self, mode=True): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| class LayerNorm(nn.LayerNorm): | |
| """Subclass torch's LayerNorm to handle fp16.""" | |
| def forward(self, x: torch.Tensor, mask=None): | |
| orig_type = x.dtype | |
| # ret = super().forward(x.type(torch.float32)) | |
| ret = super().forward(x) | |
| return ret.type(orig_type) | |