lhallee commited on
Commit
bdb4649
·
verified ·
1 Parent(s): 11333ff

Upload modeling_esm_plusplus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_esm_plusplus.py +7 -3
modeling_esm_plusplus.py CHANGED
@@ -762,8 +762,7 @@ class EmbeddingMixin:
762
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
763
  if full_embeddings:
764
  emb = emb[mask.bool()].reshape(-1, hidden_size)
765
- c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)",
766
- (seq, emb.cpu().numpy().tobytes()))
767
 
768
  if (i + 1) % 100 == 0:
769
  conn.commit()
@@ -979,7 +978,12 @@ class ESMplusplusForSequenceClassification(ESMplusplusForMaskedLM, EmbeddingMixi
979
  self.mse = nn.MSELoss()
980
  self.ce = nn.CrossEntropyLoss()
981
  self.bce = nn.BCEWithLogitsLoss()
982
- self.pooler = Pooler(['cls','mean'])
 
 
 
 
 
983
  self.init_weights()
984
 
985
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
 
762
  for seq, emb, mask in zip(seqs, embeddings, attention_mask):
763
  if full_embeddings:
764
  emb = emb[mask.bool()].reshape(-1, hidden_size)
765
+ c.execute("INSERT OR REPLACE INTO embeddings VALUES (?, ?)", (seq, emb.cpu().numpy().tobytes()))
 
766
 
767
  if (i + 1) % 100 == 0:
768
  conn.commit()
 
978
  self.mse = nn.MSELoss()
979
  self.ce = nn.CrossEntropyLoss()
980
  self.bce = nn.BCEWithLogitsLoss()
981
+ # if kwargs has pooling_types, use them, otherwise use ['cls', 'mean']
982
+ if 'pooling_types' in kwargs and isinstance(kwargs['pooling_types'], List[str]) and len(kwargs['pooling_types']) > 0:
983
+ pooling_types = kwargs['pooling_types']
984
+ else:
985
+ pooling_types = ['cls', 'mean']
986
+ self.pooler = Pooler(pooling_types)
987
  self.init_weights()
988
 
989
  def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: