AITech 학습정리-[DAY 19] Transformer, 실습-Multi-head Attention, Masked Multi-head Attention
과거의 것들/AI Tech boostcamp

AITech 학습정리-[DAY 19] Transformer, 실습-Multi-head Attention, Masked Multi-head Attention

 =================================

학습내용

(7강) Transformer

Add-on 처럼 사용되던 attention 만 사용해서 만들 수 있다. 라는걸 보여주는게 transformer

 


일단 RNN 부터 보자.

RNN은 왼쪽에서 오른쪽으로 가거나 반대로 해서 오른쪽에서 왼쪽으로 감. bi-direction RNNs 은 두개 다 하는데 이러면 I go home 단어를 넣는다고 했을 때 go 단어는 정방향에서 I 정보를 역방향에선 home 정보를 hidden state에서 가지게 된다. 그래서 이 두 벡터를 concat 해서 두개 다 가지게 함.

 

Further Reading

 

http://jalammar.github.io/illustrated-transformer/

 

The Illustrated Transformer

Discussions: Hacker News (65 points, 4 comments), Reddit r/MachineLearning (29 points, 3 comments) Translations: Chinese (Simplified), French, Japanese, Korean, Russian, Spanish, Vietnamese Watch: MIT’s Deep Learning State of the Art lecture referencing

jalammar.github.io

https://nlpinkorean.github.io/illustrated-transformer/

 

The Illustrated Transformer

저번 글에서 다뤘던 attention seq2seq 모델에 이어, attention 을 활용한 또 다른 모델인 Transformer 모델에 대해 얘기해보려 합니다. 2017 NIPS에서 Google이 소개했던 Transformer는 NLP 학계에서 정말 큰 주목을

nlpinkorean.github.io

https://pytorch.org/tutorials/beginner/transformer_tutorial.html

 

Language Modeling with nn.Transformer and TorchText — PyTorch Tutorials 1.10.1+cu102 documentation

Note Click here to download the full example code Language Modeling with nn.Transformer and TorchText This is a tutorial on training a sequence-to-sequence model that uses the nn.Transformer module. The PyTorch 1.2 release includes a standard transformer m

pytorch.org

 

 

 

 

가중치결정 query. query 벡터와 곱해지는 재료 벡터 keys, 어느것을 가져올지. 또다른 재료역할 values 벡터

각 단어들에 대해 계산할 때 각 단어에 대한 keys 와 values 벡터는 같으나 q가 달라진다. 그래서 이 달라진 q를 가지고 k 와 곱하고 여기서 새로나온 가중치에 대한 값을 values 벡터와 선형결합해서 최종적인 Attention output vector를 얻어내는 것. 이게 결국 h2가 된다.

 

Queries 와 keys 벡터 차원은 같아야 하지만, values는 상관없다. 왜냐하면 어차피 모두 더해서 scalar 값을 만들기 때문.


A(q,K,V)는 values 벡터에 대한 가중평균. quary를 attention model 을 통해 encoding 한 벡터는 최종적으로 dv 차원 만큼이 나옴. 

key 와의 내적에 기반한 유사도

 


그럼 왜 softmax 전에 root(dk) 로 나누는가?

평균과 분산이 있으니까 두개를 더하게 되면 평균은 그대로 0이지만 분산은 늘어난다. 근데 이게 단순히 개수가 늘어나면 늘어나는 거기 때문에 k차원이 늘어나면 늘어날수록 더하는게 많아지고 그러면 분산도 더 커진다. 그래서 단순히 k차원이 늘어난다고 분산이 커지고, 분산이 커지면 굉장히 큰 값이 나오고 큰 값과 작은값의 차이가 softmax 에 상당한 영향을 주기 때문에 표준편차로 나눠서 정규화 하는 것. 즉 softmax 하기 전에 한번 정규화를 시켜줘서 k차원이 얼마나 크든 작던간에 상관없이 만들어 주는 거다.

 

그렇게 만든 self attention인 z를 여러개 만듬. 여기선 8개 만들어 concat 해준 뒤 나중에 입력값 vector와 그대로 더해야 하기 때문에 linear로 입력값 vector 차원만큼 줄여서 최종 값을 내놓는다.

 

n 은 입력 sequence의 길이, d는 query의 d 차원길이.

우리가 한 Self-Attention의 경우 입력길이, k차원 길이가 행렬곱셈 정의에 의해 O(n^2*d) 가 된다. 하지만 이것은 병렬처리가 가능하기 때문에 병렬처리의 시간복잡도 Sequential Operations는 1이 됨. 단어와의 길이도 매번마다 query vector 와 key vector를 각 단어마다 계산해서 O(1)인듯.

 

RNN의 경우 hidden state인 h의 차원이 d이고 d*d 크기의 weight와 곱셈을 하고, 이것이 입력 길이 n번만큼 수행되기 때문에 O(n*d^2). 그러나 이렇게 순차적으로 계산하는 것이기 때문에 병렬처리가 불가능해서 그냥 시간복잡도는 O(n)이 된다. 단어길이도 n번 거쳐가 도착하니까 O(n).

 

 

 

한 block을 구성하는 Multi-Head Attention, residual connection, Layer Normalization. Feed Forward는 fully connected layer. Transformer에서 제안한 self-attention인 Multi-Head Attention을 포함한 한 블럭.

지금까지는 multi-head attention 을 본거고, transformer에서는 이 multi-head attention을 덧붙혀서 추가적인 후처리. residual connection. gradient의 vanishing 문제를 해결하여 layer를 여러개 쌓아 올려 성능을 올리겠다. 그럼 add와 Norm을 왜 하는가? 밑에 설명할거다.

 

그냥 neural network의 경우(linear regression 같은걸 말하는 듯) hidden 결과가 나왔을 때 hidden state 의 평균과 분산 정보가 어쨋든 버리고 입력값에 대해 빼주고 표준편차로 나눠 정규화 시켜줬다. 그래야 원래 식 y=2x+3 에서 분산이 4이고 평균이 3이라는 원래 식 정보를 반영하기 때문인 것 같다. 그래서 특정 노드에 발견되어야 하는 값에 가장 최적화된 평균과 분산을 원하는 만큼 가지도록 동작하게 된다. 그래서 layer norm 도 batch norm과 유사하게 첫번째 단계에서는 주어진 sample에 대한 평균 분산을 0과 1로 만들고 다음엔 우리가 원하는 평균 분산을 주입하는 두 단계로 이루어져 있다.

 


보면 thinking과 machines 단어를 넣었을 때 나온 hidden state를 정규화 해주고 내가 원하는 평균과 분산 값을 Affine transformation을 통해 주입하는 걸 볼 수 있다. 이게 layer normalization. batch normalization은 조금 다르다고 하지만 대강 비슷하다고 함.

 

 

문제가 하나 있는데 순서에 대한 변인요소가 없어서 단어가 바뀌어도 어순을 파악못한다. 이는 values 벡터와 곱할때 교환법칙이 성립하기 때문.

그래서 각 위치에 따른 고유 값을 더해서 반영하는 식으로 순서를 반영한다. 이게 무슨 말이냐면 얘가 단어 뜻에 관계없이 어디 위치에 있으면 특정값을 더해라, 를 통해 위치도 정보에 반영한다는 것. 이걸 sin과 cos 주기함수에 주기만 다르게 해서 만들어서 각 위치에다가 더해서 쓴다.

 

Learning rate의 경우 저런 일들이 일어나니까 경험상 저렇게 하는게 낫다는게 알려져 있다.

 

우리가 배웠던 block 이 Encoder block 이었고 이것을 여러개 layer로 쌓아 올린다.
makeing 단어를 보면 각 attention 마다 다른 곳에 주목함을 알 수 있다. 이래서 multi-head attention을 하는거고.

https://colab.research.google.com/github/tensorflow/tensor2tensor/blob/master/tensor2tensor/notebooks/hello_t2t.ipynb

 

보면 Encoding에서 번역되서 나온게 Decoder로 들어간다. Masked Multi-Head Attention은 밑에서 설명하기로 하고 그 위에 보면 Quary vector는 Decoder 에 들어간 입력값으로 만드는데 Keys vector와 values vector는 encoder에서 마지막으로 나온 keys 와 values가 들어가는 걸 볼 수 있다. 이는 단어와의 관계를 추정하는 것이기 때문에 그렇다.

그리고 "I go home"을 "나는 집에 간다" 로 번역할 때 마지막에 linear를 거치는데 여기서 모든 한글 단어에 대해 linear을 펼쳐서 본다. 만약 한글 사전에 단어가 10만개 있으면 10만개 output을 내놓도록 linear에서 펼쳐놓고 softmax해서 가장 가능성이 높은 걸 출력한다.

 

훈련할 때는 모든 단어들을 넣어주긴 하지만 실제 attention mode (디코더에서? 테스트? 말하는 듯) 일 땐 미래에 어떤 단어를 입력할 지 미리 안다는게 말이 안된다. 하지만 모델 구조가 저렇게 넣도록 되어있으니 mask를 씌워서 후처리를 하는 것 같다.

현재 입력값이 query고 나머지가 key니까 현재 입력값 query에 미래의 단어 key들에 대한 softmax 확률을 0으로 바꾸는 듯. 이게 문장이라는게 왼쪽에서 오른쪽으로 읽다보니 왼쪽 단어 보는데 오른쪽을 보고 이해하는게 말이 안되는 듯. 비록 왼쪽 단어가 오른쪽 단어를 보고 해석하는 거여도 오른쪽 단어에 왔을 때 key를 이용해 왼쪽의 단어와 연관이 있다고 하고 해석하는게 맞다고 판단한 것 같다.

 

 

 

실습

(실습 7강) Multi-head Attention 구현

Multi-Head 시 이론은 H개 만큼 Q, K, V를 따로 생성하는게 맞는데 메모리문제 등으로 인해 하나의 행렬로 만들고 구간을 쪼개서 사용한다고 한다. 위 예시는 head=3일 때 d_model을 head 개수만큼 나누면 head가 3개가 나오니까 진짜 head 3개가 되는거임. 길이는 나누고 나온 d_k로 하고.

 

##7. Multi-head Attention 1. Multi-head attention 및 self-attention 구현. 2. 각 과정에서 일어나는 연산과 input/output 형태 이해.

 
 
 
 
 

필요 패키지 import

 
 
 
 
[1]
 
 
 
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
 
import torch
import math
 
 
 
 
 
 
 
 
 
 

데이터 전처리

 
 
 
 
[2]
 
 
 
pad_id = 0
vocab_size = 100
 
data = [
  [62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54],
  [75, 51],
  [66, 88, 98, 47],
  [21, 39, 10, 64, 21],
  [98],
  [77, 65, 51, 77, 19, 15, 35, 19, 23, 97, 50, 46, 53, 42, 45, 91, 66, 3, 
43, 10],
  [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34],
  [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43]
]
 
 
 
 
 
 
 
 
 
 
[3]
 
 
 
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")
 
  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))
 
  return data, max_len
 
 
 
 
 
 
 
 
 
 
[4]
 
 
 
data, max_len = padding(data)
 
 
 
 
 
 
100%|██████████| 10/10 [00:00<00:00, 12826.62it/s]
Maximum sequence length: 20
 
 
 
 
[5]
 
 
 
