Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +4 -2
modeling_esm_plusplus.py
CHANGED
|
@@ -625,7 +625,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 625 |
|
| 626 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 627 |
if full_embeddings:
|
| 628 |
-
return residue_embeddings
|
| 629 |
elif pooling_type == 'mean':
|
| 630 |
return self.mean_pooling(residue_embeddings, attention_mask)
|
| 631 |
elif pooling_type == 'max':
|
|
@@ -653,7 +653,9 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 653 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
| 654 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 655 |
|
| 656 |
-
for seq, emb in zip(seqs, embeddings):
|
|
|
|
|
|
|
| 657 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
| 658 |
(seq, emb.cpu().numpy().tobytes()))
|
| 659 |
|
|
|
|
| 625 |
|
| 626 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 627 |
if full_embeddings:
|
| 628 |
+
return residue_embeddings
|
| 629 |
elif pooling_type == 'mean':
|
| 630 |
return self.mean_pooling(residue_embeddings, attention_mask)
|
| 631 |
elif pooling_type == 'max':
|
|
|
|
| 653 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
| 654 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 655 |
|
| 656 |
+
for seq, emb, mask in zip(seqs, embeddings, attention_mask):
|
| 657 |
+
if full_embeddings:
|
| 658 |
+
emb = emb[mask.bool()]
|
| 659 |
c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
|
| 660 |
(seq, emb.cpu().numpy().tobytes()))
|
| 661 |
|