Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +19 -3
modeling_esm_plusplus.py
CHANGED
|
@@ -339,9 +339,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 339 |
|
| 340 |
|
| 341 |
### Regression Head
|
| 342 |
-
def RegressionHead(
|
| 343 |
-
d_model: int, output_dim: int, hidden_dim: Optional[int] = None
|
| 344 |
-
) -> nn.Module:
|
| 345 |
"""Create a regression head with optional hidden dimension.
|
| 346 |
|
| 347 |
Args:
|
|
@@ -707,6 +705,12 @@ class ESMplusplusModel(PreTrainedESMplusplusModel):
|
|
| 707 |
self.tokenizer = EsmSequenceTokenizer()
|
| 708 |
self.init_weights()
|
| 709 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
def forward(
|
| 711 |
self,
|
| 712 |
input_ids: Optional[torch.Tensor] = None,
|
|
@@ -752,6 +756,18 @@ class ESMplusplusForMaskedLM(PreTrainedESMplusplusModel):
|
|
| 752 |
self.tokenizer = EsmSequenceTokenizer()
|
| 753 |
self.init_weights()
|
| 754 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
def forward(
|
| 756 |
self,
|
| 757 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 339 |
|
| 340 |
|
| 341 |
### Regression Head
|
| 342 |
+
def RegressionHead(d_model: int, output_dim: int, hidden_dim: Optional[int] = None) -> nn.Module:
|
|
|
|
|
|
|
| 343 |
"""Create a regression head with optional hidden dimension.
|
| 344 |
|
| 345 |
Args:
|
|
|
|
| 705 |
self.tokenizer = EsmSequenceTokenizer()
|
| 706 |
self.init_weights()
|
| 707 |
|
| 708 |
+
def get_input_embeddings(self):
|
| 709 |
+
return self.embed
|
| 710 |
+
|
| 711 |
+
def set_input_embeddings(self, value):
|
| 712 |
+
self.embed = value
|
| 713 |
+
|
| 714 |
def forward(
|
| 715 |
self,
|
| 716 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 756 |
self.tokenizer = EsmSequenceTokenizer()
|
| 757 |
self.init_weights()
|
| 758 |
|
| 759 |
+
def get_input_embeddings(self):
|
| 760 |
+
return self.embed
|
| 761 |
+
|
| 762 |
+
def set_input_embeddings(self, value):
|
| 763 |
+
self.embed = value
|
| 764 |
+
|
| 765 |
+
def get_output_embeddings(self):
|
| 766 |
+
return self.sequence_head[-1]
|
| 767 |
+
|
| 768 |
+
def set_output_embeddings(self, new_embeddings):
|
| 769 |
+
self.sequence_head[-1] = new_embeddings
|
| 770 |
+
|
| 771 |
def forward(
|
| 772 |
self,
|
| 773 |
input_ids: Optional[torch.Tensor] = None,
|