data
 
 
 
 
 
 
[[62, 13, 47, 39, 78, 33, 56, 13, 39, 29, 44, 86, 71, 36, 18, 75, 0, 0, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21, 54, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [75, 51, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [21, 39, 10, 64, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [98, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 [77,
  65,
  51,
  77,
  19,
  15,
  35,
  19,
  23,
  97,
  50,
  46,
  53,
  42,
  45,
  91,
  66,
  3,
  43,
  10],
 [70, 64, 98, 25, 99, 53, 4, 13, 69, 62, 66, 76, 15, 75, 45, 34, 0, 0, 0, 0],
 [20, 64, 81, 35, 76, 85, 1, 62, 8, 45, 99, 77, 19, 43, 0, 0, 0, 0, 0, 0]]
 
 
 
 

Hyperparameter 세팅 및 embedding

 
 
 
 
[6]
 
 
 
d_model = 512  # model의 hidden size
num_heads = 8  # head의 개수
 
 
 
 
 
 
 
 
 
 
[7]
 
 
 
embedding = nn.Embedding(vocab_size, d_model)
 
# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)
 
 
 
 
 
 
 
 
 
 
[8]
 
 
 
print(batch_emb)
print(batch_emb.shape)
 
 
 
 
 
 
tensor([[[-0.2169, -0.3583, 1.0193, ..., -0.7934, -0.9208, -1.0198], [-0.8411, 2.4772, 0.9702, ..., -0.4276, -1.3260, -0.0394], [ 0.3482, 2.8239, -1.6240, ..., 1.5651, -0.0208, -1.2387], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]], [[-0.3289, -0.4818, -0.6164, ..., -1.8003, 0.6235, -0.3524], [-0.8258, -0.2416, 1.0993, ..., 0.8884, 0.3743, 1.3961], [ 0.7716, 0.2966, 0.1699, ..., -1.2789, -0.5366, 0.3534], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]], [[-0.8911, -1.5136, 1.1245, ..., 1.0384, -0.5083, 0.3870], [ 0.7745, -0.3943, 0.5206, ..., -0.1021, -0.8571, -1.8890], [ 0.7830, 0.8141, -1.1696, ..., 1.6220, 1.5565, 0.6228], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]], ..., [[ 0.4533, 0.8047, -0.1219, ..., -0.8111, 1.2460, 0.4246], [-0.3274, -1.0963, -1.2922, ..., 0.2544, -1.0975, -0.6509], [ 0.7716, 0.2966, 0.1699, ..., -1.2789, -0.5366, 0.3534], ..., [-1.2036, -1.3401, -0.3581, ..., 0.1999, 0.6540, -0.4159], [ 1.2155, -0.0542, 0.4923, ..., -0.1561, 0.9865, -0.6558], [ 1.0399, -1.5524, 0.0432, ..., -0.7237, -0.7161, 0.5026]], [[ 0.3194, -0.9475, 0.9975, ..., -0.7796, -2.1479, -0.6828], [-0.3623, -1.6566, 0.6783, ..., 2.4238, -0.3513, 1.6672], [ 0.3914, 0.2937, -0.2541, ..., 1.7687, -0.3865, -0.8186], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]], [[-1.4615, 1.6652, -1.0368, ..., 0.6557, -0.1662, -0.8121], [-0.3623, -1.6566, 0.6783, ..., 2.4238, -0.3513, 1.6672], [-1.7808, 1.5748, 1.9841, ..., 0.1642, 1.0493, 0.2800], ..., [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691], [-0.3128, 0.5407, 0.1209, ..., -0.2008, -0.6200, -0.9691]]], grad_fn=<EmbeddingBackward>) torch.Size([10, 20, 512])
 
 
 
 

Linear transformation & 여러 head로 나누기

 
 
 
 
 

Multi-head attention 내에서 쓰이는 linear transformation matrix들을 정의합니다.

 
 
 
 
[9]
 
 
 
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)
 
 
 
 
 
 
 
 
 
 
[10]
 
 
 
w_0 = nn.Linear(d_model, d_model)
 
 
 
 
 
 
 
 
 
 
[11]
 
 
 
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)
 
print(q.shape)
print(k.shape)
print(v.shape)
 
 
 
 
 
 
torch.Size([10, 20, 512]) torch.Size([10, 20, 512]) torch.Size([10, 20, 512])
 
 
 
 

Q, k, v를 num_head개의 차원 분할된 여러 vector로 만듭니다.

 
 
 
 
[12]
 
 
 
batch_size = q.shape[0]
d_k = d_model // num_heads
 
q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
 
print(q.shape)
print(k.shape)
print(v.shape)
 
 
 
 
 
torch.Size([10, 20, 8, 64]) torch.Size([10, 20, 8, 64]) torch.Size([10, 20, 8, 64])
 
 
 
[13]
 
 
 
q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)
 
print(q.shape)
print(k.shape)
print(v.shape)
 
 
 
 
 
 
torch.Size([10, 8, 20, 64]) torch.Size([10, 8, 20, 64]) torch.Size([10, 8, 20, 64])
 
 
 
 

Scaled dot-product self-attention 구현

 
 
 
 
 

각 head에서 실행되는 self-attetion 과정입니다.

 
 
 
 
[14]
 
 
 
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, 
num_heads, L, L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)
 
print(attn_dists)
print(attn_dists.shape)
 
 
 
 
 
