Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +12 -2
modeling_esm_plusplus.py
CHANGED
|
@@ -669,7 +669,12 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM):
|
|
| 669 |
Returns:
|
| 670 |
ESMplusplusOutput containing loss, logits, and hidden states
|
| 671 |
"""
|
| 672 |
-
output = super().forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 673 |
x = output.last_hidden_state
|
| 674 |
cls_features = x[:, 0, :]
|
| 675 |
mean_features = self.mean_pooling(x, attention_mask)
|
|
@@ -735,7 +740,12 @@ class ESMplusplusForTokenClassification(ESMplusplusForMaskedLM):
|
|
| 735 |
Returns:
|
| 736 |
ESMplusplusOutput containing loss, logits, and hidden states
|
| 737 |
"""
|
| 738 |
-
output = super().forward(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 739 |
x = output.last_hidden_state
|
| 740 |
logits = self.classifier(x)
|
| 741 |
loss = None
|
|
|
|
| 669 |
Returns:
|
| 670 |
ESMplusplusOutput containing loss, logits, and hidden states
|
| 671 |
"""
|
| 672 |
+
output = super().forward(
|
| 673 |
+
input_ids=input_ids,
|
| 674 |
+
attention_mask=attention_mask,
|
| 675 |
+
labels=None,
|
| 676 |
+
output_hidden_states=output_hidden_states
|
| 677 |
+
)
|
| 678 |
x = output.last_hidden_state
|
| 679 |
cls_features = x[:, 0, :]
|
| 680 |
mean_features = self.mean_pooling(x, attention_mask)
|
|
|
|
| 740 |
Returns:
|
| 741 |
ESMplusplusOutput containing loss, logits, and hidden states
|
| 742 |
"""
|
| 743 |
+
output = super().forward(
|
| 744 |
+
input_ids=input_ids,
|
| 745 |
+
attention_mask=attention_mask,
|
| 746 |
+
labels=None,
|
| 747 |
+
output_hidden_states=output_hidden_states
|
| 748 |
+
)
|
| 749 |
x = output.last_hidden_state
|
| 750 |
logits = self.classifier(x)
|
| 751 |
loss = None
|