gbyuvd commited on
Commit
42c2638
·
verified ·
1 Parent(s): ad6a9f0

Proper full HF Compat

Browse files
Files changed (1) hide show
  1. FastChemTokenizerHF.py +539 -769
FastChemTokenizerHF.py CHANGED
@@ -1,769 +1,539 @@
1
- import torch
2
- import json
3
- import os
4
- from typing import List, Union, Optional, Tuple, Dict, Any
5
- from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
6
- from transformers.utils import PaddingStrategy, TensorType
7
- from functools import lru_cache
8
-
9
-
10
- class TrieNode:
11
- __slots__ = ['children', 'token_id']
12
- def __init__(self):
13
- self.children = {}
14
- self.token_id = None # If set, this node completes a valid token
15
-
16
-
17
- class FastChemTokenizer(PreTrainedTokenizerBase):
18
- """
19
- Fully HuggingFace API compatible tokenizer for chemical representations.
20
- """
21
-
22
- vocab_files_names = {"vocab_file": "vocab.json"}
23
-
24
- def __init__(
25
- self,
26
- token_to_id=None,
27
- vocab_file=None,
28
- model_max_length=512,
29
- padding_side="right",
30
- truncation_side="right",
31
- chat_template=None,
32
- **kwargs
33
- ):
34
- # Handle vocab loading
35
- if token_to_id is None and vocab_file is None:
36
- raise ValueError("Either token_to_id or vocab_file must be provided")
37
-
38
- if vocab_file is not None:
39
- with open(vocab_file, "r", encoding="utf-8") as f:
40
- token_to_id = json.load(f)
41
- token_to_id = {str(k): int(v) for k, v in token_to_id.items()}
42
-
43
- self.token_to_id = token_to_id
44
- self.id_to_token = {v: k for k, v in token_to_id.items()}
45
-
46
- # Precompute max token length for possible use & clarity
47
- self.max_token_len = max(len(t) for t in token_to_id.keys()) if token_to_id else 0
48
-
49
- # Build trie for fast longest-match lookup
50
- self.trie_root = self._build_trie(token_to_id)
51
-
52
- # Validate required special tokens
53
- required_special_tokens = ["<s>", "</s>", "<pad>", "<unk>", "<mask>"]
54
- for tok in required_special_tokens:
55
- if tok not in token_to_id:
56
- raise KeyError(f"Required special token '{tok}' not found in vocab.")
57
-
58
- # Assign special token IDs explicitly
59
- self.bos_token_id = token_to_id["<s>"]
60
- self.eos_token_id = token_to_id["</s>"]
61
- self.pad_token_id = token_to_id["<pad>"]
62
- self.unk_token_id = token_to_id["<unk>"]
63
- self.mask_token_id = token_to_id["<mask>"]
64
-
65
- # Special tokens
66
- bos_token = "<s>"
67
- eos_token = "</s>"
68
- pad_token = "<pad>"
69
- unk_token = "<unk>"
70
- mask_token = "<mask>"
71
-
72
- # Initialize parent class with all required parameters
73
- super().__init__(
74
- bos_token=bos_token,
75
- eos_token=eos_token,
76
- unk_token=unk_token,
77
- sep_token=None,
78
- pad_token=pad_token,
79
- cls_token=None,
80
- mask_token=mask_token,
81
- additional_special_tokens=[],
82
- model_max_length=model_max_length,
83
- padding_side=padding_side,
84
- truncation_side=truncation_side,
85
- chat_template=chat_template,
86
- **kwargs,
87
- )
88
-
89
- def _build_trie(self, token_to_id):
90
- root = TrieNode()
91
- for token, tid in token_to_id.items():
92
- node = root
93
- for char in token:
94
- if char not in node.children:
95
- node.children[char] = TrieNode()
96
- node = node.children[char]
97
- node.token_id = tid
98
- return root
99
-
100
- @property
101
- def vocab_size(self):
102
- return len(self.token_to_id)
103
-
104
- def __len__(self):
105
- return len(self.token_to_id)
106
-
107
- def get_vocab(self) -> Dict[str, int]:
108
- return self.token_to_id.copy()
109
-
110
- @lru_cache(maxsize=10000)
111
- def _cached_encode_str(self, s: str) -> Tuple[int, ...]:
112
- return tuple(self._encode_core(s))
113
-
114
- def _encode_core(self, text: str) -> List[int]:
115
- """Core encoding logic using Trie — no caching."""
116
- tokens = text
117
- result_ids = []
118
- i = 0
119
- n = len(tokens)
120
-
121
- while i < n:
122
- node = self.trie_root
123
- j = i
124
- last_match_id = None
125
- last_match_end = i
126
-
127
- while j < n and tokens[j] in node.children:
128
- node = node.children[tokens[j]]
129
- j += 1
130
- if node.token_id is not None:
131
- last_match_id = node.token_id
132
- last_match_end = j
133
-
134
- if last_match_id is not None:
135
- result_ids.append(last_match_id)
136
- i = last_match_end
137
- else:
138
- tok = tokens[i]
139
- result_ids.append(self.token_to_id.get(tok, self.unk_token_id))
140
- i += 1
141
-
142
- return result_ids
143
-
144
- def _tokenize(self, text: str, **kwargs) -> List[str]:
145
- token_ids = self._encode_core(text.strip())
146
- return [self.id_to_token[tid] for tid in token_ids]
147
-
148
- def _convert_token_to_id(self, token: str) -> int:
149
- return self.token_to_id.get(token, self.unk_token_id)
150
-
151
- def _convert_id_to_token(self, index: int) -> str:
152
- return self.id_to_token.get(index, self.unk_token)
153
-
154
- # ✅ Public methods
155
- def convert_tokens_to_ids(self, tokens: Union[str, List[str]]) -> Union[int, List[int]]:
156
- if isinstance(tokens, str):
157
- return self._convert_token_to_id(tokens)
158
- return [self._convert_token_to_id(tok) for tok in tokens]
159
-
160
- def convert_ids_to_tokens(self, ids: Union[int, List[int]]) -> Union[str, List[str]]:
161
- if isinstance(ids, int):
162
- return self._convert_id_to_token(ids)
163
- return [self._convert_id_to_token(i) for i in ids]
164
-
165
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
166
- """SMILES-style decoding: no spaces between tokens."""
167
- return "".join(tokens)
168
-
169
- def encode(
170
- self,
171
- text: str,
172
- text_pair: Optional[str] = None,
173
- add_special_tokens: bool = True,
174
- padding: bool = False,
175
- truncation: bool = False,
176
- max_length: Optional[int] = None,
177
- return_tensors: Optional[str] = None,
178
- ) -> List[int]:
179
- encoded = self.encode_plus(
180
- text=text,
181
- text_pair=text_pair,
182
- add_special_tokens=add_special_tokens,
183
- padding=padding,
184
- truncation=truncation,
185
- max_length=max_length,
186
- return_tensors=return_tensors,
187
- )
188
-
189
- input_ids = encoded["input_ids"]
190
- if isinstance(input_ids, torch.Tensor):
191
- if input_ids.dim() > 1:
192
- input_ids = input_ids.squeeze(0)
193
- input_ids = input_ids.tolist()
194
-
195
- return input_ids
196
-
197
- def decode(
198
- self,
199
- token_ids: Union[List[int], torch.Tensor],
200
- skip_special_tokens: bool = False,
201
- clean_up_tokenization_spaces: bool = None,
202
- **kwargs
203
- ) -> str:
204
- if isinstance(token_ids, torch.Tensor):
205
- token_ids = token_ids.tolist()
206
-
207
- if skip_special_tokens:
208
- special_ids = {
209
- self.bos_token_id,
210
- self.eos_token_id,
211
- self.pad_token_id,
212
- self.mask_token_id,
213
- }
214
- else:
215
- special_ids = set()
216
-
217
- tokens = []
218
- for tid in token_ids:
219
- if tid in special_ids:
220
- continue
221
- token = self.id_to_token.get(tid, self.unk_token)
222
- tokens.append(token)
223
-
224
- return "".join(tokens)
225
-
226
- def batch_decode(
227
- self,
228
- sequences: Union[List[List[int]], torch.Tensor],
229
- skip_special_tokens: bool = False,
230
- clean_up_tokenization_spaces: bool = None,
231
- **kwargs
232
- ) -> List[str]:
233
- """Batch decode sequences."""
234
- if isinstance(sequences, torch.Tensor):
235
- sequences = sequences.tolist()
236
-
237
- return [
238
- self.decode(
239
- seq,
240
- skip_special_tokens=skip_special_tokens,
241
- clean_up_tokenization_spaces=clean_up_tokenization_spaces,
242
- **kwargs
243
- )
244
- for seq in sequences
245
- ]
246
-
247
- def decode_with_trace(self, token_ids: List[int]) -> None:
248
- print(f"\n🔍 Decoding {len(token_ids)} tokens:")
249
- for i, tid in enumerate(token_ids):
250
- token = self.id_to_token.get(tid, self.unk_token)
251
- print(f" [{i:03d}] ID={tid:5d} → '{token}'")
252
-
253
- def __call__(
254
- self,
255
- text: Union[str, List[str]],
256
- text_pair: Optional[Union[str, List[str]]] = None,
257
- add_special_tokens: bool = True,
258
- padding: Union[bool, str, PaddingStrategy] = False,
259
- truncation: Union[bool, str] = False,
260
- max_length: Optional[int] = None,
261
- stride: int = 0,
262
- is_split_into_words: bool = False,
263
- pad_to_multiple_of: Optional[int] = None,
264
- return_tensors: Optional[Union[str, TensorType]] = None,
265
- return_token_type_ids: Optional[bool] = None,
266
- return_attention_mask: Optional[bool] = None,
267
- return_overflowing_tokens: bool = False,
268
- return_special_tokens_mask: bool = False,
269
- return_offsets_mapping: bool = False,
270
- return_length: bool = False,
271
- verbose: bool = True,
272
- **kwargs
273
- ) -> BatchEncoding:
274
- """
275
- Main callable method that handles both single and batch inputs.
276
- """
277
- # Handle defaults
278
- if return_token_type_ids is None:
279
- return_token_type_ids = True
280
- if return_attention_mask is None:
281
- return_attention_mask = True
282
-
283
- if isinstance(text, list):
284
- if text_pair is not None:
285
- batch = [(t, p) for t, p in zip(text, text_pair)]
286
- else:
287
- batch = text
288
- return self.batch_encode_plus(
289
- batch,
290
- add_special_tokens=add_special_tokens,
291
- padding=padding,
292
- truncation=truncation,
293
- max_length=max_length,
294
- stride=stride,
295
- is_split_into_words=is_split_into_words,
296
- pad_to_multiple_of=pad_to_multiple_of,
297
- return_tensors=return_tensors,
298
- return_token_type_ids=return_token_type_ids,
299
- return_attention_mask=return_attention_mask,
300
- return_overflowing_tokens=return_overflowing_tokens,
301
- return_special_tokens_mask=return_special_tokens_mask,
302
- return_offsets_mapping=return_offsets_mapping,
303
- return_length=return_length,
304
- verbose=verbose,
305
- **kwargs
306
- )
307
- else:
308
- return self.encode_plus(
309
- text=text,
310
- text_pair=text_pair,
311
- add_special_tokens=add_special_tokens,
312
- padding=padding,
313
- truncation=truncation,
314
- max_length=max_length,
315
- stride=stride,
316
- is_split_into_words=is_split_into_words,
317
- pad_to_multiple_of=pad_to_multiple_of,
318
- return_tensors=return_tensors,
319
- return_token_type_ids=return_token_type_ids,
320
- return_attention_mask=return_attention_mask,
321
- return_overflowing_tokens=return_overflowing_tokens,
322
- return_special_tokens_mask=return_special_tokens_mask,
323
- return_offsets_mapping=return_offsets_mapping,
324
- return_length=return_length,
325
- verbose=verbose,
326
- **kwargs
327
- )
328
-
329
- def encode_plus(
330
- self,
331
- text: str,
332
- text_pair: Optional[str] = None,
333
- add_special_tokens: bool = True,
334
- padding: Union[bool, str, PaddingStrategy] = False,
335
- truncation: Union[bool, str] = False,
336
- max_length: Optional[int] = None,
337
- stride: int = 0,
338
- is_split_into_words: bool = False,
339
- pad_to_multiple_of: Optional[int] = None,
340
- return_tensors: Optional[Union[str, TensorType]] = None,
341
- return_token_type_ids: Optional[bool] = True,
342
- return_attention_mask: Optional[bool] = True,
343
- return_overflowing_tokens: bool = False,
344
- return_special_tokens_mask: bool = False,
345
- return_offsets_mapping: bool = False,
346
- return_length: bool = False,
347
- verbose: bool = True,
348
- **kwargs
349
- ) -> BatchEncoding:
350
- if max_length is None:
351
- max_length = self.model_max_length
352
-
353
- ids_a = list(self._cached_encode_str(text.strip()))
354
-
355
- if text_pair is not None:
356
- ids_b = list(self._cached_encode_str(text_pair.strip()))
357
- else:
358
- ids_b = None
359
-
360
- input_ids = []
361
- token_type_ids = []
362
-
363
- if add_special_tokens:
364
- input_ids.append(self.bos_token_id)
365
- token_type_ids.append(0)
366
- if ids_b is not None:
367
- input_ids.extend(ids_a)
368
- token_type_ids.extend([0] * len(ids_a))
369
- input_ids.append(self.eos_token_id)
370
- token_type_ids.append(0)
371
-
372
- input_ids.extend(ids_b)
373
- token_type_ids.extend([1] * len(ids_b))
374
- input_ids.append(self.eos_token_id)
375
- token_type_ids.append(1)
376
- else:
377
- input_ids.extend(ids_a)
378
- token_type_ids.extend([0] * len(ids_a))
379
- input_ids.append(self.eos_token_id)
380
- token_type_ids.append(0)
381
- else:
382
- input_ids = ids_a.copy()
383
- token_type_ids = [0] * len(input_ids)
384
- if ids_b is not None:
385
- input_ids.extend(ids_b)
386
- token_type_ids.extend([1] * len(ids_b))
387
-
388
- # Handle truncation
389
- if truncation and len(input_ids) > max_length:
390
- input_ids = input_ids[:max_length]
391
- token_type_ids = token_type_ids[:max_length]
392
-
393
- # Handle padding
394
- if padding == True or padding == "max_length":
395
- pad_len = max_length - len(input_ids)
396
- if pad_len > 0:
397
- if self.padding_side == "right":
398
- input_ids.extend([self.pad_token_id] * pad_len)
399
- token_type_ids.extend([0] * pad_len)
400
- else:
401
- input_ids = [self.pad_token_id] * pad_len + input_ids
402
- token_type_ids = [0] * pad_len + token_type_ids
403
-
404
- attention_mask = [1 if tid != self.pad_token_id else 0 for tid in input_ids]
405
-
406
- encoded_dict = {
407
- "input_ids": input_ids,
408
- }
409
-
410
- if return_attention_mask:
411
- encoded_dict["attention_mask"] = attention_mask
412
-
413
- if return_token_type_ids:
414
- encoded_dict["token_type_ids"] = token_type_ids
415
-
416
- if return_special_tokens_mask:
417
- special_tokens_mask = [
418
- 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id} else 0
419
- for tid in input_ids
420
- ]
421
- encoded_dict["special_tokens_mask"] = special_tokens_mask
422
-
423
- if return_length:
424
- encoded_dict["length"] = len([tid for tid in input_ids if tid != self.pad_token_id])
425
-
426
- if return_tensors == "pt":
427
- output = {}
428
- for k, v in encoded_dict.items():
429
- tensor = torch.tensor(v, dtype=torch.long)
430
- if tensor.ndim == 1:
431
- tensor = tensor.unsqueeze(0)
432
- output[k] = tensor
433
- else:
434
- output = encoded_dict
435
-
436
- return BatchEncoding(output, tensor_type=return_tensors)
437
-
438
- def batch_encode_plus(
439
- self,
440
- batch_text_or_text_pairs: List[Union[str, Tuple[str, str]]],
441
- add_special_tokens: bool = True,
442
- padding: Union[bool, str, PaddingStrategy] = False,
443
- truncation: Union[bool, str] = False,
444
- max_length: Optional[int] = None,
445
- stride: int = 0,
446
- is_split_into_words: bool = False,
447
- pad_to_multiple_of: Optional[int] = None,
448
- return_tensors: Optional[Union[str, TensorType]] = None,
449
- return_token_type_ids: Optional[bool] = True,
450
- return_attention_mask: Optional[bool] = True,
451
- return_overflowing_tokens: bool = False,
452
- return_special_tokens_mask: bool = False,
453
- return_offsets_mapping: bool = False,
454
- return_length: bool = False,
455
- verbose: bool = True,
456
- **kwargs
457
- ) -> BatchEncoding:
458
- all_input_ids = []
459
- all_attention_masks = []
460
- all_token_type_ids = []
461
- all_special_tokens_masks = []
462
- all_lengths = []
463
-
464
- for item in batch_text_or_text_pairs:
465
- if isinstance(item, tuple):
466
- text, text_pair = item
467
- else:
468
- text, text_pair = item, None
469
-
470
- encoded = self.encode_plus(
471
- text=text,
472
- text_pair=text_pair,
473
- add_special_tokens=add_special_tokens,
474
- padding=False, # We'll handle batch padding later
475
- truncation=truncation,
476
- max_length=max_length,
477
- stride=stride,
478
- is_split_into_words=is_split_into_words,
479
- pad_to_multiple_of=pad_to_multiple_of,
480
- return_tensors=None, # Don't convert to tensors yet
481
- return_token_type_ids=return_token_type_ids,
482
- return_attention_mask=return_attention_mask,
483
- return_overflowing_tokens=return_overflowing_tokens,
484
- return_special_tokens_mask=return_special_tokens_mask,
485
- return_offsets_mapping=return_offsets_mapping,
486
- return_length=return_length,
487
- verbose=verbose,
488
- **kwargs
489
- )
490
-
491
- all_input_ids.append(encoded["input_ids"])
492
- if "attention_mask" in encoded:
493
- all_attention_masks.append(encoded["attention_mask"])
494
- if "token_type_ids" in encoded:
495
- all_token_type_ids.append(encoded["token_type_ids"])
496
- if "special_tokens_mask" in encoded:
497
- all_special_tokens_masks.append(encoded["special_tokens_mask"])
498
- if "length" in encoded:
499
- all_lengths.append(encoded["length"])
500
-
501
- batched = {
502
- "input_ids": all_input_ids,
503
- }
504
-
505
- if all_attention_masks:
506
- batched["attention_mask"] = all_attention_masks
507
- if all_token_type_ids:
508
- batched["token_type_ids"] = all_token_type_ids
509
- if all_special_tokens_masks:
510
- batched["special_tokens_mask"] = all_special_tokens_masks
511
- if all_lengths:
512
- batched["length"] = all_lengths
513
-
514
- # Handle batch padding
515
- if padding == True or padding == "longest":
516
- max_len = max(len(ids) for ids in all_input_ids)
517
- for key in batched:
518
- if key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
519
- padded_seqs = []
520
- for seq in batched[key]:
521
- pad_len = max_len - len(seq)
522
- if pad_len > 0:
523
- if key == "input_ids":
524
- padding_value = self.pad_token_id
525
- else:
526
- padding_value = 0
527
-
528
- if self.padding_side == "right":
529
- padded_seq = seq + [padding_value] * pad_len
530
- else:
531
- padded_seq = [padding_value] * pad_len + seq
532
- else:
533
- padded_seq = seq
534
- padded_seqs.append(padded_seq)
535
- batched[key] = padded_seqs
536
-
537
- if return_tensors == "pt":
538
- def to_tensor_list(lst):
539
- return [torch.tensor(item, dtype=torch.long) for item in lst]
540
-
541
- for key in ["input_ids", "attention_mask", "token_type_ids", "special_tokens_mask"]:
542
- if key in batched:
543
- batched[key] = torch.nn.utils.rnn.pad_sequence(
544
- to_tensor_list(batched[key]),
545
- batch_first=True,
546
- padding_value=self.pad_token_id if key == "input_ids" else 0
547
- )
548
-
549
- # Handle non-sequence data
550
- if "length" in batched:
551
- batched["length"] = torch.tensor(batched["length"], dtype=torch.long)
552
-
553
- return BatchEncoding(batched, tensor_type=return_tensors)
554
-
555
- def pad(
556
- self,
557
- encoded_inputs,
558
- padding: Union[bool, str, PaddingStrategy] = True,
559
- max_length: Optional[int] = None,
560
- pad_to_multiple_of: Optional[int] = None,
561
- return_attention_mask: Optional[bool] = None,
562
- return_tensors: Optional[Union[str, TensorType]] = None,
563
- verbose: bool = True,
564
- ) -> BatchEncoding:
565
- """Pad encoded inputs."""
566
- # This is a simplified version - full implementation would be more complex
567
- return encoded_inputs
568
-
569
- # Save/Load methods
570
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
571
- """Save vocabulary to files."""
572
- if not os.path.isdir(save_directory):
573
- os.makedirs(save_directory)
574
-
575
- vocab_file = os.path.join(
576
- save_directory,
577
- (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
578
- )
579
-
580
- with open(vocab_file, "w", encoding="utf-8") as f:
581
- json.dump(self.token_to_id, f, ensure_ascii=False, indent=2)
582
-
583
- return (vocab_file,)
584
-
585
- def save_pretrained(
586
- self,
587
- save_directory: Union[str, os.PathLike],
588
- legacy_format: bool = True,
589
- filename_prefix: Optional[str] = None,
590
- push_to_hub: bool = False,
591
- **kwargs
592
- ):
593
- """Save tokenizer to directory."""
594
- if not os.path.exists(save_directory):
595
- os.makedirs(save_directory)
596
-
597
- # Save vocabulary
598
- vocab_files = self.save_vocabulary(save_directory, filename_prefix)
599
-
600
- # Save tokenizer config
601
- tokenizer_config = {
602
- "tokenizer_class": self.__class__.__name__,
603
- "model_max_length": self.model_max_length,
604
- "padding_side": self.padding_side,
605
- "truncation_side": self.truncation_side,
606
- "special_tokens": {
607
- "bos_token": self.bos_token,
608
- "eos_token": self.eos_token,
609
- "pad_token": self.pad_token,
610
- "unk_token": self.unk_token,
611
- "mask_token": self.mask_token,
612
- }
613
- }
614
-
615
- config_file = os.path.join(save_directory, "tokenizer_config.json")
616
- with open(config_file, "w", encoding="utf-8") as f:
617
- json.dump(tokenizer_config, f, ensure_ascii=False, indent=2)
618
-
619
- print(f"✅ Tokenizer saved to: {save_directory}")
620
-
621
- return (save_directory,)
622
-
623
- @classmethod
624
- def from_pretrained(
625
- cls,
626
- pretrained_model_name_or_path: Union[str, os.PathLike],
627
- *init_inputs,
628
- **kwargs
629
- ):
630
- """Load tokenizer from pretrained directory or hub."""
631
- if os.path.isdir(pretrained_model_name_or_path):
632
- vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json")
633
- config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
634
-
635
- # Load config if available
636
- config = {}
637
- if os.path.exists(config_file):
638
- with open(config_file, "r", encoding="utf-8") as f:
639
- config = json.load(f)
640
-
641
- # Merge config with kwargs
642
- merged_config = {**config, **kwargs}
643
-
644
- return cls(vocab_file=vocab_file, **merged_config)
645
- else:
646
- raise NotImplementedError("Loading from HuggingFace Hub not implemented yet")
647
-
648
- def get_special_tokens_mask(
649
- self,
650
- token_ids_0: List[int],
651
- token_ids_1: Optional[List[int]] = None,
652
- already_has_special_tokens: bool = False
653
- ) -> List[int]:
654
- """Get special tokens mask."""
655
- if already_has_special_tokens:
656
- return [
657
- 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id}
658
- else 0 for tid in token_ids_0
659
- ]
660
-
661
- mask = [1] # BOS
662
- mask.extend([0] * len(token_ids_0)) # Token sequence
663
- mask.append(1) # EOS
664
-
665
- if token_ids_1 is not None:
666
- mask.extend([0] * len(token_ids_1)) # Second sequence
667
- mask.append(1) # EOS
668
-
669
- return mask
670
-
671
- def create_token_type_ids_from_sequences(
672
- self,
673
- token_ids_0: List[int],
674
- token_ids_1: Optional[List[int]] = None
675
- ) -> List[int]:
676
- """Create token type IDs for sequences."""
677
- sep = [self.eos_token_id]
678
- cls = [self.bos_token_id]
679
-
680
- if token_ids_1 is None:
681
- return len(cls + token_ids_0 + sep) * [0]
682
-
683
- return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
684
-
685
- def build_inputs_with_special_tokens(
686
- self,
687
- token_ids_0: List[int],
688
- token_ids_1: Optional[List[int]] = None
689
- ) -> List[int]:
690
- """Build inputs with special tokens."""
691
- if token_ids_1 is None:
692
- return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
693
-
694
- return ([self.bos_token_id] + token_ids_0 + [self.eos_token_id] +
695
- token_ids_1 + [self.eos_token_id])
696
-
697
-
698
- class FastChemTokenizerSelfies(FastChemTokenizer):
699
- """
700
- SELFIES variant that handles whitespace-separated tokens.
701
- Uses trie-based longest-match encoding (same as original working version).
702
- """
703
-
704
- def _encode_core(self, text: str) -> List[int]:
705
- """Trie-based encoding for SELFIES with fragment + atom vocab."""
706
- result_ids = []
707
- i = 0
708
- n = len(text)
709
-
710
- while i < n:
711
- if text[i].isspace(): # skip literal whitespace
712
- i += 1
713
- continue
714
-
715
- node = self.trie_root
716
- j = i
717
- last_match_id = None
718
- last_match_end = i
719
-
720
- # Traverse trie character by character (including spaces if part of vocab key)
721
- while j < n and text[j] in node.children:
722
- node = node.children[text[j]]
723
- j += 1
724
- if node.token_id is not None:
725
- last_match_id = node.token_id
726
- last_match_end = j
727
-
728
- if last_match_id is not None:
729
- result_ids.append(last_match_id)
730
- i = last_match_end
731
- else:
732
- # Fallback: encode one char as unk or atom
733
- result_ids.append(self.token_to_id.get(text[i], self.unk_token_id))
734
- i += 1
735
-
736
- return result_ids
737
-
738
- def convert_tokens_to_string(self, tokens: List[str]) -> str:
739
- """SELFIES decoding: join tokens with spaces (preserve original format)."""
740
- return " ".join(tokens)
741
-
742
- def decode(
743
- self,
744
- token_ids: Union[List[int], torch.Tensor],
745
- skip_special_tokens: bool = False,
746
- clean_up_tokenization_spaces: bool = None,
747
- **kwargs
748
- ) -> str:
749
- if isinstance(token_ids, torch.Tensor):
750
- token_ids = token_ids.tolist()
751
-
752
- if skip_special_tokens:
753
- special_ids = {
754
- self.bos_token_id,
755
- self.eos_token_id,
756
- self.pad_token_id,
757
- self.mask_token_id,
758
- }
759
- else:
760
- special_ids = set()
761
-
762
- tokens = []
763
- for tid in token_ids:
764
- if tid in special_ids:
765
- continue
766
- token = self.id_to_token.get(tid, self.unk_token)
767
- tokens.append(token)
768
-
769
- return " ".join(tokens) # ✅ preserve spaces
 