tensor([[[[0.0483, 0.0464, 0.0778, ..., 0.0614, 0.0614, 0.0614], [0.0362, 0.0656, 0.0156, ..., 0.0370, 0.0370, 0.0370], [0.0460, 0.0501, 0.0424, ..., 0.0623, 0.0623, 0.0623], ..., [0.0347, 0.0323, 0.0538, ..., 0.1034, 0.1034, 0.1034], [0.0347, 0.0323, 0.0538, ..., 0.1034, 0.1034, 0.1034], [0.0347, 0.0323, 0.0538, ..., 0.1034, 0.1034, 0.1034]], [[0.0404, 0.0403, 0.0285, ..., 0.0692, 0.0692, 0.0692], [0.0305, 0.0402, 0.0557, ..., 0.0377, 0.0377, 0.0377], [0.0372, 0.0459, 0.0572, ..., 0.0409, 0.0409, 0.0409], ..., [0.0666, 0.0519, 0.0567, ..., 0.0635, 0.0635, 0.0635], [0.0666, 0.0519, 0.0567, ..., 0.0635, 0.0635, 0.0635], [0.0666, 0.0519, 0.0567, ..., 0.0635, 0.0635, 0.0635]], [[0.0455, 0.0610, 0.0668, ..., 0.0576, 0.0576, 0.0576], [0.0735, 0.0627, 0.0385, ..., 0.0494, 0.0494, 0.0494], [0.0773, 0.0402, 0.0613, ..., 0.0504, 0.0504, 0.0504], ..., [0.0474, 0.0514, 0.0441, ..., 0.0479, 0.0479, 0.0479], [0.0474, 0.0514, 0.0441, ..., 0.0479, 0.0479, 0.0479], [0.0474, 0.0514, 0.0441, ..., 0.0479, 0.0479, 0.0479]], ..., [[0.0373, 0.0420, 0.0425, ..., 0.0405, 0.0405, 0.0405], [0.0395, 0.0286, 0.0691, ..., 0.0638, 0.0638, 0.0638], [0.0250, 0.0302, 0.0596, ..., 0.0698, 0.0698, 0.0698], ..., [0.0498, 0.0197, 0.0690, ..., 0.0537, 0.0537, 0.0537], [0.0498, 0.0197, 0.0690, ..., 0.0537, 0.0537, 0.0537], [0.0498, 0.0197, 0.0690, ..., 0.0537, 0.0537, 0.0537]], [[0.0387, 0.0249, 0.0319, ..., 0.0697, 0.0697, 0.0697], [0.0613, 0.0528, 0.0384, ..., 0.0322, 0.0322, 0.0322], [0.0498, 0.0539, 0.0328, ..., 0.0616, 0.0616, 0.0616], ..., [0.0426, 0.0712, 0.0455, ..., 0.0511, 0.0511, 0.0511], [0.0426, 0.0712, 0.0455, ..., 0.0511, 0.0511, 0.0511], [0.0426, 0.0712, 0.0455, ..., 0.0511, 0.0511, 0.0511]], [[0.0727, 0.0252, 0.0722, ..., 0.0559, 0.0559, 0.0559], [0.0863, 0.0512, 0.0345, ..., 0.0449, 0.0449, 0.0449], [0.0634, 0.0584, 0.0333, ..., 0.0503, 0.0503, 0.0503], ..., [0.0329, 0.0757, 0.0665, ..., 0.0410, 0.0410, 0.0410], [0.0329, 0.0757, 0.0665, ..., 0.0410, 0.0410, 0.0410], [0.0329, 0.0757, 0.0665, ..., 0.0410, 0.0410, 0.0410]]], [[[0.0768, 0.0889, 0.0462, ..., 0.0416, 0.0416, 0.0416], [0.0286, 0.0342, 0.0516, ..., 0.0555, 0.0555, 0.0555], [0.0185, 0.0133, 0.0208, ..., 0.0610, 0.0610, 0.0610], ..., [0.0208, 0.0198, 0.0146, ..., 0.0593, 0.0593, 0.0593], [0.0208, 0.0198, 0.0146, ..., 0.0593, 0.0593, 0.0593], [0.0208, 0.0198, 0.0146, ..., 0.0593, 0.0593, 0.0593]], [[0.0395, 0.0524, 0.0477, ..., 0.0518, 0.0518, 0.0518], [0.0460, 0.0686, 0.0540, ..., 0.0477, 0.0477, 0.0477], [0.0278, 0.0704, 0.0335, ..., 0.0522, 0.0522, 0.0522], ..., [0.0336, 0.0225, 0.0344, ..., 0.0573, 0.0573, 0.0573], [0.0336, 0.0225, 0.0344, ..., 0.0573, 0.0573, 0.0573], [0.0336, 0.0225, 0.0344, ..., 0.0573, 0.0573, 0.0573]], [[0.0422, 0.0351, 0.0672, ..., 0.0510, 0.0510, 0.0510], [0.0234, 0.0575, 0.0511, ..., 0.0536, 0.0536, 0.0536], [0.0304, 0.0461, 0.0541, ..., 0.0522, 0.0522, 0.0522], ..., [0.0769, 0.0466, 0.0611, ..., 0.0488, 0.0488, 0.0488], [0.0769, 0.0466, 0.0611, ..., 0.0488, 0.0488, 0.0488], [0.0769, 0.0466, 0.0611, ..., 0.0488, 0.0488, 0.0488]], ..., [[0.0275, 0.0409, 0.0720, ..., 0.0493, 0.0493, 0.0493], [0.0481, 0.0495, 0.0574, ..., 0.0490, 0.0490, 0.0490], [0.0561, 0.0507, 0.0705, ..., 0.0451, 0.0451, 0.0451], ..., [0.0825, 0.0692, 0.0783, ..., 0.0441, 0.0441, 0.0441], [0.0825, 0.0692, 0.0783, ..., 0.0441, 0.0441, 0.0441], [0.0825, 0.0692, 0.0783, ..., 0.0441, 0.0441, 0.0441]], [[0.0264, 0.0646, 0.0331, ..., 0.0523, 0.0523, 0.0523], [0.0760, 0.0303, 0.0844, ..., 0.0451, 0.0451, 0.0451], [0.0276, 0.0411, 0.0247, ..., 0.0574, 0.0574, 0.0574], ..., [0.0365, 0.0452, 0.0674, ..., 0.0510, 0.0510, 0.0510], [0.0365, 0.0452, 0.0674, ..., 0.0510, 0.0510, 0.0510], [0.0365, 0.0452, 0.0674, ..., 0.0510, 0.0510, 0.0510]], [[0.0303, 0.0351, 0.0208, ..., 0.0570, 0.0570, 0.0570], [0.0325, 0.0703, 0.0617, ..., 0.0491, 0.0491, 0.0491], [0.0692, 0.0625, 0.0669, ..., 0.0377, 0.0377, 0.0377], ..., [0.0641, 0.0561, 0.0607, ..., 0.0474, 0.0474, 0.0474], [0.0641, 0.0561, 0.0607, ..., 0.0474, 0.0474, 0.0474], [0.0641, 0.0561, 0.0607, ..., 0.0474, 0.0474, 0.0474]]], [[[0.0319, 0.0667, 0.1340, ..., 0.0258, 0.0258, 0.0258], [0.0427, 0.0423, 0.0473, ..., 0.0536, 0.0536, 0.0536], [0.0788, 0.0390, 0.0182, ..., 0.0467, 0.0467, 0.0467], ..., [0.0462, 0.0179, 0.0145, ..., 0.0786, 0.0786, 0.0786], [0.0462, 0.0179, 0.0145, ..., 0.0786, 0.0786, 0.0786], [0.0462, 0.0179, 0.0145, ..., 0.0786, 0.0786, 0.0786]], [[0.0364, 0.0986, 0.0557, ..., 0.0394, 0.0394, 0.0394], [0.0673, 0.0528, 0.0870, ..., 0.0468, 0.0468, 0.0468], [0.0513, 0.0916, 0.0370, ..., 0.0340, 0.0340, 0.0340], ..., [0.0393, 0.0612, 0.0221, ..., 0.0608, 0.0608, 0.0608], [0.0393, 0.0612, 0.0221, ..., 0.0608, 0.0608, 0.0608], [0.0393, 0.0612, 0.0221, ..., 0.0608, 0.0608, 0.0608]], [[0.0255, 0.0424, 0.0443, ..., 0.0660, 0.0660, 0.0660], [0.0531, 0.0754, 0.0386, ..., 0.0541, 0.0541, 0.0541], [0.0622, 0.0611, 0.0416, ..., 0.0516, 0.0516, 0.0516], ..., [0.0385, 0.0434, 0.0429, ..., 0.0470, 0.0470, 0.0470], [0.0385, 0.0434, 0.0429, ..., 0.0470, 0.0470, 0.0470], [0.0385, 0.0434, 0.0429, ..., 0.0470, 0.0470, 0.0470]], ..., [[0.0485, 0.0438, 0.0449, ..., 0.0447, 0.0447, 0.0447], [0.0331, 0.0283, 0.0588, ..., 0.0636, 0.0636, 0.0636], [0.0586, 0.0647, 0.0924, ..., 0.0431, 0.0431, 0.0431], ..., [0.0706, 0.0612, 0.0526, ..., 0.0476, 0.0476, 0.0476], [0.0706, 0.0612, 0.0526, ..., 0.0476, 0.0476, 0.0476], [0.0706, 0.0612, 0.0526, ..., 0.0476, 0.0476, 0.0476]], [[0.0420, 0.0601, 0.0400, ..., 0.0466, 0.0466, 0.0466], [0.0508, 0.0614, 0.0849, ..., 0.0376, 0.0376, 0.0376], [0.0326, 0.0251, 0.0395, ..., 0.0514, 0.0514, 0.0514], ..., [0.0369, 0.0344, 0.0416, ..., 0.0564, 0.0564, 0.0564], [0.0369, 0.0344, 0.0416, ..., 0.0564, 0.0564, 0.0564], [0.0369, 0.0344, 0.0416, ..., 0.0564, 0.0564, 0.0564]], [[0.0524, 0.0410, 0.0576, ..., 0.0572, 0.0572, 0.0572], [0.0561, 0.0960, 0.0704, ..., 0.0430, 0.0430, 0.0430], [0.0610, 0.0537, 0.0277, ..., 0.0351, 0.0351, 0.0351], ..., [0.0485, 0.0806, 0.0452, ..., 0.0426, 0.0426, 0.0426], [0.0485, 0.0806, 0.0452, ..., 0.0426, 0.0426, 0.0426], [0.0485, 0.0806, 0.0452, ..., 0.0426, 0.0426, 0.0426]]], ..., [[[0.0356, 0.0693, 0.0371, ..., 0.0553, 0.0690, 0.0497], [0.0185, 0.0895, 0.0409, ..., 0.0365, 0.0746, 0.0806], [0.0202, 0.0567, 0.0417, ..., 0.0596, 0.0598, 0.0458], ..., [0.0617, 0.0810, 0.0333, ..., 0.0863, 0.0659, 0.0325], [0.0805, 0.0325, 0.0351, ..., 0.0385, 0.0400, 0.0501], [0.0539, 0.0366, 0.0618, ..., 0.0364, 0.0500, 0.0466]], [[0.0478, 0.0542, 0.0370, ..., 0.0801, 0.0679, 0.0344], [0.0852, 0.0416, 0.0637, ..., 0.0481, 0.0530, 0.0445], [0.0601, 0.0613, 0.0408, ..., 0.0407, 0.0693, 0.0615], ..., [0.0350, 0.0598, 0.0460, ..., 0.0949, 0.0513, 0.0563], [0.0545, 0.0409, 0.0532, ..., 0.0597, 0.0392, 0.0411], [0.0474, 0.0552, 0.0396, ..., 0.0646, 0.0546, 0.0326]], [[0.0484, 0.0461, 0.0176, ..., 0.0607, 0.0507, 0.0382], [0.0985, 0.0508, 0.0593, ..., 0.0314, 0.0546, 0.0861], [0.0388, 0.0508, 0.0483, ..., 0.0543, 0.0564, 0.0994], ..., [0.0504, 0.0384, 0.0708, ..., 0.0208, 0.0460, 0.0370], [0.0463, 0.0430, 0.0450, ..., 0.0732, 0.0456, 0.0704], [0.0380, 0.0473, 0.0382, ..., 0.0436, 0.0412, 0.0702]], ..., [[0.0574, 0.0286, 0.0633, ..., 0.0606, 0.0433, 0.0666], [0.0491, 0.0768, 0.0442, ..., 0.0694, 0.0412, 0.0423], [0.0570, 0.0356, 0.0540, ..., 0.0382, 0.0603, 0.0340], ..., [0.0669, 0.0394, 0.0790, ..., 0.0497, 0.0385, 0.0524], [0.0518, 0.0976, 0.0309, ..., 0.0465, 0.0353, 0.0399], [0.0435, 0.0437, 0.0476, ..., 0.0469, 0.0628, 0.0385]], [[0.0511, 0.0343, 0.0497, ..., 0.0592, 0.0426, 0.0299], [0.0433, 0.0324, 0.0599, ..., 0.0774, 0.0478, 0.0462], [0.0553, 0.0347, 0.0396, ..., 0.0414, 0.0329, 0.0392], ..., [0.0405, 0.0380, 0.0432, ..., 0.0903, 0.0603, 0.0343], [0.0414, 0.0506, 0.0374, ..., 0.0455, 0.0491, 0.0531], [0.0447, 0.0297, 0.0492, ..., 0.0441, 0.0549, 0.0418]], [[0.0401, 0.0182, 0.0629, ..., 0.0664, 0.0572, 0.0494], [0.0447, 0.0362, 0.0564, ..., 0.0397, 0.0405, 0.0733], [0.0449, 0.0701, 0.0348, ..., 0.0376, 0.0419, 0.0580], ..., [0.0698, 0.0570, 0.0360, ..., 0.0525, 0.0400, 0.0683], [0.0510, 0.0493, 0.0370, ..., 0.0363, 0.0471, 0.0376], [0.0576, 0.0597, 0.0692, ..., 0.0397, 0.0494, 0.0448]]], [[[0.0406, 0.0412, 0.0219, ..., 0.0877, 0.0877, 0.0877], [0.0338, 0.0559, 0.0529, ..., 0.0806, 0.0806, 0.0806], [0.0506, 0.0337, 0.0514, ..., 0.0318, 0.0318, 0.0318], ..., [0.0344, 0.0460, 0.0434, ..., 0.1100, 0.1100, 0.1100], [0.0344, 0.0460, 0.0434, ..., 0.1100, 0.1100, 0.1100], [0.0344, 0.0460, 0.0434, ..., 0.1100, 0.1100, 0.1100]], [[0.0601, 0.0350, 0.0664, ..., 0.0631, 0.0631, 0.0631], [0.0411, 0.0600, 0.0305, ..., 0.0237, 0.0237, 0.0237], [0.0283, 0.0500, 0.1090, ..., 0.0510, 0.0510, 0.0510], ..., [0.0614, 0.0681, 0.0461, ..., 0.0622, 0.0622, 0.0622], [0.0614, 0.0681, 0.0461, ..., 0.0622, 0.0622, 0.0622], [0.0614, 0.0681, 0.0461, ..., 0.0622, 0.0622, 0.0622]], [[0.0363, 0.0662, 0.0641, ..., 0.0387, 0.0387, 0.0387], [0.0313, 0.0288, 0.0348, ..., 0.0297, 0.0297, 0.0297], [0.0246, 0.1005, 0.0401, ..., 0.0490, 0.0490, 0.0490], ..., [0.0643, 0.0340, 0.0427, ..., 0.0503, 0.0503, 0.0503], [0.0643, 0.0340, 0.0427, ..., 0.0503, 0.0503, 0.0503], [0.0643, 0.0340, 0.0427, ..., 0.0503, 0.0503, 0.0503]], ..., [[0.0553, 0.0274, 0.0333, ..., 0.0620, 0.0620, 0.0620], [0.0565, 0.0468, 0.0630, ..., 0.0389, 0.0389, 0.0389], [0.0441, 0.0327, 0.0769, ..., 0.0423, 0.0423, 0.0423], ..., [0.0263, 0.1001, 0.0387, ..., 0.0492, 0.0492, 0.0492], [0.0263, 0.1001, 0.0387, ..., 0.0492, 0.0492, 0.0492], [0.0263, 0.1001, 0.0387, ..., 0.0492, 0.0492, 0.0492]], [[0.0537, 0.0480, 0.0862, ..., 0.0379, 0.0379, 0.0379], [0.0315, 0.0965, 0.0714, ..., 0.0440, 0.0440, 0.0440], [0.0546, 0.0409, 0.0454, ..., 0.0412, 0.0412, 0.0412], ..., [0.0292, 0.0569, 0.0452, ..., 0.0497, 0.0497, 0.0497], [0.0292, 0.0569, 0.0452, ..., 0.0497, 0.0497, 0.0497], [0.0292, 0.0569, 0.0452, ..., 0.0497, 0.0497, 0.0497]], [[0.0473, 0.0449, 0.0630, ..., 0.0461, 0.0461, 0.0461], [0.0498, 0.0690, 0.0543, ..., 0.0436, 0.0436, 0.0436], [0.0312, 0.0324, 0.0388, ..., 0.0469, 0.0469, 0.0469], ..., [0.0875, 0.0861, 0.0373, ..., 0.0404, 0.0404, 0.0404], [0.0875, 0.0861, 0.0373, ..., 0.0404, 0.0404, 0.0404], [0.0875, 0.0861, 0.0373, ..., 0.0404, 0.0404, 0.0404]]], [[[0.0968, 0.0516, 0.0518, ..., 0.0462, 0.0462, 0.0462], [0.0517, 0.0533, 0.0306, ..., 0.0768, 0.0768, 0.0768], [0.0638, 0.0841, 0.0273, ..., 0.0418, 0.0418, 0.0418], ..., [0.0317, 0.0387, 0.0320, ..., 0.0924, 0.0924, 0.0924], [0.0317, 0.0387, 0.0320, ..., 0.0924, 0.0924, 0.0924], [0.0317, 0.0387, 0.0320, ..., 0.0924, 0.0924, 0.0924]], [[0.0467, 0.0609, 0.0833, ..., 0.0450, 0.0450, 0.0450], [0.0436, 0.0675, 0.0896, ..., 0.0267, 0.0267, 0.0267], [0.0476, 0.0822, 0.0771, ..., 0.0561, 0.0561, 0.0561], ..., [0.0347, 0.0670, 0.0367, ..., 0.0612, 0.0612, 0.0612], [0.0347, 0.0670, 0.0367, ..., 0.0612, 0.0612, 0.0612], [0.0347, 0.0670, 0.0367, ..., 0.0612, 0.0612, 0.0612]], [[0.0339, 0.0394, 0.0391, ..., 0.0499, 0.0499, 0.0499], [0.0310, 0.0318, 0.0619, ..., 0.0328, 0.0328, 0.0328], [0.0831, 0.0336, 0.0379, ..., 0.0541, 0.0541, 0.0541], ..., [0.0359, 0.0349, 0.0613, ..., 0.0516, 0.0516, 0.0516], [0.0359, 0.0349, 0.0613, ..., 0.0516, 0.0516, 0.0516], [0.0359, 0.0349, 0.0613, ..., 0.0516, 0.0516, 0.0516]], ..., [[0.0544, 0.0314, 0.0356, ..., 0.0515, 0.0515, 0.0515], [0.0799, 0.0470, 0.0319, ..., 0.0391, 0.0391, 0.0391], [0.0408, 0.0561, 0.0754, ..., 0.0346, 0.0346, 0.0346], ..., [0.0207, 0.0974, 0.0423, ..., 0.0479, 0.0479, 0.0479], [0.0207, 0.0974, 0.0423, ..., 0.0479, 0.0479, 0.0479], [0.0207, 0.0974, 0.0423, ..., 0.0479, 0.0479, 0.0479]], [[0.0487, 0.0639, 0.0416, ..., 0.0497, 0.0497, 0.0497], [0.0415, 0.0952, 0.0803, ..., 0.0434, 0.0434, 0.0434], [0.0421, 0.0440, 0.0253, ..., 0.0651, 0.0651, 0.0651], ..., [0.0538, 0.0644, 0.0421, ..., 0.0563, 0.0563, 0.0563], [0.0538, 0.0644, 0.0421, ..., 0.0563, 0.0563, 0.0563], [0.0538, 0.0644, 0.0421, ..., 0.0563, 0.0563, 0.0563]], [[0.0271, 0.0523, 0.1258, ..., 0.0414, 0.0414, 0.0414], [0.0624, 0.0719, 0.0781, ..., 0.0454, 0.0454, 0.0454], [0.0386, 0.0539, 0.0486, ..., 0.0528, 0.0528, 0.0528], ..., [0.0546, 0.0886, 0.1174, ..., 0.0416, 0.0416, 0.0416], [0.0546, 0.0886, 0.1174, ..., 0.0416, 0.0416, 0.0416], [0.0546, 0.0886, 0.1174, ..., 0.0416, 0.0416, 0.0416]]]], grad_fn=<SoftmaxBackward>) torch.Size([10, 8, 20, 20])
 
 
 
