Spaces:
Runtime error
Runtime error
Patched error
Browse files- logits_ngrams.py +5 -2
logits_ngrams.py
CHANGED
|
@@ -24,7 +24,10 @@ def _no_repeat_ngram_logits(input_ids, cur_len, logits, batch_size=1, no_repeat_
|
|
| 24 |
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
| 25 |
banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
| 26 |
for batch_idx in range(batch_size):
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
return logits
|
| 30 |
|
|
@@ -35,7 +38,7 @@ def _calc_banned_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len
|
|
| 35 |
return [[] for _ in range(num_hypos)]
|
| 36 |
generated_ngrams = [{} for _ in range(num_hypos)]
|
| 37 |
for idx in range(num_hypos):
|
| 38 |
-
gen_tokens = prev_input_ids[idx]
|
| 39 |
generated_ngram = generated_ngrams[idx]
|
| 40 |
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
| 41 |
|
|
|
|
| 24 |
# from fairseq: https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345
|
| 25 |
banned_tokens = _calc_banned_tokens(input_ids, batch_size, no_repeat_ngram_size, cur_len)
|
| 26 |
for batch_idx in range(batch_size):
|
| 27 |
+
if skip_tokens is not None:
|
| 28 |
+
logits[batch_idx, [token for token in banned_tokens[batch_idx] if int(token) not in skip_tokens]] = -float("inf")
|
| 29 |
+
else:
|
| 30 |
+
logits[batch_idx, banned_tokens[batch_idx]] = -float("inf")
|
| 31 |
|
| 32 |
return logits
|
| 33 |
|
|
|
|
| 38 |
return [[] for _ in range(num_hypos)]
|
| 39 |
generated_ngrams = [{} for _ in range(num_hypos)]
|
| 40 |
for idx in range(num_hypos):
|
| 41 |
+
gen_tokens = prev_input_ids[idx].tolist()
|
| 42 |
generated_ngram = generated_ngrams[idx]
|
| 43 |
for ngram in zip(*[gen_tokens[i:] for i in range(no_repeat_ngram_size)]):
|
| 44 |
|