Roman Castagné
commited on
Commit
·
bfe5167
1
Parent(s):
9b92850
metric file
Browse files
ERR.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Token prediction metric."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
import datasets
|
| 6 |
+
import numpy as np
|
| 7 |
+
from Levenshtein import distance as levenshtein_distance
|
| 8 |
+
from scipy.optimize import linear_sum_assignment
|
| 9 |
+
|
| 10 |
+
import evaluate
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
_DESCRIPTION = """
|
| 14 |
+
Unofficial implementation of the Error Reduction Rate (ERR) metric introduced for lexical normalization.
|
| 15 |
+
This implementation works on Seq2Seq models by aligning the predictions with the ground truth outputs.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
_KWARGS_DESCRIPTION = """
|
| 20 |
+
Args:
|
| 21 |
+
predictions (`list` of `str`): Predicted labels.
|
| 22 |
+
references (`list` of `Dict[str, str]`): Ground truth sentences, each with a field `input` and `output`.
|
| 23 |
+
Returns:
|
| 24 |
+
`err` (`float` or `int`): Error Reduction Rate. See here: http://noisy-text.github.io/2021/multi-lexnorm.html
|
| 25 |
+
`err_tp` (`int`): Number of true positives.
|
| 26 |
+
`err_fn` (`int`): Number of false negatives.
|
| 27 |
+
`err_tn` (`int`): Number of true negatives.
|
| 28 |
+
`err_fp` (`int`): Number of false positives.
|
| 29 |
+
Examples:
|
| 30 |
+
Example 1-A simple example
|
| 31 |
+
>>> err = evaluate.load("err")
|
| 32 |
+
>>> results = err.compute(predictions=[["The", "large", "dog"]], references=[{"input": ["The", "large", "dawg"], "output": ["The", "large", "dog"]}])
|
| 33 |
+
>>> print(results)
|
| 34 |
+
{'err': 1.0, 'err_tp': 2, 'err_fn': 0, 'err_tn': 1, 'err_fp': 0}
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
_CITATION = """
|
| 39 |
+
@inproceedings{baldwin-etal-2015-shared,
|
| 40 |
+
title = "Shared Tasks of the 2015 Workshop on Noisy User-generated Text: {T}witter Lexical Normalization and Named Entity Recognition",
|
| 41 |
+
author = "Baldwin, Timothy and
|
| 42 |
+
de Marneffe, Marie Catherine and
|
| 43 |
+
Han, Bo and
|
| 44 |
+
Kim, Young-Bum and
|
| 45 |
+
Ritter, Alan and
|
| 46 |
+
Xu, Wei",
|
| 47 |
+
booktitle = "Proceedings of the Workshop on Noisy User-generated Text",
|
| 48 |
+
month = jul,
|
| 49 |
+
year = "2015",
|
| 50 |
+
address = "Beijing, China",
|
| 51 |
+
publisher = "Association for Computational Linguistics",
|
| 52 |
+
url = "https://aclanthology.org/W15-4319",
|
| 53 |
+
doi = "10.18653/v1/W15-4319",
|
| 54 |
+
pages = "126--135",
|
| 55 |
+
}
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
| 60 |
+
class ErrorReductionRate(evaluate.Metric):
|
| 61 |
+
def _info(self):
|
| 62 |
+
return evaluate.MetricInfo(
|
| 63 |
+
description=_DESCRIPTION,
|
| 64 |
+
citation=_CITATION,
|
| 65 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
| 66 |
+
features=datasets.Features(
|
| 67 |
+
{
|
| 68 |
+
"predictions": datasets.Sequence(datasets.Value("string")),
|
| 69 |
+
"references": {
|
| 70 |
+
"input": datasets.Sequence(datasets.Value("string")),
|
| 71 |
+
"output": datasets.Sequence(datasets.Value("string")),
|
| 72 |
+
},
|
| 73 |
+
}
|
| 74 |
+
),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
def _compute(self, predictions, references):
|
| 78 |
+
|
| 79 |
+
tp, fn, tn, fp = 0, 0, 0, 0
|
| 80 |
+
for pred, ref in zip(predictions, references):
|
| 81 |
+
inputs, outputs = ref["input"], ref["output"]
|
| 82 |
+
|
| 83 |
+
labels = self._split_expressions_into_tokens(outputs)
|
| 84 |
+
|
| 85 |
+
assert len(pred) == len(
|
| 86 |
+
labels
|
| 87 |
+
), f"Number of predicted words ({len(pred)}) does not match number of target words ({len(labels)})"
|
| 88 |
+
|
| 89 |
+
formatted_preds = self._align_predictions_with_labels(pred, labels)
|
| 90 |
+
|
| 91 |
+
for i in range(len(inputs)):
|
| 92 |
+
# Normalization was necessary
|
| 93 |
+
if inputs[i].lower() != outputs[i]:
|
| 94 |
+
tp += formatted_preds[i] == outputs[i]
|
| 95 |
+
fn += formatted_preds[i] != outputs[i]
|
| 96 |
+
else:
|
| 97 |
+
tn += formatted_preds[i] == outputs[i]
|
| 98 |
+
fp += formatted_preds[i] != outputs[i]
|
| 99 |
+
|
| 100 |
+
err = (tp - fp) / (tp + fn)
|
| 101 |
+
|
| 102 |
+
return {"err": err, "err_tp": tp, "err_fn": fn, "err_tn": tn, "err_fp": fp}
|
| 103 |
+
|
| 104 |
+
def _align_predictions_with_labels(self, predictions: List[str], labels: List[Tuple[str, int]]) -> List[str]:
|
| 105 |
+
levenshtein_matrix = np.zeros((len(labels), len(predictions)))
|
| 106 |
+
|
| 107 |
+
for i, (label, _) in enumerate(labels):
|
| 108 |
+
for j, pred in enumerate(predictions):
|
| 109 |
+
levenshtein_matrix[i, j] = levenshtein_distance(label, pred)
|
| 110 |
+
|
| 111 |
+
col_alignment, row_alignment = linear_sum_assignment(levenshtein_matrix)
|
| 112 |
+
alignment = sorted(row_alignment, key=lambda i: col_alignment[i])
|
| 113 |
+
|
| 114 |
+
num_outputs = max(map(lambda x: x[1], labels)) + 1
|
| 115 |
+
formatted_preds = [[] for _ in range(num_outputs)]
|
| 116 |
+
for i, aligned_idx in enumerate(alignment):
|
| 117 |
+
formatted_preds[labels[i][1]].append(predictions[aligned_idx])
|
| 118 |
+
|
| 119 |
+
formatted_preds = [" ".join(preds) for preds in formatted_preds]
|
| 120 |
+
|
| 121 |
+
return formatted_preds
|
| 122 |
+
|
| 123 |
+
def _split_expressions_into_tokens(self, outputs: List[str]) -> List[Tuple[str, int]]:
|
| 124 |
+
labels = []
|
| 125 |
+
for segment, normalized in enumerate(outputs):
|
| 126 |
+
if normalized == "":
|
| 127 |
+
labels.append((normalized, segment))
|
| 128 |
+
else:
|
| 129 |
+
for w in normalized.split():
|
| 130 |
+
labels.append((w, segment))
|
| 131 |
+
|
| 132 |
+
return labels
|