[15]
 
 
 
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)
 
print(attn_values.shape)
 
 
 
 
 
 
torch.Size([10, 8, 20, 64])
 
 
 
 

각 head의 결과물 병합

 
 
 
 
 

각 head의 결과물을 concat하고 동일 차원으로 linear transformation합니다.

 
 
 
 
[16]
 
 
 
attn_values = attn_values.transpose(1, 2)  # (B, L, num_heads, d_k)
attn_values = attn_values.contiguous().view(batch_size, -1, d_model)  # (B, 
L, d_model)
 
print(attn_values.shape)
 
 
 
 
 
torch.Size([10, 20, 512])
 
 
 
[17]
 
 
 
outputs = w_0(attn_values)
 
print(outputs)
print(outputs.shape)
 
 
 
 
 
 
tensor([[[-1.1352e-01, -1.9139e-01, 3.4395e-02, ..., 4.0102e-02, -2.2987e-01, 1.9529e-01], [-8.6867e-02, -2.1087e-01, -1.0989e-02, ..., 4.2762e-02, -1.9277e-01, 8.8230e-02], [-1.0772e-01, -2.3987e-01, 3.8831e-02, ..., -9.8586e-03, -2.1114e-01, 1.1364e-01], ..., [-1.4770e-01, -2.5519e-01, 3.5802e-02, ..., 2.1453e-02, -1.9920e-01, 1.3214e-01], [-1.4770e-01, -2.5519e-01, 3.5802e-02, ..., 2.1453e-02, -1.9920e-01, 1.3214e-01], [-1.4770e-01, -2.5519e-01, 3.5802e-02, ..., 2.1453e-02, -1.9920e-01, 1.3214e-01]], [[-1.9061e-01, -5.3132e-01, 7.5893e-02, ..., 1.5970e-01, -5.5873e-01, 5.3326e-01], [-2.2175e-01, -5.3580e-01, 1.1941e-01, ..., 1.4892e-01, -5.1877e-01, 5.3374e-01], [-2.5074e-01, -6.1631e-01, 1.0151e-01, ..., 1.8491e-01, -5.1544e-01, 5.0850e-01], ..., [-2.2835e-01, -5.4975e-01, 9.6207e-02, ..., 1.9166e-01, -5.5457e-01, 5.4181e-01], [-2.2835e-01, -5.4975e-01, 9.6207e-02, ..., 1.9166e-01, -5.5457e-01, 5.4181e-01], [-2.2835e-01, -5.4975e-01, 9.6207e-02, ..., 1.9166e-01, -5.5457e-01, 5.4181e-01]], [[-3.5870e-03, -2.8132e-01, 8.0012e-02, ..., 1.3743e-01, -3.1456e-01, 2.6330e-01], [-8.5198e-02, -3.6469e-01, 6.6780e-02, ..., 1.6005e-01, -3.5290e-01, 2.1697e-01], [-3.4478e-02, -3.8513e-01, 9.0698e-02, ..., 1.4251e-01, -2.7272e-01, 1.5694e-01], ..., [-1.2265e-01, -3.9496e-01, 6.3260e-02, ..., 1.1824e-01, -3.8367e-01, 2.5895e-01], [-1.2265e-01, -3.9496e-01, 6.3260e-02, ..., 1.1824e-01, -3.8367e-01, 2.5895e-01], [-1.2265e-01, -3.9496e-01, 6.3260e-02, ..., 1.1824e-01, -3.8367e-01, 2.5895e-01]], ..., [[ 3.9329e-02, -1.6083e-02, -1.3860e-01, ..., 2.2836e-02, 1.0380e-02, 1.6353e-01], [ 3.5097e-02, -8.2941e-03, -1.9275e-01, ..., -7.4558e-03, 2.7378e-02, 1.5106e-01], [ 1.3097e-02, 1.3226e-02, -1.8156e-01, ..., 3.8179e-02, 2.5640e-02, 1.0779e-01], ..., [ 4.8920e-02, 4.3918e-02, -1.8756e-01, ..., 2.1911e-02, 4.1894e-02, 1.4420e-01], [ 3.8124e-02, -3.4729e-02, -1.4246e-01, ..., 2.2699e-02, 8.5502e-02, 1.5101e-01], [ 5.7281e-02, 4.1677e-02, -1.5282e-01, ..., 5.5112e-02, -1.1841e-03, 1.2372e-01]], [[-1.8894e-02, -2.2604e-01, -6.2761e-03, ..., 2.4829e-02, -2.0479e-01, 8.3719e-02], [-1.0974e-02, -2.0397e-01, -2.7611e-03, ..., -1.2597e-02, -1.3351e-01, 5.0224e-02], [ 4.9551e-02, -1.9723e-01, -6.5305e-02, ..., 6.6988e-02, -1.7278e-01, 1.6469e-01], ..., [-1.4522e-02, -2.6608e-01, 1.3490e-04, ..., 4.5861e-02, -2.0047e-01, 1.3786e-01], [-1.4522e-02, -2.6608e-01, 1.3490e-04, ..., 4.5861e-02, -2.0047e-01, 1.3786e-01], [-1.4522e-02, -2.6608e-01, 1.3490e-04, ..., 4.5861e-02, -2.0047e-01, 1.3786e-01]], [[ 1.9836e-02, -1.5901e-01, -5.8236e-02, ..., 1.1571e-01, -1.9807e-01, 3.1994e-01], [ 3.1970e-02, -1.6160e-01, -4.8095e-02, ..., 3.1911e-02, -1.3050e-01, 3.1769e-01], [ 3.0805e-02, -1.4795e-01, -8.4109e-02, ..., 1.0025e-01, -1.9486e-01, 3.1365e-01], ..., [ 2.7737e-02, -1.9884e-01, -2.9289e-02, ..., 1.3317e-01, -2.6549e-01, 3.4827e-01], [ 2.7737e-02, -1.9884e-01, -2.9289e-02, ..., 1.3317e-01, -2.6549e-01, 3.4827e-01], [ 2.7737e-02, -1.9884e-01, -2.9289e-02, ..., 1.3317e-01, -2.6549e-01, 3.4827e-01]]], grad_fn=<AddBackward0>) torch.Size([10, 20, 512])
 
 
 
 

전체 코드

 
 
 
 
 

위의 과정을 모두 합쳐 하나의 Multi-head attention 모듈을 구현하겠습니다.

 
 
 
 
[18]
 
 
 
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()
 
    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
 
    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)
 
  def forward(self, q, k, v):
    batch_size = q.shape[0]
 
    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)
 
    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
 
    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)
 
    attn_values = self.self_attention(q, k, v)  # (B, num_heads, L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, 
-1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)
 
    return self.w_0(attn_values)
 
  def self_attention(self, q, k, v):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # 
(B, num_heads, L, L)
    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)
 
    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)
 
    return attn_values
 
 
 
 
 
 
 
 
 
[19]
 
 
 
multihead_attn = MultiheadAttention()
 
outputs = multihead_attn(batch_emb, batch_emb, batch_emb)  # (B, L, d_model)
 
 
 
 
 
 
 
 
[20]
 
 
 
print(outputs)
print(outputs.shape)
 
 
 
 
 
 
tensor([[[ 9.2516e-03, 1.7369e-01, 1.6916e-03, ..., 1.1210e-01, -1.6726e-02, 3.4743e-01], [-3.6782e-02, 1.2917e-01, -4.9187e-02, ..., 1.4615e-01, -3.5719e-02, 3.4785e-01], [ 3.5305e-03, 1.6074e-01, -8.8263e-02, ..., 1.8464e-01, 4.2831e-02, 4.0534e-01], ..., [-2.8834e-02, 1.2947e-01, -3.5194e-02, ..., 1.8476e-01, -3.0066e-05, 3.2937e-01], [-2.8834e-02, 1.2947e-01, -3.5194e-02, ..., 1.8476e-01, -3.0066e-05, 3.2937e-01], [-2.8834e-02, 1.2947e-01, -3.5194e-02, ..., 1.8476e-01, -3.0066e-05, 3.2937e-01]], [[ 1.4557e-01, -1.0484e-01, -5.2524e-02, ..., 6.9377e-01, 6.0970e-02, 4.6097e-01], [ 1.5773e-01, -8.2220e-02, -7.7061e-02, ..., 7.2866e-01, 4.4147e-02, 4.3781e-01], [ 1.3460e-01, -9.5307e-02, -9.3904e-02, ..., 7.3514e-01, 7.2221e-02, 4.4110e-01], ..., [ 1.7378e-01, -1.1004e-01, -1.2651e-01, ..., 7.8337e-01, 6.8945e-02, 3.9528e-01], [ 1.7378e-01, -1.1004e-01, -1.2651e-01, ..., 7.8337e-01, 6.8945e-02, 3.9528e-01], [ 1.7378e-01, -1.1004e-01, -1.2651e-01, ..., 7.8337e-01, 6.8945e-02, 3.9528e-01]], [[ 1.0813e-01, -1.9237e-01, 3.3813e-02, ..., 3.8729e-01, 1.4551e-01, 3.0486e-01], [ 1.1949e-01, -1.7870e-01, 1.7301e-02, ..., 4.1831e-01, 1.1177e-01, 2.9937e-01], [ 1.2946e-01, -1.1937e-01, -2.8948e-02, ..., 5.2200e-01, 1.1423e-01, 3.4231e-01], ..., [ 7.3729e-02, -1.1033e-01, -8.8151e-02, ..., 3.9152e-01, 1.6017e-01, 2.6974e-01], [ 7.3729e-02, -1.1033e-01, -8.8151e-02, ..., 3.9152e-01, 1.6017e-01, 2.6974e-01], [ 7.3729e-02, -1.1033e-01, -8.8151e-02, ..., 3.9152e-01, 1.6017e-01, 2.6974e-01]], ..., [[ 1.1427e-01, -8.0809e-02, -9.1286e-02, ..., 1.3664e-02, 1.7932e-01, 1.6808e-02], [ 1.0275e-01, -9.2001e-02, -1.2701e-01, ..., 4.8521e-03, 1.9850e-01, 6.6440e-02], [ 5.8104e-02, -5.4810e-02, -1.3720e-01, ..., 5.1057e-02, 1.2398e-01, 3.6784e-02], ..., [ 1.1326e-01, -3.4439e-02, -7.9127e-02, ..., 2.4230e-02, 1.2355e-01, 3.7422e-02], [ 5.5170e-02, -1.9381e-02, -8.2321e-02, ..., 2.8540e-02, 1.3763e-01, 5.2429e-02], [ 1.0601e-01, -3.5269e-02, -9.8664e-02, ..., 3.1459e-03, 1.2400e-01, 6.3038e-02]], [[ 3.0379e-02, 4.2547e-02, -5.1137e-02, ..., 1.6373e-01, 3.6856e-02, 1.4134e-01], [-1.8070e-02, 7.9446e-02, 3.5455e-02, ..., 1.3126e-01, -6.4744e-02, 1.3485e-01], [ 7.5835e-02, 8.2923e-02, -2.5474e-02, ..., 2.0878e-01, 9.6988e-02, 5.7338e-02], ..., [ 2.6591e-02, 1.3556e-01, -3.2525e-02, ..., 1.7950e-01, 3.0757e-02, 1.0859e-01], [ 2.6591e-02, 1.3556e-01, -3.2525e-02, ..., 1.7950e-01, 3.0757e-02, 1.0859e-01], [ 2.6591e-02, 1.3556e-01, -3.2525e-02, ..., 1.7950e-01, 3.0757e-02, 1.0859e-01]], [[-8.6823e-03, -4.3011e-02, -1.4872e-01, ..., 3.8146e-01, -3.7269e-02, 2.0649e-01], [-3.7705e-02, -4.4422e-02, -7.4764e-02, ..., 2.4636e-01, -1.0249e-01, 1.8855e-01], [-3.8483e-02, -2.6246e-02, -5.2669e-02, ..., 2.3631e-01, 2.0639e-03, 1.6390e-01], ..., [-8.9017e-03, -7.5137e-02, -1.4088e-01, ..., 3.3598e-01, -2.4982e-02, 1.8044e-01], [-8.9017e-03, -7.5137e-02, -1.4088e-01, ..., 3.3598e-01, -2.4982e-02, 1.8044e-01], [-8.9017e-03, -7.5137e-02, -1.4088e-01, ..., 3.3598e-01, -2.4982e-02, 1.8044e-01]]], grad_fn=<AddBackward0>) torch.Size([10, 20, 512])

 

 

(실습 8강) Masked Multi-head Attention 구현

 

mask 작업. 미래를 못 보게.

 

