import numpy as np import torch from typing import Any from transformers import AutoTokenizer def splade_max(features, attention_mask): """ SPLADE pooling operation """ relu = torch.nn.ReLU(inplace=False) values, ids_ = torch.max( torch.log(1 + relu(features)) * attention_mask.unsqueeze(-1), dim=1 ) return values, ids_ def encode( self, sentences: list[str], max_length: int = 1024, prompt_type: str = "document", return_dict: bool = False, print_dict: bool = False, batch_size: int = 8, top_k_q: int = -1, top_k_d: int = -1, **kwargs: Any, ) -> np.ndarray: all_embeddings = [] for i in range(0, len(sentences), batch_size): batch_texts = sentences[i : i + batch_size] batch_dict = self.create_batch_dict(batch_texts, max_length) batch_dict = { key: value.to(self.model.device) for key, value in batch_dict.items() } with torch.no_grad(): splare_reps = self(**batch_dict)[0] if prompt_type == "query" and top_k_q > 0: splare_reps = top_k(splare_reps, top_k_q) if prompt_type == "document" and top_k_d > 0: splare_reps = top_k(splare_reps, top_k_d) all_embeddings.append(splare_reps.cpu().float().numpy()) if return_dict: d = bow_dict(self, np.concatenate(all_embeddings, axis=0)) if print_dict: print_bow_bars(sentences, d) return d else: return np.concatenate(all_embeddings, axis=0) def bow_dict(self, embeddings): out = [] for vector in embeddings: idx = np.nonzero(vector)[0] weights = vector[idx] d = {k: v for k, v in zip(idx.tolist(), weights.tolist())} sorted_d = { self.reverse_voc[k]: float(v) for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True) } out.append(sorted_d) return out def print_bow_bars(sentences, bow_list, width=20): ascii_header("TOP ACTIVATED WORDS") for sent, bow in zip(sentences, bow_list): print(f"* INPUT: {sent}\n") max_w = max(bow.values()) for k, v in sorted(bow.items(), key=lambda x: x[1], reverse=True): bar = "█" * int(v / max_w * width) print(f"{k[:25]:25} | {bar} {v:.2f}") print("\n") def ascii_header(title, width=70): title = f" {title} " print("+" + "-" * (width - 2) + "+") print("|" + title.center(width - 2) + "|") print("+" + "-" * (width - 2) + "+") print("\n") def similarity(self, a, b) -> torch.Tensor: """ MTEB eval requires this """ if not isinstance(a, torch.Tensor): a = torch.tensor(a) if not isinstance(b, torch.Tensor): b = torch.tensor(b) def _dot_score_core(a_tensor, b_tensor): if len(a_tensor.shape) == 1: a_tensor = a_tensor.unsqueeze(0) if len(b_tensor.shape) == 1: b_tensor = b_tensor.unsqueeze(0) return a_tensor @ b_tensor.transpose(0, 1) return _dot_score_core(a, b) def prepare_tokenizer(tokenizer_name: str, padding_side="right"): """ loads and prepares tokenizer """ tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) tokenizer.pad_token = ( tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token ) tokenizer.padding_side = padding_side return tokenizer def get_decoder_model( model_name_or_path: str, attn_implementation: str, bidirectional: bool, base_cfg, token=None ): """ base_cfg is the pretrained config of the underlying model """ print("WARNING: bidirectional only tested for transformer 4.51.2") assert ( bidirectional is True ), "the model has been trained with bi-directional attention!" assert ( attn_implementation == "flash_attention_2" ), f"bidir models only support flash_attention_2 for now, not {attn_implementation}!" from .modeling_qwen3_bidir import Qwen3BidirForCausalLM return Qwen3BidirForCausalLM.from_pretrained( model_name_or_path, config=base_cfg, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation, token=token, ) def top_k(x: torch.Tensor, k: int) -> torch.Tensor: """ zeroes out all but the top-k values in the last dimension of x """ _, topk_indices = x.topk(k, dim=-1) # create a zero tensor of the same shape as x mask = torch.zeros_like(x, dtype=torch.bool) # use scatter along the last dimension mask.scatter_(-1, topk_indices, True) # zero out all but the top-k return x * mask