Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +31 -15
modeling_esm_plusplus.py
CHANGED
|
@@ -249,7 +249,7 @@ class SwiGLU(nn.Module):
|
|
| 249 |
return F.silu(x1) * x2
|
| 250 |
|
| 251 |
|
| 252 |
-
def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
|
| 253 |
"""Create SwiGLU feedforward network with layer normalization."""
|
| 254 |
return nn.Sequential(
|
| 255 |
nn.LayerNorm(d_model),
|
|
@@ -257,6 +257,7 @@ def swiglu_ln_ffn(d_model: int, expansion_ratio: float) -> nn.Sequential:
|
|
| 257 |
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
|
| 258 |
),
|
| 259 |
SwiGLU(),
|
|
|
|
| 260 |
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
|
| 261 |
)
|
| 262 |
|
|
@@ -372,10 +373,11 @@ class UnifiedTransformerBlock(nn.Module):
|
|
| 372 |
n_heads: int,
|
| 373 |
residue_scaling_factor: float = 1,
|
| 374 |
expansion_ratio: float = 8 / 3,
|
|
|
|
| 375 |
):
|
| 376 |
super().__init__()
|
| 377 |
self.attn = MultiHeadAttention(d_model, n_heads)
|
| 378 |
-
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio)
|
| 379 |
self.scaling_factor = residue_scaling_factor
|
| 380 |
|
| 381 |
def forward(
|
|
@@ -435,6 +437,7 @@ class TransformerStack(nn.Module):
|
|
| 435 |
d_model: int,
|
| 436 |
n_heads: int,
|
| 437 |
n_layers: int,
|
|
|
|
| 438 |
):
|
| 439 |
super().__init__()
|
| 440 |
self.blocks = nn.ModuleList(
|
|
@@ -443,6 +446,7 @@ class TransformerStack(nn.Module):
|
|
| 443 |
d_model,
|
| 444 |
n_heads,
|
| 445 |
residue_scaling_factor=math.sqrt(n_layers / 36),
|
|
|
|
| 446 |
)
|
| 447 |
for i in range(n_layers)
|
| 448 |
]
|
|
@@ -517,7 +521,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 517 |
self.config = config
|
| 518 |
self.vocab_size = config.vocab_size
|
| 519 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
| 520 |
-
self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers)
|
| 521 |
self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
|
| 522 |
self.ce_loss = nn.CrossEntropyLoss()
|
| 523 |
self.tokenizer = EsmSequenceTokenizer()
|
|
@@ -649,25 +653,22 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 649 |
|
| 650 |
return embeddings_dict
|
| 651 |
|
| 652 |
-
"""
|
| 653 |
-
TODO
|
| 654 |
-
- Add dropout (default 0.0)
|
| 655 |
-
- Class method for returning manually computed attention maps
|
| 656 |
-
"""
|
| 657 |
-
|
| 658 |
def forward(
|
| 659 |
self,
|
| 660 |
input_ids: Optional[torch.Tensor] = None,
|
| 661 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 662 |
labels: Optional[torch.Tensor] = None,
|
| 663 |
-
|
| 664 |
-
|
|
|
|
| 665 |
) -> ESMplusplusOutput:
|
| 666 |
"""Forward pass for masked language modeling.
|
| 667 |
|
| 668 |
Args:
|
| 669 |
input_ids: Input token IDs
|
| 670 |
attention_mask: Attention mask
|
|
|
|
| 671 |
labels: Optional labels for masked tokens
|
| 672 |
output_hidden_states: Whether to return all hidden states
|
| 673 |
output_attentions: Whether to return attention weights
|
|
@@ -675,7 +676,10 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 675 |
Returns:
|
| 676 |
ESMplusplusOutput containing loss, logits, hidden states and attention weights
|
| 677 |
"""
|
| 678 |
-
|
|
|
|
|
|
|
|
|
|
| 679 |
output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
|
| 680 |
x = output.last_hidden_state
|
| 681 |
logits = self.sequence_head(x)
|
|
@@ -710,15 +714,18 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
| 710 |
self,
|
| 711 |
input_ids: Optional[torch.Tensor] = None,
|
| 712 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 713 |
labels: Optional[torch.Tensor] = None,
|
| 714 |
-
|
| 715 |
-
|
|
|
|
| 716 |
) -> ESMplusplusOutput:
|
| 717 |
"""Forward pass for sequence classification.
|
| 718 |
|
| 719 |
Args:
|
| 720 |
input_ids: Input token IDs
|
| 721 |
attention_mask: Attention mask
|
|
|
|
| 722 |
labels: Optional labels for classification
|
| 723 |
output_hidden_states: Whether to return all hidden states
|
| 724 |
output_attentions: Whether to return attention weights
|
|
@@ -729,7 +736,9 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
| 729 |
output = super().forward(
|
| 730 |
input_ids=input_ids,
|
| 731 |
attention_mask=attention_mask,
|
|
|
|
| 732 |
labels=None,
|
|
|
|
| 733 |
output_hidden_states=output_hidden_states
|
| 734 |
)
|
| 735 |
x = output.last_hidden_state
|
|
@@ -783,16 +792,21 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
|
|
| 783 |
self,
|
| 784 |
input_ids: Optional[torch.Tensor] = None,
|
| 785 |
attention_mask: Optional[torch.Tensor] = None,
|
|
|
|
| 786 |
labels: Optional[torch.Tensor] = None,
|
| 787 |
-
|
|
|
|
|
|
|
| 788 |
) -> ESMplusplusOutput:
|
| 789 |
"""Forward pass for token classification.
|
| 790 |
|
| 791 |
Args:
|
| 792 |
input_ids: Input token IDs
|
| 793 |
attention_mask: Attention mask
|
|
|
|
| 794 |
labels: Optional labels for token classification
|
| 795 |
output_hidden_states: Whether to return all hidden states
|
|
|
|
| 796 |
|
| 797 |
Returns:
|
| 798 |
ESMplusplusOutput containing loss, logits, and hidden states
|
|
@@ -800,7 +814,9 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
|
|
| 800 |
output = super().forward(
|
| 801 |
input_ids=input_ids,
|
| 802 |
attention_mask=attention_mask,
|
|
|
|
| 803 |
labels=None,
|
|
|
|
| 804 |
output_hidden_states=output_hidden_states
|
| 805 |
)
|
| 806 |
x = output.last_hidden_state
|
|
|
|
| 249 |
return F.silu(x1) * x2
|
| 250 |
|
| 251 |
|
| 252 |
+
def swiglu_ln_ffn(d_model: int, expansion_ratio: float, dropout: float = 0.0) -> nn.Sequential:
|
| 253 |
"""Create SwiGLU feedforward network with layer normalization."""
|
| 254 |
return nn.Sequential(
|
| 255 |
nn.LayerNorm(d_model),
|
|
|
|
| 257 |
d_model, swiglu_correction_fn(expansion_ratio, d_model) * 2, bias=False
|
| 258 |
),
|
| 259 |
SwiGLU(),
|
| 260 |
+
nn.Dropout(dropout),
|
| 261 |
nn.Linear(swiglu_correction_fn(expansion_ratio, d_model), d_model, bias=False),
|
| 262 |
)
|
| 263 |
|
|
|
|
| 373 |
n_heads: int,
|
| 374 |
residue_scaling_factor: float = 1,
|
| 375 |
expansion_ratio: float = 8 / 3,
|
| 376 |
+
dropout: float = 0.0,
|
| 377 |
):
|
| 378 |
super().__init__()
|
| 379 |
self.attn = MultiHeadAttention(d_model, n_heads)
|
| 380 |
+
self.ffn = swiglu_ln_ffn(d_model, expansion_ratio, dropout)
|
| 381 |
self.scaling_factor = residue_scaling_factor
|
| 382 |
|
| 383 |
def forward(
|
|
|
|
| 437 |
d_model: int,
|
| 438 |
n_heads: int,
|
| 439 |
n_layers: int,
|
| 440 |
+
dropout: float = 0.0,
|
| 441 |
):
|
| 442 |
super().__init__()
|
| 443 |
self.blocks = nn.ModuleList(
|
|
|
|
| 446 |
d_model,
|
| 447 |
n_heads,
|
| 448 |
residue_scaling_factor=math.sqrt(n_layers / 36),
|
| 449 |
+
dropout=dropout,
|
| 450 |
)
|
| 451 |
for i in range(n_layers)
|
| 452 |
]
|
|
|
|
| 521 |
self.config = config
|
| 522 |
self.vocab_size = config.vocab_size
|
| 523 |
self.embed = nn.Embedding(self.vocab_size, config.hidden_size)
|
| 524 |
+
self.transformer = TransformerStack(config.hidden_size, config.num_attention_heads, config.num_hidden_layers, config.dropout)
|
| 525 |
self.sequence_head = RegressionHead(config.hidden_size, self.vocab_size)
|
| 526 |
self.ce_loss = nn.CrossEntropyLoss()
|
| 527 |
self.tokenizer = EsmSequenceTokenizer()
|
|
|
|
| 653 |
|
| 654 |
return embeddings_dict
|
| 655 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 656 |
def forward(
|
| 657 |
self,
|
| 658 |
input_ids: Optional[torch.Tensor] = None,
|
| 659 |
attention_mask: Optional[torch.Tensor] = None,
|
| 660 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 661 |
labels: Optional[torch.Tensor] = None,
|
| 662 |
+
output_attentions: Optional[bool] = None,
|
| 663 |
+
output_hidden_states: Optional[bool] = None,
|
| 664 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
| 665 |
) -> ESMplusplusOutput:
|
| 666 |
"""Forward pass for masked language modeling.
|
| 667 |
|
| 668 |
Args:
|
| 669 |
input_ids: Input token IDs
|
| 670 |
attention_mask: Attention mask
|
| 671 |
+
inputs_embeds: Optional precomputed embeddings
|
| 672 |
labels: Optional labels for masked tokens
|
| 673 |
output_hidden_states: Whether to return all hidden states
|
| 674 |
output_attentions: Whether to return attention weights
|
|
|
|
| 676 |
Returns:
|
| 677 |
ESMplusplusOutput containing loss, logits, hidden states and attention weights
|
| 678 |
"""
|
| 679 |
+
if inputs_embeds is None:
|
| 680 |
+
x = self.embed(input_ids)
|
| 681 |
+
else:
|
| 682 |
+
x = inputs_embeds
|
| 683 |
output = self.transformer(x, attention_mask, output_hidden_states, output_attentions)
|
| 684 |
x = output.last_hidden_state
|
| 685 |
logits = self.sequence_head(x)
|
|
|
|
| 714 |
self,
|
| 715 |
input_ids: Optional[torch.Tensor] = None,
|
| 716 |
attention_mask: Optional[torch.Tensor] = None,
|
| 717 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 718 |
labels: Optional[torch.Tensor] = None,
|
| 719 |
+
output_attentions: Optional[bool] = None,
|
| 720 |
+
output_hidden_states: Optional[bool] = None,
|
| 721 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
| 722 |
) -> ESMplusplusOutput:
|
| 723 |
"""Forward pass for sequence classification.
|
| 724 |
|
| 725 |
Args:
|
| 726 |
input_ids: Input token IDs
|
| 727 |
attention_mask: Attention mask
|
| 728 |
+
inputs_embeds: Optional precomputed embeddings
|
| 729 |
labels: Optional labels for classification
|
| 730 |
output_hidden_states: Whether to return all hidden states
|
| 731 |
output_attentions: Whether to return attention weights
|
|
|
|
| 736 |
output = super().forward(
|
| 737 |
input_ids=input_ids,
|
| 738 |
attention_mask=attention_mask,
|
| 739 |
+
inputs_embeds=inputs_embeds,
|
| 740 |
labels=None,
|
| 741 |
+
output_attentions=output_attentions,
|
| 742 |
output_hidden_states=output_hidden_states
|
| 743 |
)
|
| 744 |
x = output.last_hidden_state
|
|
|
|
| 792 |
self,
|
| 793 |
input_ids: Optional[torch.Tensor] = None,
|
| 794 |
attention_mask: Optional[torch.Tensor] = None,
|
| 795 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 796 |
labels: Optional[torch.Tensor] = None,
|
| 797 |
+
output_attentions: Optional[bool] = None,
|
| 798 |
+
output_hidden_states: Optional[bool] = None,
|
| 799 |
+
return_dict: Optional[bool] = None, # to play nice with HF adjacent packages
|
| 800 |
) -> ESMplusplusOutput:
|
| 801 |
"""Forward pass for token classification.
|
| 802 |
|
| 803 |
Args:
|
| 804 |
input_ids: Input token IDs
|
| 805 |
attention_mask: Attention mask
|
| 806 |
+
inputs_embeds: Optional precomputed embeddings
|
| 807 |
labels: Optional labels for token classification
|
| 808 |
output_hidden_states: Whether to return all hidden states
|
| 809 |
+
output_attentions: Whether to return attention weights
|
| 810 |
|
| 811 |
Returns:
|
| 812 |
ESMplusplusOutput containing loss, logits, and hidden states
|
|
|
|
| 814 |
output = super().forward(
|
| 815 |
input_ids=input_ids,
|
| 816 |
attention_mask=attention_mask,
|
| 817 |
+
inputs_embeds=inputs_embeds,
|
| 818 |
labels=None,
|
| 819 |
+
output_attentions=output_attentions,
|
| 820 |
output_hidden_states=output_hidden_states
|
| 821 |
)
|
| 822 |
x = output.last_hidden_state
|