##8. Masked Multi-head Attention 1. Masked Multi-head Attention 구현. 2. Encoder-Decoder Attention 구현.

 
 
 
 
 

필요 패키지 import

 
 
 
 
[1]
 
 
 
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
 
import torch
import math
 
 
 
 
 
 
 
 
 
 

데이터 전처리

 
 
 
 
 

데이터의 값과 형태를 좀 더 명확하게 보기 위해 sample을 줄이겠습니다.

 
 
 
 
[2]
 
 
 
pad_id = 0
vocab_size = 100
 
data = [
  [62, 13, 47, 39, 78, 33, 56, 13],
  [60, 96, 51, 32, 90],
  [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
  [66, 88, 98, 47],
  [77, 65, 51, 77, 19, 15, 35, 19, 23]
]
 
 
 
 
 
 
 
 
 
 
[3]
 
 
 
def padding(data):
  max_len = len(max(data, key=len))
  print(f"Maximum sequence length: {max_len}")
 
  for i, seq in enumerate(tqdm(data)):
    if len(seq) < max_len:
      data[i] = seq + [pad_id] * (max_len - len(seq))
 
  return data, max_len
 
 
 
 
 
 
 
 
 
 
[4]
 
 
 
data, max_len = padding(data)
 
 
 
 
 
 
100%|██████████| 5/5 [00:00<00:00, 3296.37it/s]
Maximum sequence length: 10
 
 
 
 
[5]
 
 
 
data
 
 
 
 
 
 
[[62, 13, 47, 39, 78, 33, 56, 13, 0, 0],
 [60, 96, 51, 32, 90, 0, 0, 0, 0, 0],
 [35, 45, 48, 65, 91, 99, 92, 10, 3, 21],
 [66, 88, 98, 47, 0, 0, 0, 0, 0, 0],
 [77, 65, 51, 77, 19, 15, 35, 19, 23, 0]]
 
 
 
 

Hyperparameter 세팅 및 embedding

 
 
 
 
[6]
 
 
 
d_model = 8  # model의 hidden size
num_heads = 2  # head의 개수
inf = 1e12
 
 
 
 
 
 
 
 
 
 
[7]
 
 
 
embedding = nn.Embedding(vocab_size, d_model)
 
# B: batch size, L: maximum sequence length
batch = torch.LongTensor(data)  # (B, L)
batch_emb = embedding(batch)  # (B, L, d_model)
 
 
 
 
 
 
 
 
 
 
[8]
 
 
 
print(batch_emb)
print(batch_emb.shape)
 
 
 
 
 
 
tensor([[[ 2.5978e-02, -1.1719e+00, -5.6547e-01, 1.0690e+00, -7.4584e-01, -1.0695e+00, 1.4428e+00, -2.7004e+00], [-4.7617e-01, -1.3327e+00, 1.9251e+00, -6.8176e-01, 7.5115e-02, 5.3887e-01, 2.2054e-01, -2.0816e-01], [-8.6807e-01, 1.1268e+00, -7.2726e-01, -1.0275e+00, -3.0366e-01, 1.2544e+00, -7.0513e-02, -1.0134e+00], [-1.2948e+00, -2.5417e+00, -2.5985e-01, -3.3389e-01, 2.0048e-02, -1.6515e-01, -7.6054e-01, 1.1995e+00], [-1.1619e+00, -1.7698e+00, -5.5598e-01, -2.6992e-01, 1.3043e+00, -2.6215e-01, -6.2565e-01, -3.4484e-01], [-1.4553e+00, 7.6459e-01, -4.2104e-01, -5.1377e-01, 8.8455e-01, -1.5364e+00, 9.5698e-02, -1.2962e+00], [ 1.4414e+00, 6.7954e-01, 1.6368e-01, 6.5510e-01, 1.9676e-01, 2.7868e-01, 1.1996e-02, -7.4251e-01], [-4.7617e-01, -1.3327e+00, 1.9251e+00, -6.8176e-01, 7.5115e-02, 5.3887e-01, 2.2054e-01, -2.0816e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01]], [[ 1.3878e-01, -6.1272e-01, -9.2627e-01, 8.2591e-01, -4.9490e-01, 1.4858e+00, 3.7874e-01, 1.6428e-01], [ 1.6573e+00, -1.2150e+00, -1.8417e-01, 6.2360e-01, 6.1281e-01, -2.2841e-03, 8.1279e-01, 2.9292e-01], [ 6.9719e-01, 3.5959e-01, 1.0445e+00, 1.2747e+00, 2.3077e+00, 5.2847e-01, 1.1980e+00, -6.0787e-01], [ 2.5983e+00, 2.8562e+00, 6.5606e-01, -2.2477e-01, 1.8020e-01, 1.8544e+00, 1.2822e+00, -1.0173e+00], [ 2.5266e-01, 1.1753e+00, -2.5657e-01, -1.7501e+00, 2.5095e+00, 1.4618e+00, 5.3141e-01, -1.0419e+00], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01]], [[ 6.4874e-01, -1.2360e+00, 6.8337e-01, 6.0631e-01, -1.6179e+00, -1.8304e+00, 1.1675e+00, -1.3559e+00], [ 1.0240e+00, -1.5537e+00, -9.4666e-01, -1.5513e+00, 2.9823e-02, -3.6872e-01, -1.4232e+00, -4.1756e-01], [ 3.0624e+00, -8.0635e-01, 2.0955e+00, 2.7434e-02, -1.0448e+00, -1.3578e+00, -1.2429e+00, -9.7899e-01], [ 5.8301e-01, 5.7118e-01, 8.3664e-02, -9.9143e-01, -5.9037e-01, 1.4771e-02, 7.2694e-01, -3.0060e-01], [-6.9838e-01, -3.6387e-01, -4.6559e-01, -2.0434e+00, -2.3196e+00, -9.8511e-01, -1.8809e-01, -5.3997e-01], [-1.0637e+00, 1.0115e+00, -1.3071e+00, -2.4907e-01, -2.4333e-02, -4.5905e-01, 9.4616e-01, 5.4789e-01], [ 7.7480e-01, -3.0079e-01, -1.7079e-01, 6.4207e-01, -8.1697e-02, 1.4789e+00, 7.9172e-01, -5.1938e-01], [ 5.0799e-01, 8.9652e-01, -1.6079e+00, -1.1147e+00, 1.5580e-01, 8.5131e-01, -7.9493e-01, 1.8839e+00], [-2.8777e-01, 4.7038e-01, 1.1657e+00, -3.4352e-01, 2.4759e-01, 1.7312e+00, -5.9322e-01, 2.5661e+00], [-6.4382e-01, 7.6634e-01, -2.5152e-02, -3.9127e-01, 3.1379e-02, 1.0803e+00, -2.6616e-01, -9.6649e-02]], [[ 3.9309e-01, 5.3615e-01, 1.4154e+00, 1.2089e+00, 1.5527e+00, 1.2730e+00, 4.5496e-01, 6.8353e-01], [ 5.6372e-01, -1.1905e+00, 7.8466e-01, -9.8275e-01, -1.4256e+00, -1.4576e-01, -9.5380e-02, -1.5898e-01], [ 1.5278e+00, 8.1257e-01, 6.3651e-01, 7.1092e-01, -4.2330e-02, 2.6004e-01, -6.3720e-01, 9.4828e-01], [-8.6807e-01, 1.1268e+00, -7.2726e-01, -1.0275e+00, -3.0366e-01, 1.2544e+00, -7.0513e-02, -1.0134e+00], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01]], [[ 1.8276e+00, 2.1958e+00, 7.5264e-02, -1.2217e-03, 1.6027e-01, -4.3237e-01, 1.4135e-01, -9.1643e-01], [ 5.8301e-01, 5.7118e-01, 8.3664e-02, -9.9143e-01, -5.9037e-01, 1.4771e-02, 7.2694e-01, -3.0060e-01], [ 6.9719e-01, 3.5959e-01, 1.0445e+00, 1.2747e+00, 2.3077e+00, 5.2847e-01, 1.1980e+00, -6.0787e-01], [ 1.8276e+00, 2.1958e+00, 7.5264e-02, -1.2217e-03, 1.6027e-01, -4.3237e-01, 1.4135e-01, -9.1643e-01], [-1.4659e-01, -1.5753e+00, 2.2311e+00, -1.0745e+00, 5.2471e-03, -4.5582e-01, -4.2744e-01, -4.0704e-01], [-2.0087e-01, -1.3524e+00, 9.0261e-01, 1.3093e-01, -9.6100e-02, -5.0534e-02, 1.4622e+00, -9.9551e-01], [ 6.4874e-01, -1.2360e+00, 6.8337e-01, 6.0631e-01, -1.6179e+00, -1.8304e+00, 1.1675e+00, -1.3559e+00], [-1.4659e-01, -1.5753e+00, 2.2311e+00, -1.0745e+00, 5.2471e-03, -4.5582e-01, -4.2744e-01, -4.0704e-01], [ 2.1318e-01, 8.9759e-02, 1.1890e+00, -9.0741e-01, -2.3283e+00, 8.3807e-01, -2.7013e+00, -1.0480e+00], [ 5.7980e-01, 1.3983e+00, -4.4109e-01, -6.0635e-01, 1.6694e+00, -1.7608e+00, 3.4570e-01, 8.3854e-01]]], grad_fn=<EmbeddingBackward>) torch.Size([5, 10, 8])
 
 
 
 

Mask 구축

 
 
 
 
 

True는 attention이 적용될 부분, False는 masking될 자리입니다.

 
 
 
 
[9]
 
 
 
padding_mask = (batch != pad_id).unsqueeze(1)  # (B, 1, L)
 
print(padding_mask)
print(padding_mask.shape)
 
 
 
 
 
 
tensor([[[ True, True, True, True, True, True, True, True, False, False]], [[ True, True, True, True, True, False, False, False, False, False]], [[ True, True, True, True, True, True, True, True, True, True]], [[ True, True, True, True, False, False, False, False, False, False]], [[ True, True, True, True, True, True, True, True, True, False]]]) torch.Size([5, 1, 10])
 
 
 
[10]
 
 
 
nopeak_mask = torch.ones([1, max_len, max_len], dtype=torch.bool)  # (1, L, 
L)
nopeak_mask = torch.tril(nopeak_mask)  # (1, L, L)
 
print(nopeak_mask)
print(nopeak_mask.shape)
 
 
 
 
 
 
tensor([[[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, True, False, False, False, False], [ True, True, True, True, True, True, True, False, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, True, False], [ True, True, True, True, True, True, True, True, True, True]]]) torch.Size([1, 10, 10])
 
 
 
[11]
 
 
 
mask = padding_mask & nopeak_mask  # (B, L, L)
 
print(mask)
print(mask.shape)
 
 
 
 
 
 
tensor([[[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, True, False, False, False, False], [ True, True, True, True, True, True, True, False, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, False, False]], [[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False]], [[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, True, False, False, False, False], [ True, True, True, True, True, True, True, False, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, True, False], [ True, True, True, True, True, True, True, True, True, True]], [[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False]], [[ True, False, False, False, False, False, False, False, False, False], [ True, True, False, False, False, False, False, False, False, False], [ True, True, True, False, False, False, False, False, False, False], [ True, True, True, True, False, False, False, False, False, False], [ True, True, True, True, True, False, False, False, False, False], [ True, True, True, True, True, True, False, False, False, False], [ True, True, True, True, True, True, True, False, False, False], [ True, True, True, True, True, True, True, True, False, False], [ True, True, True, True, True, True, True, True, True, False], [ True, True, True, True, True, True, True, True, True, False]]]) torch.Size([5, 10, 10])
 
 
 
 

Linear transformation & 여러 head로 나누기

 
 
 
 
[12]
 
 
 
w_q = nn.Linear(d_model, d_model)
w_k = nn.Linear(d_model, d_model)
w_v = nn.Linear(d_model, d_model)
 
w_0 = nn.Linear(d_model, d_model)
 
 
 
 
 
 
 
 
 
 
[13]
 
 
 
q = w_q(batch_emb)  # (B, L, d_model)
k = w_k(batch_emb)  # (B, L, d_model)
v = w_v(batch_emb)  # (B, L, d_model)
 
batch_size = q.shape[0]
d_k = d_model // num_heads
 
q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
 
q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, L, d_k)
 
print(q.shape)
print(k.shape)
print(v.shape)
 
 
 
 
 
 
torch.Size([5, 2, 10, 4]) torch.Size([5, 2, 10, 4]) torch.Size([5, 2, 10, 4])
 
 
 
 

Masking이 적용된 self-attention 구현

 
 
 
 
