splade-code-06B / utils.py
maxoul's picture
Update utils.py
450c551 verified
import numpy as np
import torch
from typing import Any
from transformers import AutoTokenizer
def splade_max(features, attention_mask):
"""
SPLADE pooling operation
"""
relu = torch.nn.ReLU(inplace=False)
values, ids_ = torch.max(
torch.log(1 + relu(features)) * attention_mask.unsqueeze(-1), dim=1
)
return values, ids_
def encode(
self,
sentences: list[str],
max_length: int = 1024,
prompt_type: str = "document",
return_dict: bool = False,
print_dict: bool = False,
batch_size: int = 8,
top_k_q: int = -1,
top_k_d: int = -1,
**kwargs: Any,
) -> np.ndarray:
all_embeddings = []
for i in range(0, len(sentences), batch_size):
batch_texts = sentences[i : i + batch_size]
batch_dict = self.create_batch_dict(batch_texts, max_length)
batch_dict = {
key: value.to(self.model.device) for key, value in batch_dict.items()
}
with torch.no_grad():
splare_reps = self(**batch_dict)[0]
if prompt_type == "query" and top_k_q > 0:
splare_reps = top_k(splare_reps, top_k_q)
if prompt_type == "document" and top_k_d > 0:
splare_reps = top_k(splare_reps, top_k_d)
all_embeddings.append(splare_reps.cpu().float().numpy())
if return_dict:
d = bow_dict(self, np.concatenate(all_embeddings, axis=0))
if print_dict:
print_bow_bars(sentences, d)
return d
else:
return np.concatenate(all_embeddings, axis=0)
def bow_dict(self, embeddings):
out = []
for vector in embeddings:
idx = np.nonzero(vector)[0]
weights = vector[idx]
d = {k: v for k, v in zip(idx.tolist(), weights.tolist())}
sorted_d = {
self.reverse_voc[k]: float(v)
for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)
}
out.append(sorted_d)
return out
def print_bow_bars(sentences, bow_list, width=20):
ascii_header("TOP ACTIVATED WORDS")
for sent, bow in zip(sentences, bow_list):
print(f"* INPUT: {sent}\n")
max_w = max(bow.values())
for k, v in sorted(bow.items(), key=lambda x: x[1], reverse=True):
bar = "█" * int(v / max_w * width)
print(f"{k[:25]:25} | {bar} {v:.2f}")
print("\n")
def ascii_header(title, width=70):
title = f" {title} "
print("+" + "-" * (width - 2) + "+")
print("|" + title.center(width - 2) + "|")
print("+" + "-" * (width - 2) + "+")
print("\n")
def similarity(self, a, b) -> torch.Tensor:
"""
MTEB eval requires this
"""
if not isinstance(a, torch.Tensor):
a = torch.tensor(a)
if not isinstance(b, torch.Tensor):
b = torch.tensor(b)
def _dot_score_core(a_tensor, b_tensor):
if len(a_tensor.shape) == 1:
a_tensor = a_tensor.unsqueeze(0)
if len(b_tensor.shape) == 1:
b_tensor = b_tensor.unsqueeze(0)
return a_tensor @ b_tensor.transpose(0, 1)
return _dot_score_core(a, b)
def prepare_tokenizer(tokenizer_name: str, padding_side="right"):
"""
loads and prepares tokenizer
"""
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tokenizer.pad_token = (
tokenizer.bos_token or tokenizer.pad_token or tokenizer.eos_token
)
tokenizer.padding_side = padding_side
return tokenizer
def get_decoder_model(
model_name_or_path: str, attn_implementation: str, bidirectional: bool, base_cfg, token=None
):
"""
base_cfg is the pretrained config of the underlying model
"""
print("WARNING: bidirectional only tested for transformer 4.51.2")
assert (
bidirectional is True
), "the model has been trained with bi-directional attention!"
assert (
attn_implementation == "flash_attention_2"
), f"bidir models only support flash_attention_2 for now, not {attn_implementation}!"
from .modeling_qwen3_bidir import Qwen3BidirForCausalLM
return Qwen3BidirForCausalLM.from_pretrained(
model_name_or_path,
config=base_cfg,
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
token=token,
)
def top_k(x: torch.Tensor, k: int) -> torch.Tensor:
"""
zeroes out all but the top-k values in the last dimension of x
"""
_, topk_indices = x.topk(k, dim=-1)
# create a zero tensor of the same shape as x
mask = torch.zeros_like(x, dtype=torch.bool)
# use scatter along the last dimension
mask.scatter_(-1, topk_indices, True)
# zero out all but the top-k
return x * mask