Update modeling_esm_plusplus.py
Browse files- modeling_esm_plusplus.py +9 -7
modeling_esm_plusplus.py
CHANGED
|
@@ -537,10 +537,10 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 537 |
batch_size: Batch size for processing
|
| 538 |
max_len: Maximum sequence length
|
| 539 |
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
|
| 540 |
-
full_precision: Whether to cast to full precision (float32) before storage
|
| 541 |
pooling_type: Type of pooling ('mean' or 'cls')
|
| 542 |
num_workers: Number of workers for data loading, 0 for the main process
|
| 543 |
-
sql: Whether to store embeddings in SQLite database
|
| 544 |
sql_db_path: Path to SQLite database
|
| 545 |
|
| 546 |
Returns:
|
|
@@ -553,12 +553,12 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 553 |
device = self.device
|
| 554 |
|
| 555 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 556 |
-
if full_precision:
|
| 557 |
-
residue_embeddings = residue_embeddings.float()
|
| 558 |
if full_embeddings:
|
| 559 |
return residue_embeddings
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
| 562 |
|
| 563 |
if sql:
|
| 564 |
import sqlite3
|
|
@@ -575,7 +575,7 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 575 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 576 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 577 |
x = self.embed(input_ids)
|
| 578 |
-
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
|
| 579 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 580 |
|
| 581 |
for seq, emb in zip(seqs, embeddings):
|
|
@@ -596,6 +596,8 @@ class ESMplusplusForMaskedLM(PreTrainedModel):
|
|
| 596 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 597 |
x = self.embed(input_ids)
|
| 598 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
|
|
|
|
|
|
|
| 599 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 600 |
for seq, emb in zip(seqs, embeddings):
|
| 601 |
embeddings_dict[seq] = emb
|
|
|
|
| 537 |
batch_size: Batch size for processing
|
| 538 |
max_len: Maximum sequence length
|
| 539 |
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
|
| 540 |
+
full_precision: Whether to cast to full precision (float32) before storage - relevant for dict storage
|
| 541 |
pooling_type: Type of pooling ('mean' or 'cls')
|
| 542 |
num_workers: Number of workers for data loading, 0 for the main process
|
| 543 |
+
sql: Whether to store embeddings in SQLite database - will be stored in float32
|
| 544 |
sql_db_path: Path to SQLite database
|
| 545 |
|
| 546 |
Returns:
|
|
|
|
| 553 |
device = self.device
|
| 554 |
|
| 555 |
def get_embeddings(residue_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
|
|
|
|
|
| 556 |
if full_embeddings:
|
| 557 |
return residue_embeddings
|
| 558 |
+
elif pooling_type == 'mean':
|
| 559 |
+
return self.mean_pooling(residue_embeddings, attention_mask)
|
| 560 |
+
else:
|
| 561 |
+
return residue_embeddings[:, 0, :]
|
| 562 |
|
| 563 |
if sql:
|
| 564 |
import sqlite3
|
|
|
|
| 575 |
seqs = sequences[i * batch_size:(i + 1) * batch_size]
|
| 576 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 577 |
x = self.embed(input_ids)
|
| 578 |
+
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state.float() # required for sql
|
| 579 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 580 |
|
| 581 |
for seq, emb in zip(seqs, embeddings):
|
|
|
|
| 596 |
input_ids, attention_mask = batch['input_ids'].to(device), batch['attention_mask'].to(device)
|
| 597 |
x = self.embed(input_ids)
|
| 598 |
residue_embeddings = self.transformer(x, attention_mask).last_hidden_state
|
| 599 |
+
if full_precision:
|
| 600 |
+
residue_embeddings = residue_embeddings.float()
|
| 601 |
embeddings = get_embeddings(residue_embeddings, attention_mask)
|
| 602 |
for seq, emb in zip(seqs, embeddings):
|
| 603 |
embeddings_dict[seq] = emb
|