add encoder self-attention
Browse files- __pycache__/utils.cpython-313.pyc +0 -0
- app.py +105 -16
- utils.py +5 -9
__pycache__/utils.cpython-313.pyc
CHANGED
|
Binary files a/__pycache__/utils.cpython-313.pyc and b/__pycache__/utils.cpython-313.pyc differ
|
|
|
app.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
import gradio as gr
|
| 4 |
-
from utils import save_data, get_attn_list, get_top_attns
|
| 5 |
|
| 6 |
|
| 7 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
@@ -37,8 +37,11 @@ def translate_text(input_text):
|
|
| 37 |
avg_decoder_attn_list = get_attn_list(translated.decoder_attentions, layer_index)
|
| 38 |
decoder_attn_scores = get_top_attns(avg_decoder_attn_list)
|
| 39 |
|
|
|
|
|
|
|
|
|
|
| 40 |
# save_data(outputs, src_tokens, tgt_tokens, attn_scores)
|
| 41 |
-
return outputs, render_cross_attn_html(src_tokens, tgt_tokens), cross_attn_scores, render_encoder_decoder_attn_html(tgt_tokens, "Output"), decoder_attn_scores
|
| 42 |
|
| 43 |
|
| 44 |
def render_cross_attn_html(src_tokens, tgt_tokens):
|
|
@@ -64,13 +67,17 @@ def render_cross_attn_html(src_tokens, tgt_tokens):
|
|
| 64 |
def render_encoder_decoder_attn_html(tokens, type):
|
| 65 |
# Build HTML for source and target tokens
|
| 66 |
tokens_html = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
for i, token in enumerate(tokens):
|
| 68 |
-
tokens_html += f'<span class="token
|
| 69 |
|
| 70 |
html = f"""
|
| 71 |
<div class="tgt-token-wrapper-text">{type} Tokens</div>
|
| 72 |
<div class="tgt-token-wrapper">{tokens_html}</div>
|
| 73 |
-
<div class="scores"><span class="score-1
|
| 74 |
"""
|
| 75 |
return html
|
| 76 |
|
|
@@ -80,7 +87,7 @@ css = """
|
|
| 80 |
.output-html {padding-top: 1rem; padding-bottom: 1rem;}
|
| 81 |
.output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);}
|
| 82 |
.token {padding: .5rem; border-radius: 5px;}
|
| 83 |
-
.
|
| 84 |
.tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
| 85 |
.src-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
| 86 |
.src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;}
|
|
@@ -94,18 +101,21 @@ css = """
|
|
| 94 |
"""
|
| 95 |
|
| 96 |
js = """
|
| 97 |
-
function showCrossAttFun(attn_scores, decoder_attn) {
|
| 98 |
|
| 99 |
const scrTokens = document.querySelectorAll('.src-token');
|
| 100 |
const srcLen = scrTokens.length - 1
|
| 101 |
const targetTokens = document.querySelectorAll('.tgt-token');
|
| 102 |
const scores = document.querySelectorAll('.score');
|
| 103 |
|
| 104 |
-
|
| 105 |
const decoderTokens = document.querySelectorAll('.decoder-token');
|
| 106 |
const decLen = decoderTokens.length - 1
|
| 107 |
const decoderScores = document.querySelectorAll('.decoder-score');
|
| 108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
function onTgtHover(event, idx) {
|
| 110 |
event.style.backgroundColor = "#C6E6E6";
|
| 111 |
|
|
@@ -153,9 +163,7 @@ function showCrossAttFun(attn_scores, decoder_attn) {
|
|
| 153 |
scores[2].style.display = "none";
|
| 154 |
}
|
| 155 |
|
| 156 |
-
function
|
| 157 |
-
event.style.backgroundColor = "#C6E6E6";
|
| 158 |
-
|
| 159 |
idx0 = decoder_attn[idx]['top_index'][0]
|
| 160 |
if (idx0 < decLen) {
|
| 161 |
el0 = decoderTokens[idx0]
|
|
@@ -181,12 +189,12 @@ function showCrossAttFun(attn_scores, decoder_attn) {
|
|
| 181 |
}
|
| 182 |
|
| 183 |
for (i=idx+1; i < decoderTokens.length; i++) {
|
| 184 |
-
decoderTokens[i].style.color = "#
|
| 185 |
}
|
| 186 |
|
| 187 |
}
|
| 188 |
|
| 189 |
-
function
|
| 190 |
event.style.backgroundColor = "";
|
| 191 |
idx0 = decoder_attn[idx]['top_index'][0]
|
| 192 |
el0 = decoderTokens[idx0]
|
|
@@ -216,6 +224,62 @@ function showCrossAttFun(attn_scores, decoder_attn) {
|
|
| 216 |
}
|
| 217 |
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
targetTokens.forEach((el, idx) => {
|
| 220 |
el.addEventListener("mouseover", () => {
|
| 221 |
onTgtHover(el, idx)
|
|
@@ -230,13 +294,25 @@ function showCrossAttFun(attn_scores, decoder_attn) {
|
|
| 230 |
|
| 231 |
decoderTokens.forEach((el, idx) => {
|
| 232 |
el.addEventListener("mouseover", () => {
|
| 233 |
-
|
| 234 |
})
|
| 235 |
});
|
| 236 |
|
| 237 |
decoderTokens.forEach((el, idx) => {
|
| 238 |
el.addEventListener("mouseout", () => {
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
})
|
| 241 |
});
|
| 242 |
}
|
|
@@ -269,6 +345,7 @@ with gr.Blocks(css=css) as demo:
|
|
| 269 |
|
| 270 |
cross_attn = gr.JSON(value=[], visible=False)
|
| 271 |
decoder_attn = gr.JSON(value=[], visible=False)
|
|
|
|
| 272 |
|
| 273 |
gr.Markdown(
|
| 274 |
"""
|
|
@@ -281,6 +358,18 @@ with gr.Blocks(css=css) as demo:
|
|
| 281 |
with gr.Row(elem_classes="output-html-row"):
|
| 282 |
output_html = gr.HTML(label="Cross Attention", elem_classes="output-html")
|
| 283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
gr.Markdown(
|
| 285 |
"""
|
| 286 |
## Check Self Attentions for Decoder
|
|
@@ -293,9 +382,9 @@ with gr.Blocks(css=css) as demo:
|
|
| 293 |
with gr.Row(elem_classes="output-html-row"):
|
| 294 |
decoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html")
|
| 295 |
|
| 296 |
-
translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, cross_attn, decoder_output_html, decoder_attn])
|
| 297 |
|
| 298 |
-
output_box.change(None, [cross_attn, decoder_attn], None, js=js)
|
| 299 |
|
| 300 |
gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ",
|
| 301 |
elem_classes="note-text")
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
import gradio as gr
|
| 4 |
+
from utils import save_data, get_attn_list, get_top_attns, get_encoder_attn_list
|
| 5 |
|
| 6 |
|
| 7 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
| 37 |
avg_decoder_attn_list = get_attn_list(translated.decoder_attentions, layer_index)
|
| 38 |
decoder_attn_scores = get_top_attns(avg_decoder_attn_list)
|
| 39 |
|
| 40 |
+
avg_encoder_attn_list = get_encoder_attn_list(translated.encoder_attentions, layer_index)
|
| 41 |
+
encoder_attn_scores = get_top_attns(avg_encoder_attn_list)
|
| 42 |
+
|
| 43 |
# save_data(outputs, src_tokens, tgt_tokens, attn_scores)
|
| 44 |
+
return outputs, render_cross_attn_html(src_tokens, tgt_tokens), cross_attn_scores, render_encoder_decoder_attn_html(tgt_tokens, "Output"), decoder_attn_scores, render_encoder_decoder_attn_html(src_tokens, "Input"), encoder_attn_scores
|
| 45 |
|
| 46 |
|
| 47 |
def render_cross_attn_html(src_tokens, tgt_tokens):
|
|
|
|
| 67 |
def render_encoder_decoder_attn_html(tokens, type):
|
| 68 |
# Build HTML for source and target tokens
|
| 69 |
tokens_html = ""
|
| 70 |
+
className = "decoder"
|
| 71 |
+
if type == "Input":
|
| 72 |
+
className = "encoder"
|
| 73 |
+
|
| 74 |
for i, token in enumerate(tokens):
|
| 75 |
+
tokens_html += f'<span class="token {className}-token" data-index="{i}">{token}</span> '
|
| 76 |
|
| 77 |
html = f"""
|
| 78 |
<div class="tgt-token-wrapper-text">{type} Tokens</div>
|
| 79 |
<div class="tgt-token-wrapper">{tokens_html}</div>
|
| 80 |
+
<div class="scores"><span class="score-1 {className}-score"></span><span class="score-2 {className}-score"></span><span class="score-3 {className}-score"></span><div>
|
| 81 |
"""
|
| 82 |
return html
|
| 83 |
|
|
|
|
| 87 |
.output-html {padding-top: 1rem; padding-bottom: 1rem;}
|
| 88 |
.output-html-row {margin-bottom: .5rem; border: var(--block-border-width) solid var(--block-border-color); border-radius: var(--block-radius);}
|
| 89 |
.token {padding: .5rem; border-radius: 5px;}
|
| 90 |
+
.token {cursor: pointer;}
|
| 91 |
.tgt-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
| 92 |
.src-token-wrapper {line-height: 2.5rem; padding: .5rem;}
|
| 93 |
.src-token-wrapper-text {position: absolute; bottom: .75rem; color: #71717a;}
|
|
|
|
| 101 |
"""
|
| 102 |
|
| 103 |
js = """
|
| 104 |
+
function showCrossAttFun(attn_scores, decoder_attn, encoder_attn) {
|
| 105 |
|
| 106 |
const scrTokens = document.querySelectorAll('.src-token');
|
| 107 |
const srcLen = scrTokens.length - 1
|
| 108 |
const targetTokens = document.querySelectorAll('.tgt-token');
|
| 109 |
const scores = document.querySelectorAll('.score');
|
| 110 |
|
|
|
|
| 111 |
const decoderTokens = document.querySelectorAll('.decoder-token');
|
| 112 |
const decLen = decoderTokens.length - 1
|
| 113 |
const decoderScores = document.querySelectorAll('.decoder-score');
|
| 114 |
|
| 115 |
+
const encoderTokens = document.querySelectorAll('.encoder-token');
|
| 116 |
+
const encLen = encoderTokens.length - 1
|
| 117 |
+
const encoderScores = document.querySelectorAll('.encoder-score');
|
| 118 |
+
|
| 119 |
function onTgtHover(event, idx) {
|
| 120 |
event.style.backgroundColor = "#C6E6E6";
|
| 121 |
|
|
|
|
| 163 |
scores[2].style.display = "none";
|
| 164 |
}
|
| 165 |
|
| 166 |
+
function onDecodeHover(event, idx) {
|
|
|
|
|
|
|
| 167 |
idx0 = decoder_attn[idx]['top_index'][0]
|
| 168 |
if (idx0 < decLen) {
|
| 169 |
el0 = decoderTokens[idx0]
|
|
|
|
| 189 |
}
|
| 190 |
|
| 191 |
for (i=idx+1; i < decoderTokens.length; i++) {
|
| 192 |
+
decoderTokens[i].style.color = "#ccc9c9";
|
| 193 |
}
|
| 194 |
|
| 195 |
}
|
| 196 |
|
| 197 |
+
function outDecodeHover(event, idx) {
|
| 198 |
event.style.backgroundColor = "";
|
| 199 |
idx0 = decoder_attn[idx]['top_index'][0]
|
| 200 |
el0 = decoderTokens[idx0]
|
|
|
|
| 224 |
}
|
| 225 |
|
| 226 |
|
| 227 |
+
function onEncodeHover(event, idx) {
|
| 228 |
+
idx0 = encoder_attn[idx]['top_index'][0]
|
| 229 |
+
if (idx0 < encLen) {
|
| 230 |
+
el0 = encoderTokens[idx0]
|
| 231 |
+
el0.style.backgroundColor = "#89C6C6"
|
| 232 |
+
encoderScores[0].textContent = encoder_attn[idx]['top_values'][0]
|
| 233 |
+
encoderScores[0].style.display = "initial"
|
| 234 |
+
encoderScores[0].style.backgroundColor = "#89C6C6"
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
idx1 = encoder_attn[idx]['top_index'][1]
|
| 238 |
+
if (idx1 < encLen) {
|
| 239 |
+
el1 = encoderTokens[idx1]
|
| 240 |
+
el1.style.backgroundColor = "#C6E6E6"
|
| 241 |
+
encoderScores[1].textContent = encoder_attn[idx]['top_values'][1]
|
| 242 |
+
encoderScores[1].style.display = "initial"
|
| 243 |
+
encoderScores[1].style.backgroundColor = "#C6E6E6"
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
idx2 = encoder_attn[idx]['top_index'][2]
|
| 247 |
+
if (idx2 < encLen) {
|
| 248 |
+
el2 = encoderTokens[idx2]
|
| 249 |
+
el2.style.backgroundColor = "#E5F5F5"
|
| 250 |
+
encoderScores[2].textContent = encoder_attn[idx]['top_values'][2]
|
| 251 |
+
encoderScores[2].style.display = "initial"
|
| 252 |
+
encoderScores[2].style.backgroundColor = "#E5F5F5"
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
function outEncodeHover(event, idx) {
|
| 258 |
+
event.style.backgroundColor = "";
|
| 259 |
+
idx0 = encoder_attn[idx]['top_index'][0]
|
| 260 |
+
el0 = encoderTokens[idx0]
|
| 261 |
+
el0.style.backgroundColor = ""
|
| 262 |
+
encoderScores[0].textContent = ""
|
| 263 |
+
encoderScores[0].style.display = "none";
|
| 264 |
+
|
| 265 |
+
idx1 = encoder_attn[idx]['top_index'][1]
|
| 266 |
+
if (idx1 || idx1 == 0) {
|
| 267 |
+
el1 = encoderTokens[idx1]
|
| 268 |
+
el1.style.backgroundColor = ""
|
| 269 |
+
encoderScores[1].textContent = ""
|
| 270 |
+
encoderScores[1].style.display = "none";
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
idx2 = encoder_attn[idx]['top_index'][2]
|
| 274 |
+
if (idx2 || idx2 == 0) {
|
| 275 |
+
el2 = encoderTokens[idx2]
|
| 276 |
+
el2.style.backgroundColor = ""
|
| 277 |
+
encoderScores[2].textContent = ""
|
| 278 |
+
encoderScores[2].style.display = "none";
|
| 279 |
+
}
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
targetTokens.forEach((el, idx) => {
|
| 284 |
el.addEventListener("mouseover", () => {
|
| 285 |
onTgtHover(el, idx)
|
|
|
|
| 294 |
|
| 295 |
decoderTokens.forEach((el, idx) => {
|
| 296 |
el.addEventListener("mouseover", () => {
|
| 297 |
+
onDecodeHover(el, idx)
|
| 298 |
})
|
| 299 |
});
|
| 300 |
|
| 301 |
decoderTokens.forEach((el, idx) => {
|
| 302 |
el.addEventListener("mouseout", () => {
|
| 303 |
+
outDecodeHover(el, idx)
|
| 304 |
+
})
|
| 305 |
+
});
|
| 306 |
+
|
| 307 |
+
encoderTokens.forEach((el, idx) => {
|
| 308 |
+
el.addEventListener("mouseover", () => {
|
| 309 |
+
onEncodeHover(el, idx)
|
| 310 |
+
})
|
| 311 |
+
});
|
| 312 |
+
|
| 313 |
+
encoderTokens.forEach((el, idx) => {
|
| 314 |
+
el.addEventListener("mouseout", () => {
|
| 315 |
+
outEncodeHover(el, idx)
|
| 316 |
})
|
| 317 |
});
|
| 318 |
}
|
|
|
|
| 345 |
|
| 346 |
cross_attn = gr.JSON(value=[], visible=False)
|
| 347 |
decoder_attn = gr.JSON(value=[], visible=False)
|
| 348 |
+
encoder_attn = gr.JSON(value=[], visible=False)
|
| 349 |
|
| 350 |
gr.Markdown(
|
| 351 |
"""
|
|
|
|
| 358 |
with gr.Row(elem_classes="output-html-row"):
|
| 359 |
output_html = gr.HTML(label="Cross Attention", elem_classes="output-html")
|
| 360 |
|
| 361 |
+
gr.Markdown(
|
| 362 |
+
"""
|
| 363 |
+
## Check Self Attentions for Encoder
|
| 364 |
+
Hover your mouse over an input (English) word/token to see which word/token it is self-attending to.
|
| 365 |
+
""",
|
| 366 |
+
elem_classes="output-html-desc"
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
with gr.Row(elem_classes="output-html-row"):
|
| 370 |
+
encoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html")
|
| 371 |
+
|
| 372 |
+
|
| 373 |
gr.Markdown(
|
| 374 |
"""
|
| 375 |
## Check Self Attentions for Decoder
|
|
|
|
| 382 |
with gr.Row(elem_classes="output-html-row"):
|
| 383 |
decoder_output_html = gr.HTML(label="Decoder Attention)", elem_classes="output-html")
|
| 384 |
|
| 385 |
+
translate_button.click(fn=translate_text, inputs=input_box, outputs=[output_box, output_html, cross_attn, decoder_output_html, decoder_attn, encoder_output_html, encoder_attn])
|
| 386 |
|
| 387 |
+
output_box.change(None, [cross_attn, decoder_attn, encoder_attn], None, js=js)
|
| 388 |
|
| 389 |
gr.Markdown("**Note:** I'm using a transformer model of encoder-decoder architecture (`Helsinki-NLP/opus-mt-en-zh`) in order to obtain cross attention from the decoder layers. ",
|
| 390 |
elem_classes="note-text")
|
utils.py
CHANGED
|
@@ -40,12 +40,8 @@ def get_top_attns(avg_attn_list):
|
|
| 40 |
|
| 41 |
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
# attn_tensor = decoder_attentions[token_index][layer_index] # shape: [1, 8, 1, 24]
|
| 49 |
-
# avg_attn_list.append(attn_tensor.squeeze(0).squeeze(1).mean(0)) # shape: [24], mean across heads
|
| 50 |
-
#
|
| 51 |
-
# return avg_attn_list
|
|
|
|
| 40 |
|
| 41 |
|
| 42 |
|
| 43 |
+
def get_encoder_attn_list(encoder_attentions, layer_index):
|
| 44 |
+
attn_tensor = encoder_attentions[layer_index]
|
| 45 |
+
avg_attn_list = attn_tensor[0].mean(dim=0)
|
| 46 |
+
|
| 47 |
+
return avg_attn_list
|
|
|
|
|
|
|
|
|
|
|
|