[14]
 
 
 
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, 
num_heads, L, L)
 
 
 
 
 
 
 
 
 
 
[15]
 
 
 
masks = mask.unsqueeze(1)  # (B, 1, L, L)
# head가 추가됐기 때문에 head만큼 반복해줘야 해서 1만큼 unsqueeze를 해줬다고 
한다.
masked_attn_scores = attn_scores.masked_fill_(masks == False, -1 * inf)  # 
(B, num_heads, L, L)
 
print(masked_attn_scores)
print(masked_attn_scores.shape)
 
 
 
 
 
 
tensor([[[[ 4.7637e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.0336e-01, 1.5477e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-6.7951e-02, 2.4981e-01, -3.1751e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.7413e-01, -6.5847e-02, 7.6012e-01, 5.9436e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.6878e-01, 1.2790e-01, -7.6460e-02, 1.0074e-01, 1.0696e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.2281e-01, 1.8613e-01, -4.6751e-01, 1.1887e-01, -1.5257e-01, -2.7174e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.0118e-02, 7.6968e-02, -1.7419e-01, 3.8869e-03, 5.2239e-02, 1.3832e-01, -1.2507e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.0336e-01, 1.5477e-02, -7.8466e-02, 7.1126e-02, 2.9099e-01, 2.6656e-01, -2.8050e-02, 1.5477e-02, -1.0000e+12, -1.0000e+12], [ 9.2820e-02, 1.7372e-01, 1.3079e-02, 1.8945e-01, -2.2064e-01, -2.8560e-01, -3.8994e-02, 1.7372e-01, -1.0000e+12, -1.0000e+12], [ 9.2820e-02, 1.7372e-01, 1.3079e-02, 1.8945e-01, -2.2064e-01, -2.8560e-01, -3.8994e-02, 1.7372e-01, -1.0000e+12, -1.0000e+12]], [[ 3.0696e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.3713e-02, -3.8888e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.5046e-01, 2.0949e-01, 2.6518e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-5.5559e-01, 2.2722e-01, -1.9210e-01, -3.1048e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.3784e-01, -1.2469e-01, 2.6978e-04, -5.4546e-01, -2.2231e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.3411e-01, -2.9132e-01, 9.4263e-04, 5.5149e-02, -5.4124e-02, 3.6974e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 5.9486e-01, 2.5571e-01, 4.5128e-01, 6.3356e-01, 3.5252e-01, 3.4562e-01, -4.0415e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.3713e-02, -3.8888e-01, 4.4088e-04, -5.7006e-01, -3.5372e-01, 9.7312e-02, -3.1540e-02, -3.8888e-01, -1.0000e+12, -1.0000e+12], [ 8.1846e-01, -7.7122e-02, 4.6495e-01, 4.2805e-01, 1.1123e-01, 5.6259e-01, -5.6964e-01, -7.7122e-02, -1.0000e+12, -1.0000e+12], [ 8.1846e-01, -7.7122e-02, 4.6495e-01, 4.2805e-01, 1.1123e-01, 5.6259e-01, -5.6964e-01, -7.7122e-02, -1.0000e+12, -1.0000e+12]]], [[[-1.3960e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.6457e-01, -1.8845e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.2020e-01, 7.4878e-01, 5.0221e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 9.2887e-02, 4.7570e-01, 2.9805e-01, -1.8933e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.2282e-01, 7.2596e-01, 2.9831e-01, -2.5718e-01, -1.7101e+00, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.1973e-01, 7.2967e-03, -1.9194e-01, -3.9124e-02, -5.8100e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12]], [[ 4.4586e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.4805e-01, -1.9373e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.4653e-01, -6.7971e-01, -6.4486e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.6247e-02, -1.0307e+00, -2.2533e+00, -1.7807e+00, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.3956e-01, -1.1370e+00, -1.2380e+00, -5.0333e-01, -2.0255e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.0716e-02, -6.5803e-01, -1.4617e+00, -1.1117e+00, -8.4180e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12]]], [[[-5.5858e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.3713e+00, 3.1185e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.6621e+00, 9.3574e-01, -5.5574e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.2074e-01, -1.0291e-01, -1.9896e-01, 4.5147e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.3240e+00, 3.3535e-01, -7.6370e-01, 1.4674e-01, 2.4826e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.3187e-01, -4.9049e-01, -2.7165e-01, 4.5427e-02, 2.6163e-01, 2.2556e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.9720e-01, -3.9120e-01, -2.4132e-01, -9.5560e-02, -1.1611e-01, 8.7099e-02, -1.1915e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-6.2815e-01, -6.7362e-01, -1.1487e+00, 2.2814e-01, 4.1933e-01, 6.7831e-01, 9.8324e-02, 7.0435e-01, -1.0000e+12, -1.0000e+12], [ 7.8095e-02, -3.5943e-01, -6.1643e-01, 6.0242e-02, 4.2268e-01, 3.6293e-01, 5.3020e-02, -2.5441e-02, 2.9505e-01, -1.0000e+12], [ 2.9748e-01, -5.2630e-01, -2.6692e-01, 1.8622e-02, 1.6033e-02, 1.8395e-01, 3.2762e-02, -2.4421e-01, 2.4621e-01, -3.8580e-02]], [[ 8.8483e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.8341e-01, 3.9749e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 9.5037e-01, 8.4297e-01, -1.1219e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 7.8095e-01, 5.1285e-01, 3.5245e-02, 2.2964e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 7.2292e-01, 8.4798e-01, 6.7199e-01, 1.4567e-01, 1.0198e+00, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 4.5190e-01, 3.6850e-01, 3.8342e-01, 1.0475e-01, 5.3902e-01, -4.8990e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.3664e-01, 1.0679e-01, -1.2078e-01, 7.9969e-02, 3.3342e-01, -8.0699e-02, 2.1315e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.3649e-01, 5.4584e-01, -1.9944e-01, 1.2506e-01, 8.4549e-01, -3.4700e-02, 3.1539e-01, -6.3801e-02, -1.0000e+12, -1.0000e+12], [-3.6674e-01, -1.4907e-01, -3.9890e-01, -5.3958e-02, -2.9095e-01, -5.9863e-02, 1.4150e-01, -4.4966e-02, 1.2219e-01, -1.0000e+12], [-1.2761e-03, 1.0568e-01, -1.9336e-01, 2.1486e-02, 2.4646e-01, 6.9932e-02, -3.2433e-03, -3.7407e-03, -1.6667e-02, 1.0430e-01]]], [[[-1.0113e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.0663e-01, -3.4611e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.9826e-02, -1.8627e-01, -2.0277e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 3.6179e-01, 1.0622e-01, 1.2711e-01, -3.1751e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 8.9095e-02, 2.5991e-01, 1.3676e-01, 1.3079e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12]], [[-3.0940e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.4044e-01, 5.0485e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-6.1672e-01, 5.7552e-01, -4.5673e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-3.5909e-01, 2.8637e-01, -2.8371e-01, 2.6518e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.4753e+00, 5.8035e-01, -8.2558e-01, 4.6495e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12]]], [[[-2.8563e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.4263e-02, 4.5147e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-5.8976e-01, -2.1452e-01, 5.0221e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.8563e-02, 3.2587e-02, 1.0890e-01, -2.8563e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.9209e-01, -6.2004e-02, -4.7334e-02, 1.9209e-01, -6.0634e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-1.8172e-01, -1.7188e-01, 2.1711e-01, -1.8172e-01, 2.6872e-01, 3.7224e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 2.4097e-01, -1.4395e-01, 4.1687e-02, 2.4097e-01, -5.4559e-02, -3.5269e-01, -5.5858e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.9209e-01, -6.2004e-02, -4.7334e-02, 1.9209e-01, -6.0634e-02, -2.6342e-01, -3.4811e-01, -6.0634e-02, -1.0000e+12, -1.0000e+12], [ 4.9793e-01, 9.6707e-02, -9.1980e-02, 4.9793e-01, -5.1975e-01, -9.9177e-01, -1.7358e+00, -5.1975e-01, 2.7955e-01, -1.0000e+12], [-1.8620e-01, 8.5844e-02, -1.9194e-01, -1.8620e-01, 3.6767e-02, 1.8844e-01, 3.0871e-01, 3.6767e-02, 2.6375e-01, -1.0000e+12]], [[-8.2484e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-4.9804e-01, 2.2964e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 6.2210e-02, 9.8633e-02, -6.4486e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-8.2484e-01, 3.5734e-01, -1.9387e+00, -8.2484e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.2728e-01, 1.2025e-02, -2.8253e-01, 1.2728e-01, -3.7051e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.2581e-01, 5.4310e-02, -1.8193e-01, 1.2581e-01, -2.2922e-01, -5.5643e-02, -1.0000e+12, -1.0000e+12, -1.0000e+12, -1.0000e+12], [-2.3377e-01, 1.7875e-01, -7.1560e-01, -2.3377e-01, 4.8038e-01, 2.6586e-01, 8.8483e-01, -1.0000e+12, -1.0000e+12, -1.0000e+12], [ 1.2728e-01, 1.2025e-02, -2.8253e-01, 1.2728e-01, -3.7051e-01, -2.2962e-01, -2.0968e-02, -3.7051e-01, -1.0000e+12, -1.0000e+12], [-5.0137e-01, 5.3234e-03, -2.2427e-01, -5.0137e-01, 7.3389e-01, 2.6270e-01, 8.9148e-02, 7.3389e-01, 6.9793e-01, -1.0000e+12], [-4.6310e-01, 2.6777e-01, -1.4617e+00, -4.6310e-01, 9.2073e-02, 1.7478e-01, 8.4084e-01, 9.2073e-02, 7.2260e-01, -1.0000e+12]]]], grad_fn=<MaskedFillBackward0>) torch.Size([5, 2, 10, 10])
 
 
 
 

-1* inf로 masking된 부분은 softmax 후 0이 됩니다.

 
 
 
 
[16]
 
 
 
attn_dists = F.softmax(masked_attn_scores, dim=-1)  # (B, num_heads, L, L)
 