1
+ import torch
2
+ import json
3
+ import os
4
+ from typing import List, Union, Optional, Tuple, Dict, Any
5
+ from functools import lru_cache
6
+ from collections.abc import Mapping
7
+
8
+
9
+ # ------------------------------
10
+ # BatchEncoding
11
+ # ------------------------------
12
+ class BatchEncoding(dict, Mapping):
13
+ """Minimal BatchEncoding compatible wrapper."""
14
+
15
+ def __init__(self, data: dict, tensor_type: Optional[str] = None):
16
+ data = {} if data is None else {k: v for k, v in data.items()}
17
+ super().__init__(data)
18
+ self.data = data
19
+ self.tensor_type = tensor_type
20
+ for k, v in data.items():
21
+ setattr(self, k, v)
22
+
23
+ def __getitem__(self, key): return self.data[key]
24
+ def __iter__(self): return iter(self.data)
25
+ def __len__(self): return len(self.data)
26
+ def keys(self): return self.data.keys()
27
+ def values(self): return self.data.values()
28
+ def items(self): return self.data.items()
29
+ def get(self, key, default=None): return self.data.get(key, default)
30
+
31
+ def to(self, device):
32
+ if self.tensor_type in ("pt", "torch"):
33
+ for k, v in list(self.data.items()):
34
+ if torch.is_tensor(v):
35
+ self.data[k] = v.to(device)
36
+ setattr(self, k, self.data[k])
37
+ return self
38
+
39
+ def cpu(self): return self.to("cpu")
40
+ def cuda(self): return self.to("cuda")
41
+ def detach(self):
42
+ if self.tensor_type in ("pt", "torch"):
43
+ for k, v in list(self.data.items()):
44
+ if torch.is_tensor(v):
45
+ self.data[k] = v.detach()
46
+ setattr(self, k, self.data[k])
47
+ return self
48
+
49
+ def __repr__(self):
50
+ keys = ", ".join(list(self.data.keys())[:10])
51
+ return f"BatchEncoding(keys=[{keys}], tensor_type={self.tensor_type})"
52
+
53
+
54
+ # ------------------------------
55
+ # Base class
56
+ # ------------------------------
57
+ class PreTrainedTokenizerBase:
58
+ def __init__(self, **kwargs):
59
+ for key, value in kwargs.items():
60
+ if key.endswith('_token'):
61
+ setattr(self, f"_{key}", value)
62
+ setattr(self, f"{key}_id", None)
63
+ self.model_max_length = kwargs.get('model_max_length', 512)
64
+ self.padding_side = kwargs.get('padding_side', 'right')
65
+ self.truncation_side = kwargs.get('truncation_side', 'right')
66
+ self.chat_template = kwargs.get('chat_template')
67
+
68
+
69
+ # ------------------------------
70
+ # Trie node
71
+ # ------------------------------
72
+ class TrieNode:
73
+ __slots__ = ['children', 'token_id']
74
+ def __init__(self):
75
+ self.children = {}
76
+ self.token_id = None
77
+
78
+
79
+ # ------------------------------
80
+ # FastChemTokenizer
81
+ # ------------------------------
82
+
83
+ class FastChemTokenizer(PreTrainedTokenizerBase):
84
+ def __init__(self, token_to_id=None, vocab_file=None, **kwargs):
85
+ if vocab_file is not None:
86
+ with open(vocab_file, "r", encoding="utf-8") as f:
87
+ token_to_id = json.load(f)
88
+ token_to_id = {str(k): int(v) for k, v in token_to_id.items()}
89
+
90
+ self.token_to_id = token_to_id
91
+ self.id_to_token = {v: k for k, v in token_to_id.items()}
92
+
93
+ # Build trie
94
+ self.trie_root = self._build_trie(self.token_to_id)
95
+
96
+ # Call parent (sets token *strings*, may reset *_id to None)
97
+ super().__init__(
98
+ bos_token="<s>",
99
+ eos_token="</s>",
100
+ unk_token="<unk>",
101
+ pad_token="<pad>",
102
+ mask_token="<mask>",
103
+ model_max_length=kwargs.get("model_max_length", 512),
104
+ padding_side=kwargs.get("padding_side", "right"),
105
+ truncation_side=kwargs.get("truncation_side", "right"),
106
+ **kwargs,
107
+ )
108
+
109
+ # ✅ Re-map token strings → IDs from vocab
110
+ self.bos_token_id = self.token_to_id.get("<s>", 0)
111
+ self.eos_token_id = self.token_to_id.get("</s>", 1)
112
+ self.pad_token_id = self.token_to_id.get("<pad>", 2)
113
+ self.unk_token_id = self.token_to_id.get("<unk>", 3)
114
+ self.mask_token_id = self.token_to_id.get("<mask>", 4)
115
+
116
+ # Ensure reverse mapping always valid
117
+ self.id_to_token[self.bos_token_id] = "<s>"
118
+ self.id_to_token[self.eos_token_id] = "</s>"
119
+ self.id_to_token[self.pad_token_id] = "<pad>"
120
+ self.id_to_token[self.unk_token_id] = "<unk>"
121
+ self.id_to_token[self.mask_token_id] = "<mask>"
122
+
123
+ # Debug
124
+ print("✅ Special tokens bound:",
125
+ self.bos_token_id, self.eos_token_id, self.pad_token_id,
126
+ self.unk_token_id, self.mask_token_id)
127
+
128
+ # Ensure token *strings* also exist (for decode fallback)
129
+ self.bos_token = "<s>"
130
+ self.eos_token = "</s>"
131
+ self.pad_token = "<pad>"
132
+ self.unk_token = "<unk>"
133
+ self.mask_token = "<mask>"
134
+
135
+
136
+ def _build_trie(self, token_to_id):
137
+ root = TrieNode()
138
+ for token, tid in token_to_id.items():
139
+ node = root
140
+ for char in token:
141
+ if char not in node.children:
142
+ node.children[char] = TrieNode()
143
+ node = node.children[char]
144
+ node.token_id = tid
145
+ return root
146
+
147
+ @property
148
+ def vocab_size(self): return len(self.token_to_id)
149
+ def __len__(self): return len(self.token_to_id)
150
+ def get_vocab(self) -> Dict[str, int]: return self.token_to_id.copy()
151
+
152
+ @lru_cache(maxsize=10000)
153
+ def _cached_encode_str(self, s: str) -> Tuple[int, ...]:
154
+ return tuple(self._encode_core(s))
155
+
156
+ def _encode_core(self, text: str) -> List[int]:
157
+ tokens, result_ids = text, []
158
+ i, n = 0, len(tokens)
159
+ while i < n:
160
+ node, j = self.trie_root, i
161
+ last_match_id, last_match_end = None, i
162
+ while j < n and tokens[j] in node.children:
163
+ node = node.children[tokens[j]]
164
+ j += 1
165
+ if node.token_id is not None:
166
+ last_match_id, last_match_end = node.token_id, j
167
+ if last_match_id is not None:
168
+ result_ids.append(last_match_id)
169
+ i = last_match_end
170
+ else:
171
+ tid = self.token_to_id.get(tokens[i], self.unk_token_id)
172
+ result_ids.append(tid)
173
+ i += 1
174
+ return result_ids
175
+
176
+ # ------------------------------
177
+ # Converters
178
+ # ------------------------------
179
+ def _convert_token_to_id(self, token: str) -> int:
180
+ return self.token_to_id.get(token, self.unk_token_id)
181
+ def _convert_id_to_token(self, index: int) -> str:
182
+ return self.id_to_token.get(index, self.unk_token)
183
+
184
+ def convert_tokens_to_ids(self, tokens: Union[str, List[str]]):
185
+ if isinstance(tokens, str): return self._convert_token_to_id(tokens)
186
+ return [self._convert_token_to_id(tok) for tok in tokens]
187
+
188
+ def convert_ids_to_tokens(self, ids: Union[int, List[int]]):
189
+ if isinstance(ids, int): return self._convert_id_to_token(ids)
190
+ return [self._convert_id_to_token(i) for i in ids]
191
+
192
+ def convert_tokens_to_string(self, tokens: List[str]) -> str: return "".join(tokens)
193
+
194
+ # ------------------------------
195
+ # Encoding / Decoding
196
+ # ------------------------------
197
+ # ------------------------------
198
+ # Convenience wrappers
199
+ # ------------------------------
200
+ def encode(
201
+ self,
202
+ text: str,
203
+ text_pair: Optional[str] = None,
204
+ add_special_tokens: bool = True,
205
+ padding: bool = False,
206
+ truncation: bool = False,
207
+ max_length: Optional[int] = None,
208
+ return_tensors: Optional[str] = None,
209
+ ) -> List[int]:
210
+ encoded = self.encode_plus(
211
+ text=text,
212
+ text_pair=text_pair,
213
+ add_special_tokens=add_special_tokens,
214
+ padding=padding,
215
+ truncation=truncation,
216
+ max_length=max_length,
217
+ return_tensors=return_tensors,
218
+ )
219
+ input_ids = encoded["input_ids"]
220
+ if isinstance(input_ids, torch.Tensor):
221
+ if input_ids.dim() > 1:
222
+ input_ids = input_ids.squeeze(0)
223
+ input_ids = input_ids.tolist()
224
+ return input_ids
225
+
226
+ def __call__(
227
+ self,
228
+ text: Union[str, List[str]],
229
+ text_pair: Optional[Union[str, List[str]]] = None,
230
+ add_special_tokens: bool = True,
231
+ padding: Union[bool, str] = False,
232
+ truncation: Union[bool, str] = False,
233
+ max_length: Optional[int] = None,
234
+ stride: int = 0,
235
+ is_split_into_words: bool = False,
236
+ pad_to_multiple_of: Optional[int] = None,
237
+ return_tensors: Optional[Union[str, Any]] = None,
238
+ return_token_type_ids: Optional[bool] = None,
239
+ return_attention_mask: Optional[bool] = None,
240
+ return_overflowing_tokens: bool = False,
241
+ return_special_tokens_mask: bool = False,
242
+ return_offsets_mapping: bool = False,
243
+ return_length: bool = False,
244
+ verbose: bool = True,
245
+ **kwargs
246
+ ) -> BatchEncoding:
247
+ """HuggingFace-compatible: one string → encode_plus, list batch_encode_plus"""
248
+ if return_token_type_ids is None:
249
+ return_token_type_ids = True
250
+ if return_attention_mask is None:
251
+ return_attention_mask = True
252
+
253
+ if isinstance(text, list):
254
+ if text_pair is not None:
255
+ batch = [(t, p) for t, p in zip(text, text_pair)]
256
+ else:
257
+ batch = text
258
+ return self.batch_encode_plus(
259
+ batch,
260
+ add_special_tokens=add_special_tokens,
261
+ padding=padding,
262
+ truncation=truncation,
263
+ max_length=max_length,
264
+ stride=stride,
265
+ is_split_into_words=is_split_into_words,
266
+ pad_to_multiple_of=pad_to_multiple_of,
267
+ return_tensors=return_tensors,
268
+ return_token_type_ids=return_token_type_ids,
269
+ return_attention_mask=return_attention_mask,
270
+ return_overflowing_tokens=return_overflowing_tokens,
271
+ return_special_tokens_mask=return_special_tokens_mask,
272
+ return_offsets_mapping=return_offsets_mapping,
273
+ return_length=return_length,
274
+ verbose=verbose,
275
+ **kwargs
276
+ )
277
+ else:
278
+ return self.encode_plus(
279
+ text=text,
280
+ text_pair=text_pair,
281
+ add_special_tokens=add_special_tokens,
282
+ padding=padding,
283
+ truncation=truncation,
284
+ max_length=max_length,
285
+ stride=stride,
286
+ is_split_into_words=is_split_into_words,
287
+ pad_to_multiple_of=pad_to_multiple_of,
288
+ return_tensors=return_tensors,
289
+ return_token_type_ids=return_token_type_ids,
290
+ return_attention_mask=return_attention_mask,
291
+ return_overflowing_tokens=return_overflowing_tokens,
292
+ return_special_tokens_mask=return_special_tokens_mask,
293
+ return_offsets_mapping=return_offsets_mapping,
294
+ return_length=return_length,
295
+ verbose=verbose,
296
+ **kwargs
297
+ )
298
+
299
+ def encode_plus(
300
+ self,
301
+ text: str,
302
+ text_pair: Optional[str] = None,
303
+ add_special_tokens: bool = True,
304
+ padding: Union[bool, str] = False,
305
+ truncation: Union[bool, str] = False,
306
+ max_length: Optional[int] = None,
307
+ stride: int = 0,
308
+ is_split_into_words: bool = False,
309
+ pad_to_multiple_of: Optional[int] = None,
310
+ return_tensors: Optional[Union[str, Any]] = None,
311
+ return_token_type_ids: Optional[bool] = True,
312
+ return_attention_mask: Optional[bool] = True,
313
+ return_overflowing_tokens: bool = False,
314
+ return_special_tokens_mask: bool = False,
315
+ return_offsets_mapping: bool = False,
316
+ return_length: bool = False,
317
+ verbose: bool = True,
318
+ **kwargs
319
+ ) -> BatchEncoding:
320
+ if max_length is None: max_length = self.model_max_length
321
+ ids_a = list(self._cached_encode_str(text.strip()))
322
+ ids_b = list(self._cached_encode_str(text_pair.strip())) if text_pair else None
323
+
324
+ input_ids, token_type_ids = [], []
325
+ if add_special_tokens:
326
+ input_ids.append(self.bos_token_id); token_type_ids.append(0)
327
+ input_ids.extend(ids_a); token_type_ids.extend([0] * len(ids_a))
328
+ input_ids.append(self.eos_token_id); token_type_ids.append(0)
329
+ if ids_b is not None:
330
+ input_ids.extend(ids_b); token_type_ids.extend([1] * len(ids_b))
331
+ input_ids.append(self.eos_token_id); token_type_ids.append(1)
332
+ else:
333
+ input_ids = ids_a.copy(); token_type_ids = [0] * len(input_ids)
334
+ if ids_b is not None:
335
+ input_ids.extend(ids_b); token_type_ids.extend([1] * len(ids_b))
336
+
337
+ if truncation and len(input_ids) > max_length:
338
+ input_ids, token_type_ids = input_ids[:max_length], token_type_ids[:max_length]
339
+
340
+ encoded_dict = {"input_ids": input_ids}
341
+ if return_attention_mask:
342
+ if padding == True or padding == "max_length":
343
+ pad_len = max_length - len(input_ids)
344
+ if pad_len > 0:
345
+ if self.padding_side == "right":
346
+ input_ids.extend([self.pad_token_id] * pad_len)
347
+ token_type_ids.extend([0] * pad_len)
348
+ else:
349
+ input_ids = [self.pad_token_id] * pad_len + input_ids
350
+ token_type_ids = [0] * pad_len + token_type_ids
351
+ attention_mask = [0 if tid == self.pad_token_id else 1 for tid in input_ids]
352
+ encoded_dict["attention_mask"] = attention_mask
353
+ if return_token_type_ids: encoded_dict["token_type_ids"] = token_type_ids
354
+ if return_special_tokens_mask:
355
+ encoded_dict["special_tokens_mask"] = [
356
+ 1 if tid in {self.bos_token_id, self.eos_token_id, self.pad_token_id, self.mask_token_id} else 0
357
+ for tid in input_ids
358
+ ]
359
+ if return_length:
360
+ encoded_dict["length"] = len([tid for tid in input_ids if tid != self.pad_token_id])
361
+
362
+ if return_tensors in ["pt", "torch"]:
363
+ out = {}
364
+ for k, v in encoded_dict.items():
365
+ if isinstance(v, list):
366
+ tensor = torch.tensor(
367
+ [self.unk_token_id if x is None else int(x) for x in v], dtype=torch.long
368
+ ).unsqueeze(0)
369
+ out[k] = tensor
370
+ else:
371
+ out[k] = v
372
+ return BatchEncoding(out, tensor_type=return_tensors)
373
+ return BatchEncoding(encoded_dict, tensor_type=None)
374
+
375
+ def batch_encode_plus(
376
+ self,
377
+ batch_text_or_text_pairs: List[Union[str, Tuple[str, str]]],
378
+ add_special_tokens: bool = True,
379
+ padding: Union[bool, str] = False,
380
+ truncation: Union[bool, str] = False,
381
+ max_length: Optional[int] = None,
382
+ stride: int = 0,
383
+ is_split_into_words: bool = False,
384
+ pad_to_multiple_of: Optional[int] = None,
385
+ return_tensors: Optional[Union[str, Any]] = None,
386
+ return_token_type_ids: Optional[bool] = True,
387
+ return_attention_mask: Optional[bool] = True,
388
+ return_overflowing_tokens: bool = False,
389
+ return_special_tokens_mask: bool = False,
390
+ return_offsets_mapping: bool = False,
391
+ return_length: bool = False,
392
+ verbose: bool = True,
393
+ **kwargs
394
+ ) -> BatchEncoding:
395
+ if padding is True: padding = "longest"
396
+ if padding == "max_length" and max_length is None: max_length = self.model_max_length
397
+
398
+ all_input_ids, all_token_type_ids, all_attention_masks = [], [], []
399
+ all_special_masks, all_lengths = [], []
400
+ for item in batch_text_or_text_pairs:
401
+ t, tp = item if isinstance(item, tuple) else (item, None)
402
+ enc = self.encode_plus(
403
+ text=t, text_pair=tp, add_special_tokens=add_special_tokens,
404
+ padding=False, truncation=truncation, max_length=max_length,
405
+ return_tensors=None, return_token_type_ids=return_token_type_ids,
406
+ return_attention_mask=return_attention_mask,
407
+ return_special_tokens_mask=return_special_tokens_mask,
408
+ return_length=return_length, **kwargs
409
+ )
410
+ ids, tt, am = enc["input_ids"], enc.get("token_type_ids", [0]*len(enc["input_ids"])), enc.get("attention_mask",[1]*len(enc["input_ids"]))
411
+ sm, ln = enc.get("special_tokens_mask",[0]*len(ids)), enc.get("length", len([x for x in ids if x != self.pad_token_id]))
412
+ all_input_ids.append(ids); all_token_type_ids.append(tt); all_attention_masks.append(am)
413
+ all_special_masks.append(sm); all_lengths.append(ln)
414
+
415
+ pad_to = max(len(x) for x in all_input_ids) if padding == "longest" else (max_length if padding == "max_length" else None)
416
+ batched = {
417
+ "input_ids": all_input_ids,
418
+ "token_type_ids": all_token_type_ids if return_token_type_ids else None,
419
+ "attention_mask": all_attention_masks if return_attention_mask else None,
420
+ "special_tokens_mask": all_special_masks if return_special_tokens_mask else None,
421
+ "length": all_lengths if return_length else None,
422
+ }
423
+ if pad_to is not None:
424
+ for key in ["input_ids","token_type_ids","attention_mask","special_tokens_mask"]:
425
+ if batched.get(key) is None: continue
426
+ padded = []
427
+ for seq in batched[key]:
428
+ pad_len = pad_to - len(seq)
429
+ pad_val = self.pad_token_id if key=="input_ids" else 0
430
+ if pad_len > 0:
431
+ seq = seq+[pad_val]*pad_len if self.padding_side=="right" else [pad_val]*pad_len+seq
432
+ padded.append(seq)
433
+ batched[key] = padded
434
+
435
+ if return_tensors in ["pt", "torch"]:
436
+ def to_tensor(lst, pad_val=0):
437
+ return torch.tensor([[self.unk_token_id if x is None else int(x) for x in row] for row in lst], dtype=torch.long)
438
+ out = {}
439
+ if batched.get("input_ids") is not None: out["input_ids"] = to_tensor(batched["input_ids"], self.pad_token_id)
440
+ if batched.get("attention_mask") is not None: out["attention_mask"] = to_tensor(batched["attention_mask"],0)
441
+ if batched.get("token_type_ids") is not None: out["token_type_ids"] = to_tensor(batched["token_type_ids"],0)
442
+ if batched.get("special_tokens_mask") is not None: out["special_tokens_mask"] = to_tensor(batched["special_tokens_mask"],0)
443
+ if return_length and batched.get("length") is not None: out["length"] = torch.tensor([int(x) for x in batched["length"]], dtype=torch.long)
444
+ return BatchEncoding(out, tensor_type=return_tensors)
445
+ return BatchEncoding({k:v for k,v in batched.items() if v is not None}, tensor_type=None)
446
+
447
+ # ------------------------------
448
+ # Decoding
449
+ # ------------------------------
450
+ def decode(self, token_ids, skip_special_tokens=False, **kwargs):
451
+ if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist()
452
+ special_ids = {self.bos_token_id,self.eos_token_id,self.pad_token_id,self.mask_token_id} if skip_special_tokens else set()
453
+ tokens = [self.id_to_token.get(tid,self.unk_token) for tid in token_ids if tid not in special_ids]
454
+ return "".join(tokens)
455
+
456
+ def batch_decode(self, sequences, skip_special_tokens=False, **kwargs):
457
+ if isinstance(sequences, torch.Tensor): sequences = sequences.tolist()
458
+ return [self.decode(seq, skip_special_tokens=skip_special_tokens, **kwargs) for seq in sequences]
459
+
460
+ def decode_with_trace(self, token_ids: List[int]):
461
+ print(f"\n🔍 Decoding {len(token_ids)} tokens:")
462
+ for i, tid in enumerate(token_ids):
463
+ token = self.id_to_token.get(tid, self.unk_token)
464
+ tid_str = "None" if tid is None else f"{tid:5d}"
465
+ print(f" [{i:03d}] ID={tid_str} → '{token}'")
466
+
467
+ # ------------------------------
468
+ # Save / Load
469
+ # ------------------------------
470
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
471
+ if not os.path.isdir(save_directory): os.makedirs(save_directory)
472
+ vocab_file = os.path.join(save_directory,(filename_prefix+"-" if filename_prefix else "")+"vocab.json")
473
+ with open(vocab_file,"w",encoding="utf-8") as f: json.dump(self.token_to_id,f,ensure_ascii=False,indent=2)
474
+ return (vocab_file,)
475
+
476
+ def save_pretrained(self, save_directory: Union[str, os.PathLike], filename_prefix: Optional[str]=None, **kwargs):
477
+ if not os.path.exists(save_directory): os.makedirs(save_directory)
478
+ self.save_vocabulary(save_directory, filename_prefix)
479
+ config_file = os.path.join(save_directory,"tokenizer_config.json")
480
+ with open(config_file,"w",encoding="utf-8") as f:
481
+ json.dump({
482
+ "tokenizer_class": self.__class__.__name__,
483
+ "model_max_length": self.model_max_length,
484
+ "padding_side": self.padding_side,
485
+ "truncation_side": self.truncation_side,
486
+ "special_tokens": {
487
+ "bos_token": self.bos_token,
488
+ "eos_token": self.eos_token,
489
+ "pad_token": self.pad_token,
490
+ "unk_token": self.unk_token,
491
+ "mask_token": self.mask_token,
492
+ }
493
+ },f,ensure_ascii=False,indent=2)
494
+ return (save_directory,)
495
+
496
+ @classmethod
497
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
498
+ if os.path.isdir(pretrained_model_name_or_path):
499
+ vocab_file = os.path.join(pretrained_model_name_or_path,"vocab.json")
500
+ config_file = os.path.join(pretrained_model_name_or_path,"tokenizer_config.json")
501
+ config = {}
502
+ if os.path.exists(config_file):
503
+ with open(config_file,"r",encoding="utf-8") as f: config=json.load(f)
504
+ return cls(vocab_file=vocab_file, **{**config,**kwargs})
505
+ else:
506
+ raise NotImplementedError("Loading from Hub not implemented yet")
507
+
508
+
509
+ # ------------------------------
510
+ # SELFIES variant
511
+ # ------------------------------
512
+ class FastChemTokenizerSelfies(FastChemTokenizer):
513
+ def __init__(self, *args, **kwargs):
514
+ super().__init__(*args, **kwargs) # ensures BOS/EOS etc. are set
515
+
516
+ """SELFIES variant that handles whitespace-separated tokens."""
517
+
518
+ def _encode_core(self, text: str) -> List[int]:
519
+ result_ids, i, n = [], 0, len(text)
520
+ while i < n:
521
+ if text[i].isspace(): i += 1; continue
522
+ node, j = self.trie_root, i
523
+ last_match_id, last_match_end = None, i
524
+ while j < n and text[j] in node.children:
525
+ node = node.children[text[j]]; j += 1
526
+ if node.token_id is not None:
527
+ last_match_id, last_match_end = node.token_id, j
528
+ if last_match_id is not None:
529
+ result_ids.append(last_match_id); i = last_match_end
530
+ else:
531
+ result_ids.append(self.token_to_id.get(text[i], self.unk_token_id)); i += 1
532
+ return result_ids
533
+
534
+ def convert_tokens_to_string(self, tokens: List[str]) -> str: return " ".join(tokens)
535
+ def decode(self, token_ids, skip_special_tokens=False, **kwargs):
536
+ if isinstance(token_ids, torch.Tensor): token_ids = token_ids.tolist()
537
+ special_ids = {self.bos_token_id,self.eos_token_id,self.pad_token_id,self.mask_token_id} if skip_special_tokens else set()
538
+ tokens = [self.id_to_token.get(tid,self.unk_token) for tid in token_ids if tid not in special_ids]
539
+ return " ".join(tokens)