Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +2 -2
modeling_esm_plusplus.py
CHANGED
|
@@ -647,7 +647,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 647 |
if len(to_embed) > 0:
|
| 648 |
with torch.no_grad():
|
| 649 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 650 |
-
seqs =
|
| 651 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 652 |
x = self.embed(input_ids)
|
| 653 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
|
@@ -665,7 +665,7 @@ class PreTrainedESMplusplusModel(PreTrainedModel):
|
|
| 665 |
conn.commit()
|
| 666 |
conn.close()
|
| 667 |
return None
|
| 668 |
-
|
| 669 |
embeddings_dict = {}
|
| 670 |
with torch.no_grad():
|
| 671 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
|
|
|
| 647 |
if len(to_embed) > 0:
|
| 648 |
with torch.no_grad():
|
| 649 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|
| 650 |
+
seqs = to_embed[i * batch_size:(i + 1) * batch_size]
|
| 651 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 652 |
x = self.embed(input_ids)
|
| 653 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.detach().float() # required for sql
|
|
|
|
| 665 |
conn.commit()
|
| 666 |
conn.close()
|
| 667 |
return None
|
| 668 |
+
|
| 669 |
embeddings_dict = {}
|
| 670 |
with torch.no_grad():
|
| 671 |
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), desc='Embedding batches'):
|