print(attn_dists)
print(attn_dists.shape)
 
 
 
 
 
 
tensor([[[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5958, 0.4042, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3171, 0.4358, 0.2471, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1688, 0.1881, 0.4297, 0.2133, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2166, 0.2079, 0.1695, 0.2024, 0.2036, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2318, 0.2021, 0.1051, 0.1890, 0.1441, 0.1279, 0.0000, 0.0000, 0.0000, 0.0000], [0.1477, 0.1532, 0.1192, 0.1424, 0.1495, 0.1629, 0.1252, 0.0000, 0.0000, 0.0000], [0.1637, 0.1111, 0.1011, 0.1175, 0.1463, 0.1428, 0.1064, 0.1111, 0.0000, 0.0000], [0.1336, 0.1448, 0.1233, 0.1471, 0.0976, 0.0915, 0.1171, 0.1448, 0.0000, 0.0000], [0.1336, 0.1448, 0.1233, 0.1471, 0.0976, 0.0915, 0.1171, 0.1448, 0.0000, 0.0000]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5757, 0.4243, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3362, 0.3227, 0.3412, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1694, 0.3705, 0.2436, 0.2164, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1794, 0.2220, 0.2515, 0.1457, 0.2013, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1955, 0.1156, 0.1549, 0.1635, 0.1466, 0.2239, 0.0000, 0.0000, 0.0000, 0.0000], [0.1802, 0.1283, 0.1561, 0.1873, 0.1414, 0.1404, 0.0663, 0.0000, 0.0000, 0.0000], [0.1390, 0.1025, 0.1512, 0.0855, 0.1061, 0.1666, 0.1465, 0.1025, 0.0000, 0.0000], [0.2122, 0.0867, 0.1490, 0.1436, 0.1046, 0.1643, 0.0530, 0.0867, 0.0000, 0.0000], [0.2122, 0.0867, 0.1490, 0.1436, 0.1046, 0.1643, 0.0530, 0.0867, 0.0000, 0.0000]]], [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5060, 0.4940, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1906, 0.4544, 0.3551, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2248, 0.3297, 0.2760, 0.1695, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2402, 0.3595, 0.2344, 0.1345, 0.0315, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2910, 0.2129, 0.1745, 0.2033, 0.1182, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.6087, 0.3913, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4067, 0.2915, 0.3018, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.6243, 0.2126, 0.0626, 0.1004, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2595, 0.1169, 0.1057, 0.2203, 0.2976, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3793, 0.2129, 0.0953, 0.1353, 0.1772, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]], [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1974, 0.8026, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0573, 0.7695, 0.1732, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2425, 0.2469, 0.2243, 0.2863, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0582, 0.3060, 0.1020, 0.2534, 0.2805, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2365, 0.0940, 0.1170, 0.1607, 0.1994, 0.1924, 0.0000, 0.0000, 0.0000, 0.0000], [0.2211, 0.1005, 0.1168, 0.1351, 0.1324, 0.1622, 0.1319, 0.0000, 0.0000, 0.0000], [0.0578, 0.0552, 0.0343, 0.1360, 0.1647, 0.2134, 0.1195, 0.2190, 0.0000, 0.0000], [0.1113, 0.0718, 0.0556, 0.1093, 0.1570, 0.1479, 0.1085, 0.1003, 0.1382, 0.0000], [0.1347, 0.0591, 0.0766, 0.1019, 0.1016, 0.1202, 0.1034, 0.0783, 0.1279, 0.0962]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4467, 0.5533, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4457, 0.4003, 0.1540, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3552, 0.2717, 0.1685, 0.2047, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2004, 0.2271, 0.1904, 0.1125, 0.2696, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1902, 0.1750, 0.1776, 0.1344, 0.2075, 0.1153, 0.0000, 0.0000, 0.0000, 0.0000], [0.1473, 0.1429, 0.1138, 0.1392, 0.1793, 0.1185, 0.1590, 0.0000, 0.0000, 0.0000], [0.1201, 0.1636, 0.0777, 0.1074, 0.2208, 0.0916, 0.1299, 0.0889, 0.0000, 0.0000], [0.0855, 0.1064, 0.0828, 0.1170, 0.0923, 0.1163, 0.1422, 0.1180, 0.1395, 0.0000], [0.0961, 0.1069, 0.0793, 0.0983, 0.1231, 0.1032, 0.0959, 0.0958, 0.0946, 0.1068]]], [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.5348, 0.4652, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3709, 0.3172, 0.3120, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3255, 0.2521, 0.2574, 0.1650, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2403, 0.2850, 0.2520, 0.2227, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3004, 0.6996, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1829, 0.6025, 0.2146, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1709, 0.3258, 0.1843, 0.3190, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0565, 0.4417, 0.1083, 0.3935, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]], [[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.4827, 0.5173, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1840, 0.2677, 0.5483, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2375, 0.2525, 0.2725, 0.2375, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2304, 0.1787, 0.1814, 0.2304, 0.1790, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1281, 0.1293, 0.1908, 0.1281, 0.2009, 0.2228, 0.0000, 0.0000, 0.0000, 0.0000], [0.1906, 0.1297, 0.1562, 0.1906, 0.1419, 0.1053, 0.0857, 0.0000, 0.0000, 0.0000], [0.1579, 0.1225, 0.1243, 0.1579, 0.1226, 0.1001, 0.0920, 0.1226, 0.0000, 0.0000], [0.1967, 0.1317, 0.1091, 0.1967, 0.0711, 0.0444, 0.0211, 0.0711, 0.1581, 0.0000], [0.0872, 0.1144, 0.0867, 0.0872, 0.1090, 0.1268, 0.1430, 0.1090, 0.1367, 0.0000]], [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3257, 0.6743, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.3952, 0.4099, 0.1949, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1789, 0.5835, 0.0587, 0.1789, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.2402, 0.2141, 0.1595, 0.2402, 0.1460, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.1923, 0.1790, 0.1413, 0.1923, 0.1348, 0.1604, 0.0000, 0.0000, 0.0000, 0.0000], [0.0919, 0.1389, 0.0568, 0.0919, 0.1877, 0.1515, 0.2813, 0.0000, 0.0000, 0.0000], [0.1579, 0.1407, 0.1048, 0.1579, 0.0960, 0.1105, 0.1362, 0.0960, 0.0000, 0.0000], [0.0523, 0.0868, 0.0690, 0.0523, 0.1798, 0.1122, 0.0944, 0.1798, 0.1735, 0.0000], [0.0596, 0.1238, 0.0220, 0.0596, 0.1038, 0.1128, 0.2195, 0.1038, 0.1951, 0.0000]]]], grad_fn=<SoftmaxBackward>) torch.Size([5, 2, 10, 10])
 
 
 
[17]
 
 
 
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)
 
print(attn_values.shape)
 
 
 
 
 
 
torch.Size([5, 2, 10, 4])
 
 
 
 

전체 코드

 
 
 
 
[18]
 
 
 
class MultiheadAttention(nn.Module):
  def __init__(self):
    super(MultiheadAttention, self).__init__()
 
    # Q, K, V learnable matrices
    self.w_q = nn.Linear(d_model, d_model)
    self.w_k = nn.Linear(d_model, d_model)
    self.w_v = nn.Linear(d_model, d_model)
 
    # Linear transformation for concatenated outputs
    self.w_0 = nn.Linear(d_model, d_model)
 
  def forward(self, q, k, v, mask=None):
    batch_size = q.shape[0]
 
    q = self.w_q(q)  # (B, L, d_model)
    k = self.w_k(k)  # (B, L, d_model)
    v = self.w_v(v)  # (B, L, d_model)
 
    q = q.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    k = k.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
    v = v.view(batch_size, -1, num_heads, d_k)  # (B, L, num_heads, d_k)
 
    q = q.transpose(1, 2)  # (B, num_heads, L, d_k)
    k = k.transpose(1, 2)  # (B, num_heads, L, d_k)
    v = v.transpose(1, 2)  # (B, num_heads, L, d_k)
 
    attn_values = self.self_attention(q, k, v, mask=mask)  # (B, num_heads, 
L, d_k)
    attn_values = attn_values.transpose(1, 2).contiguous().view(batch_size, 
-1, d_model)  # (B, L, num_heads, d_k) => (B, L, d_model)
 
    return self.w_0(attn_values)
 
  def self_attention(self, q, k, v, mask=None):
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # 
(B, num_heads, L, L)
 
    if mask is not None:
      mask = mask.unsqueeze(1)  # (B, 1, L, L) or  (B, 1, 1, L)
      attn_scores = attn_scores.masked_fill_(mask == False, -1*inf)
 
    attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, L, L)
 
    attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, L, d_k)
 
    return attn_values
 
 
 
 
 
 
 
 
 
[19]
 
 
 
multihead_attn = MultiheadAttention()
 
outputs = multihead_attn(batch_emb, batch_emb, batch_emb, mask=mask)  # (B, 
L, d_model)
 
 
 
 
 
 
 
 
 
 
[20]
 
 
 
print(outputs)
print(outputs.shape)
 
 
 
 
 
 
tensor([[[ 0.0144, 0.1497, 0.1135, 1.3092, 0.9384, 0.3099, -0.1766, -0.5622], [-0.2898, 0.0896, 0.1427, 0.8015, 0.4785, 0.0681, 0.0933, -0.4136], [-0.3018, -0.0423, -0.0087, 0.4880, 0.5833, -0.1454, 0.1864, -0.2006], [-0.4746, 0.2947, 0.0155, 0.5849, 0.3672, -0.1584, 0.2451, -0.2677], [-0.4331, 0.4161, -0.0199, 0.6722, 0.3977, -0.1417, 0.2226, -0.2819], [-0.2968, 0.3491, -0.1067, 0.6743, 0.5806, -0.0976, 0.1937, -0.2695], [-0.2873, 0.2195, -0.0550, 0.4832, 0.3601, -0.0188, 0.2223, -0.3254], [-0.3359, 0.1678, -0.0062, 0.4428, 0.3006, -0.0507, 0.2521, -0.2963], [-0.2698, 0.2142, -0.0160, 0.5737, 0.4235, -0.0383, 0.2095, -0.2744], [-0.2698, 0.2142, -0.0160, 0.5737, 0.4235, -0.0383, 0.2095, -0.2744]], [[-0.1768, -0.0337, 0.3418, 0.6498, 0.1639, 0.2566, 0.0251, -0.3522], [-0.0263, -0.0877, 0.3811, 0.5641, -0.0212, 0.4380, 0.0564, -0.4133], [-0.0217, -0.1415, 0.3251, 0.3769, -0.1307, 0.4615, 0.1472, -0.4424], [-0.0246, -0.4154, 0.2211, -0.0939, -0.2521, 0.4058, 0.2911, -0.3925], [-0.0876, -0.3316, 0.1590, -0.0801, -0.1808, 0.3044, 0.3026, -0.3570], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688], [-0.0793, -0.3021, 0.1924, -0.0047, -0.1807, 0.3280, 0.2785, -0.3688]], [[ 0.0925, -0.2645, 0.2301, 0.8331, 0.6290, 0.4479, -0.0454, -0.4824], [-0.0289, 0.2087, 0.0284, 0.5535, 0.2514, 0.0563, 0.1032, -0.2719], [ 0.0734, 0.0038, 0.0775, 0.3035, -0.0467, 0.2196, 0.1020, -0.4090], [ 0.1093, -0.1755, 0.1103, 0.3176, 0.1820, 0.2193, 0.1487, -0.2609], [-0.0246, -0.0402, -0.0474, 0.3214, 0.3845, 0.0386, 0.1519, -0.2278], [-0.0277, -0.1558, 0.0172, 0.2918, 0.3786, 0.0965, 0.1986, -0.1995], [-0.0653, -0.2249, 0.0930, 0.1910, 0.1728, 0.1516, 0.2128, -0.2602], [-0.0925, -0.2369, 0.0983, 0.0908, 0.1850, -0.0380, 0.3283, 0.0092], [-0.1878, -0.3238, 0.1693, 0.0309, 0.0890, -0.0245, 0.3728, -0.0209], [-0.1793, -0.2081, 0.0681, 0.0427, 0.1342, -0.0089, 0.3275, -0.1020]], [[-0.0979, -0.3978, 0.2622, -0.3837, -0.6509, 0.4486, 0.4687, -0.3805], [-0.1681, -0.2912, 0.2294, -0.0610, -0.2760, 0.2384, 0.3482, -0.3007], [-0.0446, -0.3829, 0.2550, -0.1610, -0.3536, 0.2851, 0.3518, -0.2386], [-0.1189, -0.2787, 0.0283, -0.2147, -0.1060, 0.1434, 0.3567, -0.2334], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537], [-0.1640, -0.3223, 0.1416, -0.1231, -0.1249, 0.0697, 0.3680, -0.1537]], [[ 0.2386, -0.8116, -0.0987, -0.8957, -0.2766, 0.2600, 0.5507, -0.0985], [-0.0426, -0.6154, -0.0750, -0.5514, -0.0237, 0.1282, 0.4729, -0.1658], [-0.0397, -0.4395, 0.0520, -0.2962, -0.1769, 0.3056, 0.4293, -0.3281], [ 0.0077, -0.5226, -0.0144, -0.4651, -0.1768, 0.2642, 0.4592, -0.2693], [-0.0573, -0.3984, -0.0025, -0.3258, -0.1420, 0.1969, 0.4423, -0.2659], [-0.1514, -0.2093, 0.0159, -0.0258, -0.0172, 0.2169, 0.3462, -0.3935], [-0.1374, -0.1000, -0.0064, 0.1789, 0.1111, 0.2921, 0.2488, -0.5185], [-0.1135, -0.2408, 0.0513, 0.0704, 0.0714, 0.2074, 0.3071, -0.3553], [-0.0981, -0.3065, 0.0201, -0.1530, -0.0656, 0.1096, 0.3669, -0.2580], [-0.0758, -0.3004, 0.0695, 0.0510, 0.1120, 0.1087, 0.3016, -0.2206]]], grad_fn=<AddBackward0>) torch.Size([5, 10, 8])
 
 
 
 

Encoder-Decoder attention

 
 
 
 
 

Query, key, value만 달라질 뿐 구현은 동일합니다.
Decoder에 들어갈 batch만 별도 구현하겠습니다.

 
 
 
 
[21]
 
 
 
trg_data = [
  [33, 11, 49, 10],
  [88, 34, 5, 29, 99, 45, 11, 25],
  [67, 25, 15, 90, 54, 4, 92, 10, 46, 20, 88 ,19],
  [16, 58, 91, 47, 12, 5, 8],
  [71, 63, 62, 7, 9, 11, 55, 91, 32, 48]
]
 
trg_data, trg_max_len = padding(trg_data)
 
 
 
 
 
100%|██████████| 5/5 [00:00<00:00, 4245.25it/s]
Maximum sequence length: 12
 
 
 
 
[22]
 
 
 
# S_L: source maximum sequence length, T_L: target maximum sequence length
src_batch = batch  # (B, S_L)
trg_batch = torch.LongTensor(trg_data)  # (B, T_L)
 
print(src_batch.shape)
print(trg_batch.shape)
 
 
 
 
 
 
torch.Size([5, 10]) torch.Size([5, 12])
 
 
 
