from collections import Counter import torchvision.datasets as dset from torch.utils.data import Dataset import torch from torch.utils.data import DataLoader import glob import os from torch.utils.data import Dataset, DataLoader, random_split from shutil import copyfile import subprocess import youtokentome as yttm import re import time from tqdm import trange, tqdm import numpy as np import matplotlib.pyplot as plt import inspect # Device for dataloading and dataloading only. Dataloading on MPS was slower DEVICE = "cpu" # "mps" if torch.backends.mps.is_available() else "cpu" class BPEModelManager: def __init__(self, root_dir, vocab_size=5000): self.root_dir = root_dir self.vocab_size = vocab_size self.model_path = os.path.join(root_dir, "bpe_model.model") try: self.bpe = yttm.BPE(model=self.model_path) if self.bpe.vocab_size() != vocab_size: print( f"Vocab size mismatch: Expected {vocab_size}, got {self.bpe.vocab_size()}. Retraining model." ) self._backup_model() raise ValueError except ValueError: self._train_bpe_model() self.bpe = yttm.BPE(model=self.model_path) def _backup_model(self): backup_path = os.path.join(self.root_dir, "bpe_model.model.old") copyfile(self.model_path, backup_path) def _train_bpe_model(self): data_path = os.path.join(self.root_dir, "data/corpus.txt") processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") with open(data_path, "r", errors="ignore") as reader: raw_text = reader.read() processed_text = self.preprocess_text(raw_text) with open(processed_path, "w") as writer: writer.write(processed_text) yttm.BPE.train( data=processed_path, vocab_size=self.vocab_size, model=self.model_path, coverage=0.9999, ) def preprocess_text(self, text): return text.lower() def encode(self, text: str): return self.bpe.encode([text], output_type=yttm.OutputType.ID) def decode(self, ids): return self.bpe.decode(ids.tolist())[0] @staticmethod def attention_mask(encoded_sequence, mask_token_ids=[0, 1, 2, 3]): mask_token_tensor = torch.tensor(mask_token_ids, dtype=torch.int).to( encoded_sequence.device ) # print(mask_token_tensor) # print(encoded_sequence) return (encoded_sequence.unsqueeze(1) != mask_token_tensor).all(dim=1).int() class CodeBPEModelManager(BPEModelManager): mapping_dict = { " ": " ", "\n": " ", } def __init__(self, root_dir, vocab_size=5000): super().__init__(root_dir, vocab_size) def preprocess_text(self, text): print("Formatting....") processed_text = self.format_code(text) for key, value in CodeBPEModelManager.mapping_dict.items(): processed_text = processed_text.replace(key, value) return processed_text def encode(self, text: str): processed_text = text for key, value in CodeBPEModelManager.mapping_dict.items(): processed_text = processed_text.replace(key, value) return self.bpe.encode([processed_text], output_type=yttm.OutputType.ID)[0] def decode(self, ids): # print(ids) # print("ids^^") l = ids if isinstance(l, torch.Tensor): l = ids.tolist() if isinstance(l, int): l = [l] result = self.bpe.decode(l)[0] # print(result) for key, value in CodeBPEModelManager.mapping_dict.items(): result = result.replace(value.strip(), key) # value, key return result def raw_decode(self, id: int): return self.bpe.decode([id])[0] def _train_bpe_model(self): print("Training (1)....") data_path = os.path.join(self.root_dir, "data/corpus.txt") processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") if input("Reformat? Will take time [y/N]") == "y": with open(data_path, "r", errors="ignore", encoding="utf-8") as reader: raw_text = reader.read() processed_text = self.preprocess_text(raw_text) with open(processed_path, "w", encoding="utf-8") as writer: writer.write(processed_text) print("removing temp file...") temp_file = os.path.join(self.root_dir, "temp_code.py") # dont ask os.remove(temp_file) print("Training....") yttm.BPE.train( data=processed_path, vocab_size=self.vocab_size, model=self.model_path, coverage=1, # coverage=0.995, # TODO: revert if you want ) def format_code(self, code): try: temp_file = os.path.join(self.root_dir, "temp_code.py") with open(temp_file, "w") as file: file.write( code.replace("\t", " ") ) # Hacky replacement, black freaks out otherwise subprocess.run(["black", temp_file, "--quiet"], check=True) subprocess.run( ["autopep8", "--in-place", "--ignore=E402", temp_file], check=True ) with open(temp_file, "r") as file: formatted_code = file.read() return formatted_code except Exception as e: print(f"Error during code formatting: {e}.") return code class CodeCustomTokenizerManager(BPEModelManager): reserved_keywords = [ "false", "await", "else", "import", "pass", "none", "break", "except", "in", "raise", "true", "class", "finally", "is", "return", "and", "continue", "for", "lambda", "try", "as", "def", "from", "nonlocal", "while", "assert", "del", "global", "not", "with", "async", "elif", "if", "or", "yield", ] symbols = [ "(", ")", "[", "]", "{", "}", ".", ",", ":", ";", "+", "-", "*", "/", "%", "=", "<", ">", "&", "|", "^", "~", "!", "==", "!=", "<=", ">=", "**", "//", "@", "#", "\\", "'", '"', "`", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0x", "0d", "0o", ] def __init__( self, root_dir, vocab_size=5000, cutoff_thresh=0.1, use_vocab_size_instead=False, use_whitespace=True, # haha ): # keep 90% with thresh 0.1 self.root_dir = root_dir self.token_to_id = {"": 0} self.id_to_token = None self._token_freqs = {} self.total_num_tokens = 0 print("This is CodeCustomTokenizerManager, vocab size will be disregarded.") print(f"Cutoff threshold: {cutoff_thresh}") self.cutoff_thresh = cutoff_thresh self.use_whitespace = use_whitespace if not use_whitespace: print("Not using whitespace! Important I guess") if use_vocab_size_instead: print("Nevermind! Using vocab size instead, no cutoff thresh") self.use_vocab_size_instead = use_vocab_size_instead self.vocab_size = vocab_size vocab_path = os.path.join(self.root_dir, "custom_tokens_vocab.txt") try: self.load_vocab(vocab_path) except FileNotFoundError: print("Making vocab!") self.make_vocab() self.save_vocab(vocab_path) print(f"Vocab size: {len(self.token_to_id)}") def make_vocab(self): data_path = os.path.join(self.root_dir, "data/corpus.txt") processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") with open(data_path, "r", errors="ignore") as reader: raw_text = reader.read() processed_text = self.preprocess_text(raw_text) with open(processed_path, "w") as writer: writer.write(" ".join(processed_text)) for token in processed_text: if token not in self.token_to_id: if len(self.token_to_id) == 0: self.token_to_id = {"": 0} # TODO: bad practice or something self.token_to_id[token] = len(self.token_to_id) print(f"Number of tokens: {len(self.token_to_id)}") def make_token_freqs(self): processed_path = os.path.join(self.root_dir, "data/corpus_processed.txt") with open(processed_path, "r", errors="ignore") as reader: raw_text = reader.read() tokens = raw_text.split(" ") token_freqs = {"": 0} for token in tqdm(tokens, leave=False): if token not in token_freqs: token_freqs[token] = 1 else: token_freqs[token] += 1 self._token_freqs = token_freqs self.total_num_tokens = len(tokens) def preprocess_text(self, code): print("Preprocessing text...", code[:20]) # print(code[:100]) # comments code = code.replace("# ", "") code = re.sub(r"#.*", "", code) code = re.sub(r'"""(.*?)"""', "", code, flags=re.DOTALL) # funny usage of re code = re.sub(r"'''(.*?)'''", "", code, flags=re.DOTALL) code = re.sub(r" ", " ", code) print("Filtered comments") # print(code[:100]) # filter non-ascii # https://regexr.com/8bmfe code = re.sub(r"[^ -~\s]+", "", code) # print(code[:100]) print("Filtered non-ascii") # # Handle hex/binary/octal sequences # def split_number_sequence(match): # prefix, digits = match.group(1), match.group(2) # return f"{prefix} " + " ".join(digits) # code = re.sub(r'(0x)([0-9a-f]+)', split_number_sequence, code) # code = re.sub(r'(0b)([01]+)', split_number_sequence, code) # code = re.sub(r'(0o)([0-7]+)', split_number_sequence, code) # print("Coped with hex") # each reserved word/symbol is a token. We split by space at the end, so this works. for word in self.reserved_keywords: code = re.sub(rf"\b{word}\b", f" {word} ", code) print("Reserved words") for symbol in self.symbols: code = code.replace(symbol, f" {symbol} ") print("Symbols") # print(code[:100]) # Split identifiers by spaces, underscores, hyphens, or capitalization def split_token(token): if token.startswith("<") and token.endswith( ">" ): # preserve ✨special✨ tokens return [token.lower()] result = re.sub(r"([a-z])([A-Z])", r"\1 \2", token) result = re.sub(r"([_-])", r" \1 ", result) result = re.sub(r"([^a-zA-Z])", r" \1 ", result) return [part.lower() for part in result.split() if part.strip()] code = code.replace(" ", " ").replace("\n", " ") if not self.use_whitespace: code = code.replace("", "").replace("", "") print("Tabs + newlines") tokens = [] for token in tqdm(code.split(" "), leave=False): if token.strip(): tokens.extend(split_token(token)) tokens = [tok.lower() for tok in tokens if tok.strip()] print("Split tokens") token_freqs = {"": 0} for token in tqdm(tokens, leave=False): if token not in token_freqs: token_freqs[token] = 1 else: token_freqs[token] += 1 print("Counted freqs") # what statistics do we want to calculate? # Number of tokens that appear only once, and percentage. # Mean number of times any given token appears. # standard things: mean, std, q1, q3, median, min, max # Print out topk most frequent and their freqs total_num_tokens = len(tokens) counter = Counter(list(token_freqs.values())) num_ones = counter[1] print( f"Number of tokens that appear only once: {num_ones}. Percentage: {num_ones / total_num_tokens}" ) print(f"Mean token count: {np.mean(list(token_freqs.values()))}") print(f"Median token count: {np.median(list(token_freqs.values()))}") print( f"Standard deviation of token count: {np.std(list(token_freqs.values()))}" ) print(f"Min token count: {np.min(list(token_freqs.values()))}") print(f"Max token count: {np.max(list(token_freqs.values()))}") print(f"Top 30 most frequent tokens:") sorted_tokens = sorted(token_freqs.items(), key=lambda x: x[1], reverse=True) for token, freq in sorted_tokens[:30]: print(f"{token}: {freq}") print(f"Bottom 30 most frequent tokens:") for token, freq in sorted_tokens[-30:]: print(f"{token}: {freq}") self._token_freqs = token_freqs self.total_num_tokens = total_num_tokens # plt.figure(figsize=(15,6)) # plt.bar(np.arange(len(sorted_tokens)), [freq for token, freq in sorted_tokens]) # plt.xlabel("Token") # plt.ylabel("Frequency") # plt.title("Token frequency distribution") # plt.show() # breakpoint() # use cutoff thresh to replace tokens with UNK cutoff_thresh = self.cutoff_thresh if self.use_vocab_size_instead: print("Using vocab size instead") print("deprecated") print("cope") exit() sorted_tokens = sorted( token_freqs.items(), key=lambda x: x[1], reverse=True ) allowed_tokens = set( token for token, _ in sorted_tokens[: self.vocab_size - 1] ) # -1 for PAD for i in range(len(tokens)): if tokens[i] not in allowed_tokens and tokens[i] != "": print(f"Replacing token with UNK: {tokens[i]}") tokens[i] = "" else: cutoff_amt = ( 10 # np.percentile(list(token_freqs.values()), (1-cutoff_thresh) * 100) ) print(f"Cuttoff amount: {cutoff_amt}") # using threshold {cutoff_thresh}") # llm-optimized low_freq_tokens = [ token for token, freq in token_freqs.items() if freq < cutoff_amt and token != "" ] low_freq_tokens_set = set(low_freq_tokens) tokens = [ "" if token in low_freq_tokens_set else token for token in tqdm(tokens) ] print(tokens[500:700]) print("500-700") return [tok for tok in tokens if tok.strip()] def encode(self, code): tokens = code.split(" ") ids = [] for token in tokens: # New token if token not in self.token_to_id: self.token_to_id[token] = len(self.token_to_id) ids.append(self.token_to_id[token]) return ids def decode(self, ids): result = "" for id in ids.tolist(): for token, id_iterator in self.token_to_id.items(): if id_iterator == id: result += token result += " " return result def raw_decode(self, id: int): for token, id_iterator in self.token_to_id.items(): if id_iterator == id: return token def format_code(self, code): try: temp_file = os.path.join(self.root_dir, "temp_code.py") with open(temp_file, "w") as file: file.write( code.replace("\t", " ") ) # Hacky replacement, black freaks out otherwise subprocess.run(["black", temp_file, "--quiet"], check=True) subprocess.run( ["autopep8", "--in-place", "--ignore=E402", temp_file], check=True ) with open(temp_file, "r") as file: formatted_code = file.read() return formatted_code except Exception as e: print(f"Error during code formatting: {e}.") return code def save_vocab(self, file_path): with open(file_path, "w") as file: for token, id in self.token_to_id.items(): file.write(f"{token}\t{id}\n") def load_vocab(self, file_path): self.token_to_id = {} with open(file_path, "r") as file: for line in file.read().split("\n"): try: token, id = line.strip().split("\t") self.token_to_id[token] = int(id) except ValueError: # print(line) # print("^^ is error") pass # Should be fine, ends up being blank lines @staticmethod def attention_mask(encoded_sequence, mask_token_ids=[0]): mask_token_tensor = torch.tensor(mask_token_ids, dtype=torch.int) # print(mask_token_tensor) # print(encoded_sequence) return (encoded_sequence.unsqueeze(1) != mask_token_tensor).all(dim=1).int() def get_rarity_score(self, sequence): scores = np.zeros_like(sequence) for idx, token in enumerate(sequence): # get token count in entire corpus # get TOTAL token count in entire corpus # divide # recriprocal # rarity score for individual token in THIS sequence # average? max? **median**? if self._token_freqs == {}: self.make_token_freqs() if not self.id_to_token: self.id_to_token = {v: k for k, v in self.token_to_id.items()} token_count = self._token_freqs.get(self.id_to_token[token.item()], 0) rarity_score = self.total_num_tokens / token_count if token_count > 0 else 0 scores[idx] = rarity_score return np.float32(np.median(scores)) def get_entropy_score(self, sequence): if len(sequence) == 0: return 0.0 unique, counts = np.unique(sequence, return_counts=True) probs = counts / counts.sum() entropy = -np.sum(probs * np.log2(probs)) if len(unique) > 1: entropy /= np.log2(len(unique)) return np.float32(entropy) class DummySequentialDataManager: def __init__(self, root_dir, vocab_size=5000): print("init") self.root_dir = root_dir self.vocab_size = vocab_size with open(os.path.join(root_dir, "data/corpus_processed.txt"), "w+") as f: f.write("dummy") def encode(self, text: str): return [list(range(50))] def decode(self, ids): l = ids if isinstance(l, torch.Tensor): l = ids.tolist() if isinstance(l, int): l = [l] return " ".join([str(id) for id in l]) @staticmethod def attention_mask(encoded_sequence, mask_token_ids=[]): mask_token_tensor = torch.tensor(mask_token_ids, dtype=torch.int).to( encoded_sequence.device ) # print(mask_token_tensor) # print(encoded_sequence) return (encoded_sequence.unsqueeze(1) != mask_token_tensor).all(dim=1).int() class TextCorpusDataset(Dataset): def __init__( self, root_dir="./test-data", train=False, max_length=512, vocab_size=10000, IS_DUMMY=False, IS_CODE=False, IS_CUSTOM=False, sliding_window=False, stride=1, get_rarity_score=False, get_entropy_score=False, ): print(root_dir) # legendary code print("[TextCorpusDataset]") frame = inspect.currentframe() args, _, _, values = inspect.getargvalues(frame) print("Arguments passed:") for arg in args[1:]: # skip 'self' print(f" {arg} = {values[arg]}") self.root = root_dir self.sliding_window = sliding_window self.window_size = max_length self.stride = stride self.get_rarity_score = get_rarity_score self.get_entropy_score = get_entropy_score if IS_DUMMY: self.manager = DummySequentialDataManager(root_dir=root_dir) elif IS_CODE: if IS_CUSTOM: self.manager = CodeCustomTokenizerManager(root_dir=root_dir) else: self.manager = CodeBPEModelManager( root_dir=root_dir, vocab_size=vocab_size ) else: self.manager = BPEModelManager(root_dir=root_dir, vocab_size=vocab_size) self.max_length = max_length self.cache_file = os.path.join(root_dir, "encoded_chunked.pt") self.rarity_cache_file = os.path.join(root_dir, "rarity_scores.pt") self.entropy_cache_file = os.path.join(root_dir, "entropy_scores.pt") start_t = time.time() if os.path.exists(self.cache_file): self.chunks = torch.load(self.cache_file, weights_only=True) if self.chunks.size(-1) != self.max_length: if ( input( "Attempting to fix and re-chunk data to correct length. Continue? [y/N]: " ) == "y" ): self._chunk_and_save(torch.flatten(self.chunks).tolist()) print("Re-chunked successfully!") else: print("Operation aborted.") else: with open( os.path.join(root_dir, "data/corpus_processed.txt"), "r", errors="ignore", ) as file: text = file.read() encoded = self.manager.encode(text) self._chunk_and_save(encoded) # Load or compute cached scores self._load_or_compute_scores() end_t = time.time() print(f"Dataset loading took {end_t - start_t} seconds.") # TODO: more "optimization" self.chunks = self.chunks.to(DEVICE) if self.get_rarity_score: self.rarity_scores = self.rarity_scores.to(DEVICE) if self.get_entropy_score: self.entropy_scores = self.entropy_scores.to(DEVICE) self.dummy = torch.tensor([1], device=DEVICE) def _chunk_and_save(self, encoded): chunked_data = [] if self.sliding_window: print("sliding!") for i in trange( 0, len(encoded) - self.window_size + 1, self.stride, leave=False ): chunked_data.append( torch.tensor(encoded[i : i + self.window_size], dtype=torch.int) ) else: for i in trange(0, len(encoded), self.max_length, leave=False): chunked_data.append( torch.tensor(encoded[i : i + self.max_length], dtype=torch.int) ) # me when the last item is not necessarily of length self.max_length padded_chunk = torch.zeros(self.max_length, dtype=torch.int) padded_chunk[: len(chunked_data[-1])] = chunked_data[-1] chunked_data[-1] = padded_chunk self.chunks = torch.stack(chunked_data) torch.save(self.chunks, self.cache_file) def _load_or_compute_scores(self): """Load cached scores or compute them if not available""" if self.get_rarity_score: if os.path.exists(self.rarity_cache_file): print("Loading cached rarity scores...") self.rarity_scores = torch.load(self.rarity_cache_file, weights_only=True) if len(self.rarity_scores) != len(self.chunks): print("Rarity cache size mismatch, recomputing...") self._compute_and_cache_rarity_scores() else: print("Computing rarity scores...") self._compute_and_cache_rarity_scores() if self.get_entropy_score: if os.path.exists(self.entropy_cache_file): print("Loading cached entropy scores...") self.entropy_scores = torch.load(self.entropy_cache_file, weights_only=True) if len(self.entropy_scores) != len(self.chunks): print("Entropy cache size mismatch, recomputing...") self._compute_and_cache_entropy_scores() else: print("Computing entropy scores...") self._compute_and_cache_entropy_scores() def _compute_and_cache_rarity_scores(self): """Compute rarity scores for all chunks and cache them""" rarity_scores = [] print("Computing rarity scores for all chunks...") for i in trange(len(self.chunks), desc="Computing rarity scores"): score = self.manager.get_rarity_score(self.chunks[i]) rarity_scores.append(score) self.rarity_scores = torch.tensor(rarity_scores, dtype=torch.float32) torch.save(self.rarity_scores, self.rarity_cache_file) print(f"Cached rarity scores to {self.rarity_cache_file}") def _compute_and_cache_entropy_scores(self): """Compute entropy scores for all chunks and cache them""" entropy_scores = [] print("Computing entropy scores for all chunks...") for i in trange(len(self.chunks), desc="Computing entropy scores"): score = self.manager.get_entropy_score(self.chunks[i]) entropy_scores.append(score) self.entropy_scores = torch.tensor(entropy_scores, dtype=torch.float32) torch.save(self.entropy_scores, self.entropy_cache_file) print(f"Cached entropy scores to {self.entropy_cache_file}") # unused # def _sliding_window(self, sequence, window_size, stride): # windows = [] # for i in range(0, len(sequence) - window_size + 1, stride): # windows.append(sequence[i : i + window_size]) # return torch.stack(windows) def __len__(self): return len(self.chunks) def __getitem__( self, idx ): seq = self.chunks[idx] if self.get_rarity_score: return seq, self.rarity_scores[idx] if self.get_entropy_score: return seq, self.entropy_scores[idx] return seq, self.dummy # self.manager.attention_mask(seq) class Datasplit_chunker(Dataset): def __init__(self, root, name, subset, slide=False, stride=1, length=512): super().__init__() self.root = root if os.path.exists(os.path.join(root, f"encoded_chunked_{name}.pt")): self.items = torch.load( os.path.join(root, f"encoded_chunked_{name}.pt"), weights_only=True ) else: self.items = torch.cat([subset.dataset[idx][0] for idx in subset.indices]) if slide: self.items = self._sliding_window( self.items, window_size=length, stride=stride ) torch.save(self.items, os.path.join(root, f"encoded_chunked_{name}.pt")) print("saved!") self.chunks = self.items self.dummy = torch.tensor([1], device=DEVICE) def _sliding_window(self, sequence, window_size, stride): num_windows = (len(sequence) - window_size) // stride + 1 windows = torch.as_strided( sequence, size=(num_windows, window_size), stride=(stride, 1) ) return windows def __len__(self): return len(self.items) def __getitem__(self, idx): return self.chunks[idx], self.dummy # print("Running....") dataset = TextCorpusDataset( root_dir=os.path.expanduser( # "./dummy-data-dir" # "./smaller-er-test-data" # "./smaller-test-data" # "~/torch_datasets/github-python/all_trains_subset_corpus/all_trains_TRAINSPLIT" #"~/torch_datasets/github-python/all_trains_subset_corpus" # "~/torch_datasets/github-python/corpus" # "~/torch_datasets/github-python/mega_corpus" "~/torch_datasets/github-python/mega_licensed_corpus" ), # os.path.expanduser("~/torch_datasets/wikitext/train") vocab_size=33819, # 3645, # edited by me IS_CODE=True, # Remember to change! IS_CUSTOM=True, # IS_DUMMY=True, max_length=256, sliding_window=False, stride=10, get_rarity_score=True, ) dset_size = int(len(dataset)) train_size = int(0.8 * dset_size) # int(dset_size - 2) test_size = int(dset_size - train_size) if test_size == 2: print("alert! test size is 2 or whatever. Change this back please.") torch.manual_seed(3407) # https://arxiv.org/pdf/2109.08203 train_dataset, test_dataset, _ = random_split( dataset, [train_size, test_size, len(dataset) - train_size - test_size] ) # train_dataset = Datasplit_chunker(dataset.root,"TRAIN", train_dataset, slide=False, stride=10, length=256) # test_dataset = Datasplit_chunker(dataset.root,"TEST", test_dataset, slide=False, stride=10, length=256) # test_dataset = train_dataset # to test if the overfitting is real # train_dataset = dataset # TODO change def get_train_dataset(): return train_dataset def get_test_dataset(): return test_dataset def get_dataloader(dataset, batch_size=64): return DataLoader(dataset, batch_size=batch_size, shuffle=True) def fromDataset(dataset): dset_size = int(len(dataset)) train_size = int(0.8 * dset_size) # int(dset_size - 2) test_size = int(dset_size - train_size) if test_size == 2: print("alert! test size is 2 or whatever. Change this back please.") torch.manual_seed(3407) # https://arxiv.org/pdf/2109.08203 train_dataset, test_dataset, _ = random_split( dataset, [train_size, test_size, len(dataset) - train_size - test_size] ) return train_dataset, test_dataset if __name__ == "__main__": d = get_train_dataset() print("Number of samples: ", len(d)) for a, b in d: # a, b = d[-1] manager = dataset.manager print(a) print(manager.decode(a)) # print(a) print("--- sep batch --- ") print(f"Number of tokens used: {len(dataset.manager.token_to_id)}") break # lazy