nioushasadjadi
commited on
Commit
·
4e98ce2
1
Parent(s):
4a303bd
Changing the call function.
Browse files- tokenizer.py +20 -0
tokenizer.py
CHANGED
|
@@ -135,3 +135,23 @@ class KmerTokenizer(PreTrainedTokenizer):
|
|
| 135 |
|
| 136 |
# Instantiate the tokenizer with loaded values
|
| 137 |
return cls(vocab=vocab, k=k, stride=stride, max_len=max_len, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
# Instantiate the tokenizer with loaded values
|
| 137 |
return cls(vocab=vocab, k=k, stride=stride, max_len=max_len, **kwargs)
|
| 138 |
+
|
| 139 |
+
def __call__(self, text, padding=False, **kwargs):
|
| 140 |
+
token_ids = self.encode(text, padding=padding, **kwargs)
|
| 141 |
+
|
| 142 |
+
unk_token_id = self.vocab_dict.get("[UNK]")
|
| 143 |
+
attention_mask = [1 if id_ != unk_token_id else 0 for id_ in token_ids]
|
| 144 |
+
|
| 145 |
+
token_type_ids = [0] * len(token_ids)
|
| 146 |
+
|
| 147 |
+
# Convert to the specified tensor format
|
| 148 |
+
if kwargs.get('return_tensors') == 'pt':
|
| 149 |
+
attention_mask = torch.tensor(attention_mask)
|
| 150 |
+
token_type_ids = torch.tensor(token_type_ids)
|
| 151 |
+
|
| 152 |
+
# Return the output dictionary
|
| 153 |
+
return {
|
| 154 |
+
"input_ids": token_ids,
|
| 155 |
+
"token_type_ids": token_type_ids,
|
| 156 |
+
"attention_mask": attention_mask
|
| 157 |
+
}
|