[23]
 
 
 
src_emb = embedding(src_batch)  # (B, S_L, d_w)
trg_emb = embedding(trg_batch)  # (B, T_L, d_w)
 
print(src_emb.shape)
print(trg_emb.shape)
 
 
 
 
 
 
torch.Size([5, 10, 8]) torch.Size([5, 12, 8])
 
 
 
 

src_emb를 encoder에서 나온 결과, 그리고 trg_emb를 masked multi-head attention 후 결과로 가정합니다.

 
 
 
 
[24]
 
 
 
q = w_q(trg_emb)  # (B, T_L, d_model)
k = w_k(src_emb)  # (B, S_L, d_model)
v = w_v(src_emb)  # (B, S_L, d_model)
# 디버그모드에서 query엔 디코더입력에서 나온거, k, v엔 인코더에서 나온거
 
batch_size = q.shape[0]
d_k = d_model // num_heads
 
q = q.view(batch_size, -1, num_heads, d_k)  # (B, T_L, num_heads, d_k)
k = k.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)
v = v.view(batch_size, -1, num_heads, d_k)  # (B, S_L, num_heads, d_k)
 
q = q.transpose(1, 2)  # (B, num_heads, T_L, d_k)
k = k.transpose(1, 2)  # (B, num_heads, S_L, d_k)
v = v.transpose(1, 2)  # (B, num_heads, S_L, d_k)
 
print(q.shape)
print(k.shape)
print(v.shape)
 
 
 
 
 
torch.Size([5, 2, 12, 4]) torch.Size([5, 2, 10, 4]) torch.Size([5, 2, 10, 4])
 
 
 
[25]
 
 
 
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)  # (B, 
num_heads, T_L, S_L)
attn_dists = F.softmax(attn_scores, dim=-1)  # (B, num_heads, T_L, S_L)
 
print(attn_dists.shape)
 
 
 
 
 
 
torch.Size([5, 2, 12, 10])
 
 
 
[26]
 
 
 
attn_values = torch.matmul(attn_dists, v)  # (B, num_heads, T_L, d_k)
 
print(attn_values.shape)
 
 
 
 
 
 
torch.Size([5, 2, 12, 4])
 
 
 

Masked multi-head attention 후 나온 결과와 동일한 shape를 가지며 이후 layer에서 전체 연산도 동일하게 진행됩니다.

 

 

==================================

과제 / 퀴즈

https://wikidocs.net/22592

 

 

Natural Language Processing

Assignment 4: Byte Pair Encoding

1. Introduction

  • 일반적으로 하나의 단어에 대해 하나의 embedding을 생성할 경우 out-of-vocabulary(OOV)라는 치명적인 문제를 갖게 됩니다. 학습 데이터에서 등장하지 않은 단어가 나오는 경우 Unknown token으로 처리해주어 모델의 입력으로 넣게 되면서 전체적으로 모델의 성능이 저하될 수 있습니다. 반면 모든 단어의 embedding을 만들기에는 필요한 embedding parameter의 수가 지나치게 많습니다. 이러한 문제를 해결하기 위해 컴퓨터가 이해하는 단어를 표현하는 데에 데이터 압축 알고리즘 중 하나인 byte pair encoding 기법을 적용한 sub-word tokenizaiton이라는 개념이 나타났습니다.
  • 본 과제에서는 byte pair encoding을 이용한 간단한 sub-word tokenizer를 구현해봅니다. 과제 노트북의 지시사항과 각 함수의 docstring과 논문의 3페이지 algorithm 1 참고하여 build_bpe 함수를 완성하고 모든 test case를 통과해주세요.
 
 
 
 
[1]
 
 
 
import re, collections
 
def get_stats(vocab):
    pairs = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    return pairs
 
def merge_vocab(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out
 
vocab = {'l o w </w>' : 5, 'l o w e r </w>' : 2,
'n e w e s t </w>':6, 'w i d e s t </w>':3}
num_merges = 10
for i in range(num_merges):
    print(i)
    pairs = get_stats(vocab)
    print(pairs)
    best = max(pairs, key=pairs.get)
    vocab = merge_vocab(best, vocab)
    print(best)
    print(vocab)
 
 
 
 
 
0 defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 8, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('e', 's'): 9, ('s', 't'): 9, ('t', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3}) ('e', 's') {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3} 1 defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'es'): 6, ('es', 't'): 9, ('t', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'es'): 3}) ('es', 't') {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3} 2 defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est'): 6, ('est', '</w>'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}) ('est', '</w>') {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3} 3 defaultdict(<class 'int'>, {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('l', 'o') {'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3} 4 defaultdict(<class 'int'>, {('lo', 'w'): 7, ('w', '</w>'): 5, ('w', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('lo', 'w') {'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3} 5 defaultdict(<class 'int'>, {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('n', 'e') {'low </w>': 5, 'low e r </w>': 2, 'ne w est</w>': 6, 'w i d est</w>': 3} 6 defaultdict(<class 'int'>, {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('ne', 'w'): 6, ('w', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('ne', 'w') {'low </w>': 5, 'low e r </w>': 2, 'new est</w>': 6, 'w i d est</w>': 3} 7 defaultdict(<class 'int'>, {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('new', 'est</w>'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('new', 'est</w>') {'low </w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'w i d est</w>': 3} 8 defaultdict(<class 'int'>, {('low', '</w>'): 5, ('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('low', '</w>') {'low</w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'w i d est</w>': 3} 9 defaultdict(<class 'int'>, {('low', 'e'): 2, ('e', 'r'): 2, ('r', '</w>'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est</w>'): 3}) ('w', 'i') {'low</w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'wi d est</w>': 3}
 
 
 
 

2-1.build_bpe 함수를 완성해주세요.

 
 
 
 
[5]
 
 
 
from collections import defaultdict, Counter
 
corpus = ['low'] * 5 + ['lower'] * 2 + ['newest'] * 6 + ['widest'] * 3
vocab = Counter(corpus)
print(vocab)
print(dict(vocab))
 
 
 
 
 
 
Counter({'newest': 6, 'low': 5, 'widest': 3, 'lower': 2}) {'low': 5, 'lower': 2, 'newest': 6, 'widest': 3}
 
 
 
[51]
 
 
 
from typing import List, Dict, Set
from itertools import chain
import re
from collections import defaultdict, Counter
 
 
def build_bpe(
        corpus: List[str],
        max_vocab_size: int
) -> List[int]:
    """ BPE Vocabulary Builder
    Implement vocabulary builder for byte pair encoding.
    Please sort your idx2word by subword length in descending manner.
 
    Hint: Counter in collection library would be helpful
 
    Note: If you convert sentences list to word frequence dictionary,
          building speed is enhanced significantly because duplicated words 
are
          preprocessed together
 
    Arguments:
    corpus -- List of words to build vocab
    max_vocab_size -- The maximum size of vocab
 
    Return:
    idx2word -- Subword list
    """
    # Special tokens
    PAD = BytePairEncoding.PAD_token  # Index of <PAD> must be 0
    UNK = BytePairEncoding.UNK_token  # Index of <UNK> must be 1
    CLS = BytePairEncoding.CLS_token  # Index of <CLS> must be 2
    SEP = BytePairEncoding.SEP_token  # Index of <SEP> must be 3
    MSK = BytePairEncoding.MSK_token  # Index of <MSK> must be 4
    SPECIAL = [PAD, UNK, CLS, SEP, MSK]
 
    WORD_END = BytePairEncoding.WORD_END  # Use this token as the end of a 
word
    # YOUR CODE HERE
    
    idx2word = set()
    for s in corpus:
        for c in s:
            idx2word.add(c)
 
    stringList = [' '.join(s+WORD_END) for s in corpus]
    vocab = dict(Counter(stringList))
    while len(idx2word) < max_vocab_size-6:
        pairs = get_stats(vocab)
        if not pairs: break
        best = max(pairs, key=pairs.get)
        vocab = merge_vocab(best, vocab)
        idx2word.add(''.join(best))
    
    idx2word = SPECIAL + sorted(list(idx2word),key=len,reverse=True) + 
[WORD_END]
    print(idx2word)
    return idx2word
 
 
 
 
 
 
 
 
 
 

2-2. build_bpe 함수 평가

 
 
 
[52]
 
 
 
#############################################
# Helper functions below. DO NOT MODIFY!    #
#############################################
 
class BytePairEncoding(object):
    """ Byte Pair Encoding class
    We aren't gonna use this class for encoding. Because it is too 
slow......
    We will use sentence piece Google have made.
    Thus, this class is just for special token index reference.
    """
    PAD_token = '<pad>'
    PAD_token_idx = 0
    UNK_token = '<unk>'
    UNK_token_idx = 1
    CLS_token = '<cls>'
    CLS_token_idx = 2
    SEP_token = '<sep>'
    SEP_token_idx = 3
    MSK_token = '<msk>'
    MSK_token_idx = 4
 
    WORD_END = '_'
 
    def __init__(self, corpus: List[List[str]], max_vocab_size: int) -> 
None:
        self.idx2word = build_bpe(corpus, max_vocab_size)
 
    def encode(self, sentence: List[str]) -> List[int]:
        return encode(sentence, self.idx2word)
 
    def decoder(self, tokens: List[int]) -> List[str]:
        return decode(tokens, self.idx2word)
 
 
#############################################
# Testing functions below.                  #
#############################################
 
 
def test_build_bpe():
    print("======Building BPE Vocab Test Case======")
    PAD = BytePairEncoding.PAD_token
    UNK = BytePairEncoding.UNK_token
    CLS = BytePairEncoding.CLS_token
    SEP = BytePairEncoding.SEP_token
    MSK = BytePairEncoding.MSK_token
    WORD_END = BytePairEncoding.WORD_END
 
    # First test
    corpus = ['abcde']
    vocab = build_bpe(corpus, max_vocab_size=15)
    assert vocab[:5] == [PAD, UNK, CLS, SEP, MSK], \
        "Please insert the special tokens properly"
    print("The first test passed!")
 
    # Second test
    assert sorted(vocab[5:], key=len, reverse=True) == vocab[5:], \
        "Please sort your idx2word by subword length in decsending manner."
    print("The second test passed!")
 
    # Third test
    corpus = ['low'] * 5 + ['lower'] * 2 + ['newest'] * 6 + ['widest'] * 3
    vocab = set(build_bpe(corpus, max_vocab_size=24))
    assert vocab > {PAD, UNK, CLS, SEP, MSK, 'est_', 'low', 'newest_', \
                    'i', 'e', 'n', 't', 'd', 's', 'o', 'l', 'r', 'w',
                    WORD_END} and \
           "low_" not in vocab and "wi" not in vocab and "id" not in vocab, 
\
        "Your bpe result does not match expected result"
    print("The third test passed!")
 
    # forth test
    corpus = ['aaaaaaaaaaaa', 'abababab']
    vocab = set(build_bpe(corpus, max_vocab_size=13))
    assert vocab == {PAD, UNK, CLS, SEP, MSK, 'aaaaaaaa', 'aaaa', 'abab', 
'aa',
                     'ab', 'a', 'b', WORD_END}, \
        "Your bpe result does not match expected result"
    print("The forth test passed!")
 
    # fifth test
    corpus = ['abc', 'bcd']
    vocab = build_bpe(corpus, max_vocab_size=10000)
    assert len(vocab) == 15, \
        "Your bpe result does not match expected result"
    print("The fifth test passed!")
 
    print("All 5 tests passed!")
test_build_bpe()
 
 
 
 
 
======Building BPE Vocab Test Case====== ['<pad>', '<unk>', '<cls>', '<sep>', '<msk>', 'abcde', 'abcd', 'abc', 'ab', 'c', 'a', 'e', 'b', 'd', '_'] The first test passed! The second test passed! ['<pad>', '<unk>', '<cls>', '<sep>', '<msk>', 'newest_', 'est_', 'low', 'new', 'est', 'lo', 'es', 'ne', 'o', 'w', 'l', 'e', 's', 'i', 't', 'r', 'n', 'd', '_'] The third test passed! ['<pad>', '<unk>', '<cls>', '<sep>', '<msk>', 'aaaaaaaa', 'aaaa', 'abab', 'aa', 'ab', 'a', 'b', '_'] The forth test passed! ['<pad>', '<unk>', '<cls>', '<sep>', '<msk>', 'abc_', 'bcd_', 'abc', 'bcd', 'bc', 'c', 'a', 'b', 'd', '_'] The fifth test passed! All 5 tests passed!

 

 

==================================

피어세션

복습 및 DACON 데이터 분석 및 전처리 해봄.

 

 

===================================

후기

피곤해