Spaces:
Runtime error
Runtime error
| import torch | |
| from utils import gather_log_probs, mask_hf_labels, masked_mean | |
| def es_sentiment(pre_logits, post_logits, raw_targets, same_sent_mask, NULL_TOKEN=0): | |
| with torch.no_grad(): | |
| mask, targ = mask_hf_labels(raw_targets) | |
| pos_mask = same_sent_mask.unsqueeze(-1) * mask | |
| neg_mask = (~same_sent_mask).unsqueeze(-1) * mask | |
| # Compute log likelihoods of pos/neg samples | |
| pre_edit_token_log_probs = gather_log_probs(pre_logits, targ) | |
| post_edit_token_log_probs = gather_log_probs(post_logits, targ) | |
| mean_pos_pre = masked_mean(pre_edit_token_log_probs, pos_mask) | |
| mean_pos_post = masked_mean(post_edit_token_log_probs, pos_mask) | |
| mean_neg_post = masked_mean(post_edit_token_log_probs, neg_mask) | |
| z_sent = (mean_pos_post - mean_neg_post).sigmoid() | |
| z_topic_raw = (mean_pos_post - mean_pos_pre).exp() | |
| z_topic = min(1, z_topic_raw) | |
| es_sent = z_sent * z_topic | |
| return { | |
| "acc_sent": es_sent, | |
| "z_sent": z_sent, | |
| "z_topic": z_topic, | |
| "z_topic_raw": z_topic_raw, | |
| "correct_probs": mean_pos_post, | |
| "wrong_probs": mean_neg_post, | |
| } | |
| # DEPRECATED | |
| def sent_success(pre_edit_probs, post_edit_probs, pos_mask, eps=torch.finfo(torch.float32).eps, batch_size=20): | |
| assert False, "No longer used" | |
| # content_score = post_edit_probs[pos_mask].prod() ** (1/pos_mask.sum()) / (pre_edit_probs[pos_mask]. + eps) | |
| post_pos_avg = post_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum()) | |
| pre_pos_avg = pre_edit_probs[pos_mask].prod() ** (1 / pos_mask.sum()) | |
| content_score = post_pos_avg / (pre_pos_avg + eps) | |
| z_content = min(1., content_score) | |
| # compute z_sent through a weighting objective | |
| # normalized_probs = post_edit_probs / (post_edit_probs.sum() + eps) | |
| # balancing_factor = 0.5 * ((~pos_mask).float().sum() / pos_mask.float().sum() + 1) | |
| # z_sent_weight = balancing_factor * normalized_probs.dot(pos_mask.float()) | |
| post_neg_avg = post_edit_probs[~pos_mask].prod() ** (1 / (~pos_mask).sum()) | |
| neg_over_pos = post_neg_avg / (eps + post_pos_avg) | |
| z_sent_weight = 1 / (1 + neg_over_pos) | |
| # compute z_sent through a ranking objective | |
| batch_mask = pos_mask.view(-1, batch_size).long() | |
| sort_idxs = post_edit_probs.view(-1, batch_size).sort(-1, descending=True).indices | |
| ranked_mask = batch_mask.gather(1, sort_idxs) | |
| true_mask = batch_mask.sort(-1, descending=True).values | |
| z_sent_rank = (ranked_mask == true_mask).float().mean() | |
| # compute the final success scores | |
| weight_success = (z_content * z_sent_weight) ** 0.5 | |
| rank_success = (z_content * z_sent_rank) ** 0.5 | |
| correct_probs = post_edit_probs[pos_mask].mean() | |
| wrong_probs = post_edit_probs[~pos_mask].mean() | |
| return { | |
| "acc_weight": weight_success, | |
| "acc_rank": rank_success, | |
| "rank_score": z_sent_rank, | |
| "weight_score": z_sent_weight, | |
| "content_score": content_score, | |
| "post_edit_probs": post_edit_probs.sum(), | |
| "pre_edit_probs": pre_edit_probs.sum(), | |
| "correct_probs": correct_probs, | |
| "wrong_probs": wrong_probs | |
| } | |
| # def sent_retain(pre_logits, post_logits, sent_mask, batch_size=20, eps=torch.finfo(torch.float32).eps): | |
| # pre_log_probs = pre_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1) | |
| # post_log_probs = post_logits.log_softmax(-1).gather(-1, all_targ.unsqueeze(-1)).squeeze(-1) | |
| # pre_batch = pre_probs.view(-1, batch_size) | |
| # post_batch = post_probs.view(-1, batch_size) | |
| # mask_batch = sent_mask.view(-1, batch_size) | |
| # stats = [] | |
| # for pre, post, mask in zip(pre_batch, post_batch, mask_batch): | |
| # avg_pre = pre.prod() ** (1 / pre.numel()) | |
| # avg_post = post.prod() ** (1 / post.numel()) | |
| # z_avg = min(avg_pre / avg_post, avg_post / avg_pre) | |
| # post_neg_avg = post[~mask].prod() ** (1 / (~mask).sum()) | |
| # post_pos_avg = post[mask].prod() ** (1 / mask.sum()) | |
| # pre_neg_avg = pre[~mask].prod() ** (1 / (~mask).sum()) | |
| # pre_pos_avg = pre[mask].prod() ** (1 / mask.sum()) | |
| # post_neg_over_pos = post_neg_avg / (eps + post_pos_avg) | |
| # pre_neg_over_pos = pre_neg_avg / (eps + pre_pos_avg) | |
| # z_post = 1 / (1 + post_neg_over_pos) | |
| # z_pre = 1 / (1 + pre_neg_over_pos) | |
| # z_sent = min(z_post / z_pre, z_pre / z_post) | |
| # stats.append((z_avg * z_sent) ** 0.5) | |
| # return sum(stats) / len(stats) | |
| # For zsRE and F-NLI | |
| def retain_rate(pre_logits, post_logits, mask=None): | |
| if pre_logits.shape[-1] == 1: | |
| pre_logits = pre_logits.squeeze(-1) | |
| if post_logits.shape[-1] == 1: | |
| post_logits = post_logits.squeeze(-1) | |
| assert pre_logits.shape == post_logits.shape | |
| assert pre_logits.shape[0] == mask.shape[0] | |
| if pre_logits.dim() == 1: | |
| # binary classification | |
| pre_preds = pre_logits > 0 | |
| post_preds = post_logits > 0 | |
| retain = (pre_preds == post_preds).float().mean() | |
| elif pre_logits.dim() == 3: | |
| # sequence modeling | |
| pre_preds = pre_logits.argmax(-1) | |
| post_preds = post_logits.argmax(-1) | |
| match = (pre_preds == post_preds) * mask | |
| retain = (match.sum(-1) == mask.sum(-1)).float().mean() | |
| else: | |
| raise NotImplementedError | |
| return retain.item() | |