| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | MODE = "Dense" |
| |
|
| | TASK_LIST_CLASSIFICATION = [ |
| | "AmazonCounterfactualClassification", |
| | "AmazonPolarityClassification", |
| | "AmazonReviewsClassification", |
| | "Banking77Classification", |
| | "EmotionClassification", |
| | "ImdbClassification", |
| | "MassiveIntentClassification", |
| | "MassiveScenarioClassification", |
| | "MTOPDomainClassification", |
| | "MTOPIntentClassification", |
| | "ToxicConversationsClassification", |
| | "TweetSentimentExtractionClassification", |
| | ] |
| |
|
| | TASK_LIST_CLUSTERING = [ |
| | "ArxivClusteringP2P", |
| | "ArxivClusteringS2S", |
| | "BiorxivClusteringP2P", |
| | "BiorxivClusteringS2S", |
| | "MedrxivClusteringP2P", |
| | "MedrxivClusteringS2S", |
| | "RedditClustering", |
| | "RedditClusteringP2P", |
| | "StackExchangeClustering", |
| | "StackExchangeClusteringP2P", |
| | "TwentyNewsgroupsClustering", |
| | ] |
| |
|
| | TASK_LIST_PAIR_CLASSIFICATION = [ |
| | "SprintDuplicateQuestions", |
| | "TwitterSemEval2015", |
| | "TwitterURLCorpus", |
| | ] |
| |
|
| | TASK_LIST_RERANKING = [ |
| | "AskUbuntuDupQuestions", |
| | "MindSmallReranking", |
| | "SciDocsRR", |
| | "StackOverflowDupQuestions", |
| | ] |
| |
|
| | TASK_LIST_RETRIEVAL = [ |
| | "ArguAna", |
| | "FiQA2018", |
| | "QuoraRetrieval", |
| | "SCIDOCS", |
| | "SciFact", |
| | "Touche2020", |
| | "TRECCOVID", |
| | "NFCorpus", |
| | "NQ", |
| | "ClimateFEVER", |
| | "CQADupstackAndroidRetrieval", |
| | "CQADupstackEnglishRetrieval", |
| | "CQADupstackGamingRetrieval", |
| | "CQADupstackGisRetrieval", |
| | "CQADupstackMathematicaRetrieval", |
| | "CQADupstackPhysicsRetrieval", |
| | "CQADupstackProgrammersRetrieval", |
| | "CQADupstackStatsRetrieval", |
| | "CQADupstackTexRetrieval", |
| | "CQADupstackUnixRetrieval", |
| | "CQADupstackWebmastersRetrieval", |
| | "CQADupstackWordpressRetrieval", |
| | "DBPedia", |
| | "HotpotQA", |
| | "MSMARCO", |
| | "FEVER", |
| | ] |
| |
|
| | TASK_LIST_STS = [ |
| | "BIOSSES", |
| | "SICK-R", |
| | "STS12", |
| | "STS13", |
| | "STS14", |
| | "STS15", |
| | "STS16", |
| | "STS17", |
| | "STS22", |
| | "STSBenchmark", |
| | "SummEval", |
| | ] |
| |
|
| | MTEB_TASK_LIST = ( |
| | TASK_LIST_RETRIEVAL |
| | + TASK_LIST_CLASSIFICATION |
| | + TASK_LIST_CLUSTERING |
| | + TASK_LIST_PAIR_CLASSIFICATION |
| | + TASK_LIST_RERANKING |
| | + TASK_LIST_STS |
| | ) |
| |
|
| |
|
| | CMTEB_TASK_LIST = [ |
| | "TNews", |
| | "IFlyTek", |
| | "MultilingualSentiment", |
| | "JDReview", |
| | "OnlineShopping", |
| | "Waimai", |
| | "AmazonReviewsClassification", |
| | "MassiveIntentClassification", |
| | "MassiveScenarioClassification", |
| | "MultilingualSentiment", |
| | "CLSClusteringS2S", |
| | "CLSClusteringP2P", |
| | "ThuNewsClusteringS2S", |
| | "ThuNewsClusteringP2P", |
| | "Ocnli", |
| | "Cmnli", |
| | "T2Reranking", |
| | "MMarcoReranking", |
| | "CMedQAv1-reranking", |
| | "CMedQAv2-reranking", |
| | "T2Retrieval", |
| | "MMarcoRetrieval", |
| | "DuRetrieval", |
| | "CovidRetrieval", |
| | "CmedqaRetrieval", |
| | "EcomRetrieval", |
| | "MedicalRetrieval", |
| | "VideoRetrieval", |
| | "ATEC", |
| | "BQ", |
| | "LCQMC", |
| | "PAWSX", |
| | "STSB", |
| | "AFQMC", |
| | "QBQTC", |
| | "STS22", |
| | ] |
| |
|
| | MTEB_TASK_LIST = CMTEB_TASK_LIST + MTEB_TASK_LIST |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | import tqdm |
| | import numpy as np |
| | import math |
| |
|
| | from functools import partial |
| | from torch.utils.data import DataLoader |
| | from datasets import Dataset |
| | from transformers import AutoModel, AutoTokenizer, DataCollatorWithPadding, PreTrainedTokenizerFast, BatchEncoding |
| | from transformers.modeling_outputs import BaseModelOutput |
| | from typing import List, Dict |
| | from mteb import MTEB |
| |
|
| | def get_detailed_instruct(task_description: str) -> str: |
| | if not task_description: |
| | return "" |
| |
|
| | return "Instruction: {} Query: ".format(task_description) |
| |
|
| |
|
| |
|
| | def get_task_def_by_task_name_and_type( |
| | task_name: str, |
| | task_type: str, |
| | default_instruct="", |
| | ): |
| | if task_type in ["STS"]: |
| | return None |
| |
|
| | if task_type in ["Summarization"]: |
| | return "Given a news summary, retrieve other semantically similar summaries" |
| |
|
| | if task_type in ["Classification"]: |
| | task_name_to_instruct: Dict[str, str] = { |
| | "AmazonCounterfactualClassification": "Classify a given Amazon customer review text as either counterfactual or not-counterfactual.", |
| | "AmazonPolarityClassification": "Classify Amazon reviews into positive or negative sentiment.", |
| | "AmazonReviewsClassification": "Classify the given Amazon review into its appropriate rating category.", |
| | "Banking77Classification": "Given a online banking query, find the corresponding intents.", |
| | "EmotionClassification": "Classify the emotion expressed in the given Twitter message into one of the six emotions: anger, fear, joy, love, sadness, and surprise.", |
| | "ImdbClassification": "Classify the sentiment expressed in the given movie review text from the IMDB dataset.", |
| | "MassiveIntentClassification": "Given a user utterance as query, find the user intents.", |
| | "MassiveScenarioClassification": "Given a user utterance as query, find the user scenarios.", |
| | "MTOPDomainClassification": "Classify the intent domain of the given utterance in task-oriented conversation.", |
| | "MTOPIntentClassification": "Classify the intent of the given utterance in task-oriented conversation.", |
| | "ToxicConversationsClassification": "Classify the given comments as either toxic or not toxic.", |
| | "TweetSentimentExtractionClassification": "Classify the sentiment of a given tweet as either positive, negative, or neutral.", |
| | |
| | "TNews": "根据标题确定新闻的类别。", |
| | "IFlyTek": "根据描述确定APP的类别。", |
| | "MultilingualSentiment": "将亚马逊评论分为积极、消极或中立情绪。", |
| | "JDReview": "将商品评论分为积极或消极情绪。", |
| | "OnlineShopping": "将商品评论分为积极或消极情绪。", |
| | "Waimai": "将外卖评论分为积极或消极情绪。", |
| | } |
| | return task_name_to_instruct.get(task_name,None) |
| |
|
| | if task_type in ["Clustering"]: |
| | task_name_to_instruct: Dict[str, str] = { |
| | "ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts.", |
| | "ArxivClusteringS2S": "Identify the main and secondary category of Arxiv papers based on the titles.", |
| | "BiorxivClusteringP2P": "Identify the main category of Biorxiv papers based on the titles and abstracts.", |
| | "BiorxivClusteringS2S": "Identify the main category of Biorxiv papers based on the titles.", |
| | "MedrxivClusteringP2P": "Identify the main category of Medrxiv papers based on the titles and abstracts.", |
| | "MedrxivClusteringS2S": "Identify the main category of Medrxiv papers based on the titles.", |
| | "RedditClustering": "Identify the topic or theme of Reddit posts based on the titles.", |
| | "RedditClusteringP2P": "Identify the topic or theme of Reddit posts based on the titles and posts.", |
| | "StackExchangeClustering": "Identify the topic or theme of StackExchange posts based on the titles.", |
| | "StackExchangeClusteringP2P": "Identify the topic or theme of StackExchange posts based on the given paragraphs.", |
| | "TwentyNewsgroupsClustering": "Identify the topic or theme of the given news articles.", |
| | |
| | "CLSClusteringS2S": "根据标题确定文章的类别。", |
| | "CLSClusteringP2P": "根据标题和摘要确定文章的类别。", |
| | "ThuNewsClusteringS2S": "根据标题确定新闻的类别。", |
| | "ThuNewsClusteringP2P": "根据标题和摘要确定新闻的类别。", |
| | } |
| | return task_name_to_instruct.get(task_name,None) |
| |
|
| | if task_type in ["Reranking", "PairClassification"]: |
| | task_name_to_instruct: Dict[str, str] = { |
| | "AskUbuntuDupQuestions": "Retrieve duplicate questions from AskUbuntu forum.", |
| | "MindSmallReranking": "Retrieve relevant news articles based on user browsing history.", |
| | "SciDocsRR": "Given a title of a scientific paper, retrieve the titles of other relevant papers.", |
| | "StackOverflowDupQuestions": "Retrieve duplicate questions from StackOverflow forum.", |
| | "SprintDuplicateQuestions": "Retrieve duplicate questions from Sprint forum.", |
| | "TwitterSemEval2015": "Retrieve tweets that are semantically similar to the given tweet.", |
| | "TwitterURLCorpus": "Retrieve tweets that are semantically similar to the given tweet.", |
| | |
| | "T2Reranking": "为这个问题检索相关段落。", |
| | "MMarcoReranking": "为这个查询检索相关段落。", |
| | "CMedQAv1-reranking": "为这个医疗问题检索相关回答。", |
| | "CMedQAv2-reranking": "为这个医疗问题检索相关回答。", |
| | } |
| |
|
| | return task_name_to_instruct.get(task_name,None) |
| |
|
| | if task_type in ["Retrieval"]: |
| | if task_name.lower().startswith("cqadupstack"): |
| | return "Given a question, retrieve detailed question descriptions from Stackexchange that are duplicates to the given question" |
| |
|
| | task_name_to_instruct: Dict[str, str] = { |
| | "ArguAna": "Given a claim, find documents that refute the claim.", |
| | "ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim.", |
| | "DBPedia": "Given a query, retrieve relevant entity descriptions from DBPedia.", |
| | "FEVER": "Given a claim, retrieve documents that support or refute the claim.", |
| | "FiQA2018": "Given a financial question, retrieve user replies that best answer the question.", |
| | "HotpotQA": "Given a multi-hop question, retrieve documents that can help answer the question.", |
| | "MSMARCO": "Given a web search query, retrieve relevant passages that answer the query.", |
| | "NFCorpus": "Given a question, retrieve relevant documents that best answer the question.", |
| | "NQ": "Given a question, retrieve Wikipedia passages that answer the question.", |
| | "QuoraRetrieval": "Given a question, retrieve questions that are semantically equivalent to the given question.", |
| | "SCIDOCS": "Given a scientific paper title, retrieve paper abstracts that are cited by the given paper.", |
| | "SciFact": "Given a scientific claim, retrieve documents that support or refute the claim.", |
| | "Touche2020": "Given a question, retrieve detailed and persuasive arguments that answer the question.", |
| | "TRECCOVID": "Given a query on COVID-19, retrieve documents that answer the query.", |
| | |
| | "T2Retrieval": "为这个问题检索相关段落。", |
| | "MMarcoRetrieval": "为这个查询检索相关段落。", |
| | "DuRetrieval": "为这个问题检索相关百度知道回答。", |
| | "CovidRetrieval": "为这个问题检索相关政策回答。", |
| | "CmedqaRetrieval": "为这个医疗问题检索相关回答。", |
| | "EcomRetrieval": "为这个查询检索相关商品标题。", |
| | "MedicalRetrieval": "为这个医疗问题检索相关回答。", |
| | "VideoRetrieval": "为这个电影标题检索相关段落。", |
| | } |
| |
|
| | task_name_to_instruct.update({k.lower(): v for k, v in task_name_to_instruct.items()}) |
| |
|
| | return task_name_to_instruct.get(task_name,None) |
| | return default_instruct |
| | def _transform_func(tokenizer: PreTrainedTokenizerFast, |
| | examples: Dict[str, List]) -> BatchEncoding: |
| | batch_dict = tokenizer(examples['input_texts'], |
| | max_length=1024, |
| | padding=True, |
| | truncation=True) |
| |
|
| | return batch_dict |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def mean_pooling(hidden,attention_mask): |
| | |
| | s = torch.sum(hidden * attention_mask.unsqueeze(-1).float(), dim=1) |
| | d = attention_mask.sum(dim=1, keepdim=True).float() |
| | return s / d |
| |
|
| | def wmean_pooling(hidden,attention_mask): |
| | attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) |
| | hidden_masked = hidden * attention_mask_.unsqueeze(-1).float() |
| | s = torch.sum(hidden_masked, dim=1) |
| | d = attention_mask_.sum(dim=1, keepdim=True).float() |
| | reps = s / d |
| | return reps |
| |
|
| | def reverse_wmean_pooling(hidden,attention_mask): |
| | attention_mask_ = attention_mask * attention_mask.cumsum(dim=1) |
| | d = attention_mask_.sum(dim=1, keepdim=True).unsqueeze(1).float() / attention_mask.sum(dim=1, keepdim=True).unsqueeze(1).float() |
| | hidden = hidden.float() * d |
| | return hidden / torch.clamp(attention_mask_.unsqueeze(-1).float(),min=1e-9) |
| |
|
| |
|
| | def sparse_pooling(head,model,items,hidden,attention_mask): |
| | hidden = reverse_wmean_pooling(hidden,attention_mask) |
| | max_hidden_norm = torch.max(torch.norm(hidden,dim=-1),dim = -1).values |
| | token_weights = torch.relu(head(hidden.float()/max_hidden_norm.unsqueeze(-1).unsqueeze(-1))) |
| | vocab_size = model.embed_tokens.weight.size(0) |
| | input_ids = items["input_ids"] |
| | sparse_embedding_chunks = [] |
| | mini_chunk_size = 1 |
| | mini_chunk_size = min(mini_chunk_size,hidden.shape[0]) |
| | for i in range(0, token_weights.size(0), mini_chunk_size): |
| | now_chunk_size = min(mini_chunk_size, token_weights.size(0) - i) |
| | sparse_embedding = torch.zeros(now_chunk_size , input_ids.size(1), vocab_size, |
| | dtype=token_weights.dtype, |
| | device=token_weights.device) |
| | sparse_embedding_chunks.append(torch.max((torch.scatter(sparse_embedding, dim=-1, index=input_ids[i:i+now_chunk_size, :].unsqueeze(-1), src=token_weights[i:i+now_chunk_size, :])), dim=1).values) |
| | sparse_embedding = torch.concat(sparse_embedding_chunks, dim=0) |
| | unused_tokens = [0,1,2,73440] |
| | sparse_embedding[:, unused_tokens] *= 0. |
| | return sparse_embedding |
| |
|
| | def concat_pooling(head,model,items,hidden,attention_mask): |
| | mean_reps = mean_pooling(hidden,attention_mask) |
| | mean_reps = F.normalize(mean_reps, p=2, dim=1) |
| | sparse_reps = sparse_pooling(head,model,items,hidden,attention_mask) * math.sqrt(0.3) |
| | return torch.cat([mean_reps,sparse_reps],dim=-1) |
| |
|
| | |
| |
|
| | class DenseEncoder(torch.nn.Module): |
| | def __init__(self, **kwargs): |
| | super().__init__() |
| | |
| | model_path = "openbmb/MiniCPM-Embedding-Light" |
| | self.encoder = AutoModel.from_pretrained(model_path, trust_remote_code=True,attn_implementation="flash_attention_2", torch_dtype=torch.float16).to("cuda") |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| | self.gpu_count = torch.cuda.device_count() |
| | self.instruction = "" |
| |
|
| | self.encoder.eval() |
| | self.encoder.cuda() |
| |
|
| | if self.gpu_count > 1: |
| | self.encoder = torch.nn.DataParallel(self.encoder) |
| | |
| | @torch.no_grad() |
| | def encode(self, sentences,is_query=None, **kwargs) -> np.ndarray: |
| | """ Returns a list of embeddings for the given sentences. |
| | Args: |
| | sentences (`List[str]`): List of sentences to encode |
| | batch_size (`int`): Batch size for the encoding |
| | |
| | Returns: |
| | `List[np.ndarray]` or `List[tensor]`: List of embeddings for the given sentences |
| | """ |
| | if is_query is not False: |
| | sentences = [self.instruction + s for s in sentences] |
| | dataset: Dataset = Dataset.from_dict({'input_texts': sentences}) |
| | |
| | |
| | dataset.set_transform(partial(_transform_func, self.tokenizer)) |
| |
|
| | data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8) |
| | data_loader = DataLoader( |
| | dataset, |
| | batch_size=128* self.gpu_count, |
| | shuffle=False, |
| | drop_last=False, |
| | num_workers=2, |
| | collate_fn=data_collator, |
| | pin_memory=True) |
| |
|
| | encoded_embeds = [] |
| | for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10): |
| |
|
| | with torch.cuda.amp.autocast() and torch.no_grad(): |
| | for key in batch_dict: |
| | batch_dict[key] = batch_dict[key].to("cuda") |
| | outputs: BaseModelOutput = self.encoder(**batch_dict) |
| | if MODE == "Dense": |
| | embeds = mean_pooling(outputs.last_hidden_state, batch_dict['attention_mask']) |
| | embeds = F.normalize(embeds, p=2, dim=1) |
| | elif MODE == "Sparse": |
| | embeds = sparse_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask']) |
| | else: |
| | embeds = concat_pooling(self.encoder.module.head,self.encoder.module, batch_dict, outputs.last_hidden_state, batch_dict['attention_mask']) |
| | encoded_embeds.append(embeds.cpu().numpy()) |
| |
|
| | return np.concatenate(encoded_embeds, axis=0) |
| | |
| | @torch.no_grad() |
| | def encode_queries(self, queries: list[str], **kwargs) -> list[np.ndarray] | list[torch.Tensor]: |
| | """ |
| | Returns a list of embeddings for the given sentences. |
| | Args: |
| | queries: List of sentences to encode |
| | |
| | Returns: |
| | List of embeddings for the given sentences |
| | """ |
| |
|
| |
|
| | queries = [query for query in queries] |
| | return self.encode(queries, is_query=True, **kwargs) |
| | |
| | @torch.no_grad() |
| | def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs): |
| | |
| | if type(corpus) is dict: |
| | sentences = [ |
| | (corpus["title"][i] + " " + corpus["text"][i]).strip() |
| | if "title" in corpus |
| | else corpus["text"][i].strip() |
| | for i in range(len(corpus["text"])) |
| | ] |
| | elif isinstance(corpus[0], dict): |
| | sentences = [ |
| | (doc["title"] + " " + doc["text"]).strip() |
| | if "title" in doc |
| | else doc["text"].strip() |
| | for doc in corpus |
| | ] |
| | else: |
| | sentences = corpus |
| | is_query = False |
| | return self.encode(sentences, is_query=is_query, **kwargs) |
| |
|
| |
|
| | model = DenseEncoder() |
| | task_names = MTEB_TASK_LIST |
| | task_names = ["NFCorpus"] |
| | lang = ["en","zh", "zh-CN"] |
| |
|
| | for task in task_names: |
| | try: |
| | evaluation = MTEB(tasks=[task], task_langs=lang) |
| | task_cls = evaluation.tasks[0] |
| | task_name: str = task_cls.metadata_dict["name"] |
| | task_type: str = task_cls.metadata_dict["type"] |
| | instruction = get_task_def_by_task_name_and_type(task_name, task_type) |
| | model.instruction = get_detailed_instruct(instruction) |
| | print(model.instruction) |
| | if task == "MSMARCO": |
| | eval_splits = ["dev"] |
| | elif task in CMTEB_TASK_LIST: |
| | eval_splits = task_cls.metadata_dict["eval_splits"] |
| | else: |
| | eval_splits = ["test"] |
| | evaluation.run(model, eval_splits=eval_splits, overwrite_results=True) |
| | |
| | except Exception as e: |
| | import traceback |
| | print(traceback.format_exc()) |
| | continue |