Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +7 -3
modeling_esm_plusplus.py
CHANGED
|
@@ -762,8 +762,7 @@ class EmbeddingMixin:
|
|
| 762 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 763 |
if full_embeddings:
|
| 764 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
| 765 |
-
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
| 766 |
-
(seq, emb.cpu().numpy().tobytes()))
|
| 767 |
|
| 768 |
if (i + 1) % 100 == 0:
|
| 769 |
conn.commit()
|
|
@@ -979,7 +978,12 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
|
|
| 979 |
self.mse = nn.MSELoss()
|
| 980 |
self.ce = nn.CrossEntropyLoss()
|
| 981 |
self.bce = nn.BCEWithLogitsLoss()
|
| 982 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 983 |
self.init_weights()
|
| 984 |
|
| 985 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
| 762 |
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 763 |
if full_embeddings:
|
| 764 |
emb = emb[mask.bool()].reshape(-1, hidden_size)
|
| 765 |
+
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes()))
|
|
|
|
| 766 |
|
| 767 |
if (i + 1) % 100 == 0:
|
| 768 |
conn.commit()
|
|
|
|
| 978 |
self.mse = nn.MSELoss()
|
| 979 |
self.ce = nn.CrossEntropyLoss()
|
| 980 |
self.bce = nn.BCEWithLogitsLoss()
|
| 981 |
+
# if kwargs has pooling_types, use them, otherwise use ['cls', 'mean']
|
| 982 |
+
if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
|
| 983 |
+
pooling_types = kwargs['pooling_types']
|
| 984 |
+
else:
|
| 985 |
+
pooling_types = ['cls', 'mean']
|
| 986 |
+
self.pooler = Pooler(pooling_types)
|
| 987 |
self.init_weights()
|
| 988 |
|
| 989 |
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|