الانتباه متعدد الرؤوس
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, q, k, v, mask=None):
B = q.size(0)
Q = self.W_q(q).view(B, -1, self.n_heads, self.d_k).transpose(1,2)
K = self.W_k(k).view(B, -1, self.n_heads, self.d_k).transpose(1,2)
V = self.W_v(v).view(B, -1, self.n_heads, self.d_k).transpose(1,2)
scores = torch.matmul(Q, K.transpose(-2,-1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask==0, -1e9)
attn = F.softmax(scores, dim=-1)
return self.W_o(torch.matmul(attn, V).transpose(1,2).contiguous().view(B,-1,self.n_heads*self.d_k))