| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from collections import OrderedDict |
| from transformers.modeling_outputs import SequenceClassifierOutput |
| from typing import List, Optional, Tuple, Union |
| from .configuration import MultiLabelClassifierConfig |
|
|
| class MultiLabelClassifierModel(PreTrainedModel): |
| config_class = MultiLabelClassifierConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| |
| self.nlp_model = torch.hub.load('huggingface/pytorch-transformers', 'model', config.transformer_name) |
| self.rnn = nn.GRU(config.embedding_dim, |
| config.hidden_dim, |
| num_layers = config.num_layers, |
| bidirectional = config.bidirectional, |
| batch_first = True, |
| dropout = 0 if config.num_layers < 2 else config.dropout) |
| self.dropout = nn.Dropout(config.dropout) |
| self.out = nn.Linear(config.hidden_dim * 2 if config.bidirectional else config.hidden_dim, config.num_classes) |
|
|
| def forward(self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| )-> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
| output = self.nlp_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| token_type_ids=token_type_ids, |
| position_ids=position_ids, |
| head_mask=head_mask, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| _, hidden = self.rnn(output['last_hidden_state']) |
| if self.rnn.bidirectional: |
| hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)) |
| else: |
| hidden = self.dropout(hidden[-1,:,:]) |
|
|
| logits = self.out(hidden) |
| return SequenceClassifierOutput( |
| logits=logits, |
| hidden_states=output.hidden_states, |
| attentions=output.attentions, |
| ) |