很不错的transformer 的学习仓库:https://github.com/tianxinliao/Transformer-learning,记录一下自用
ref:https://blog.csdn.net/zhaohongfei_358/article/details/125273126
在学习transformer的时候,看到代码里面有
values = self.values(values) # (N, value_len, embed_size)keys = self.keys(keys) # (N, key_len, embed_size)queries = self.queries(query) # (N, query_len, embed_size)# Split the embedding into self.heads different piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = queries.reshape(N, query_len, self.heads, self.head_dim)# Einsum does matrix mult. for query*keys for each training example# with every other training example, don't be confused by einsum# it's just how I like doing matrix multiplication & bmmenergy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])# queries shape: (N, query_len, heads, heads_dim),# keys shape: (N, key_len, heads, heads_dim)# energy: (N, heads, query_len, key_len)
把我看蒙了,所以这次正经学习一下,看看咋回事。这个颇有一些只可意会不可言传的感觉,还是人菜瘾大,理解不深啊!
einsum 在numpy和torch中都有,借助了index–>(求和)
import torch
import torch.nn as nn
import torch.optim as optim
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape) # 矩阵乘法
print(torch.einsum('ij,kj->ki', x, v).shape) # 矩阵乘法 + T
print(torch.einsum('ij,km->ijkm', x, v).shape) # 这个算是一个拼接吧
x = torch.rand((2, 3))
v = torch.rand((1, 3))
print(torch.einsum('ij,kj->ik', x, v).shape)
print(torch.einsum('ij,kj->ki', x, v).shape)
print(torch.einsum('ij,km->ijkm', x, v).shape)
import torch
x = torch.tensor([[1, 2, 3],[4,5,6]])
y = torch.tensor([[7,8,9]])
x,y
(tensor([[1, 2, 3],[4, 5, 6]]),tensor([[7, 8, 9]]))
result = torch.einsum('ij,km->ijkm', x, y)
result
tensor([[[[ 7, 8, 9]],[[14, 16, 18]],[[21, 24, 27]]],[[[28, 32, 36]],[[35, 40, 45]],[[42, 48, 54]]]])
a = [[[1, 2], # i=0[3, 4]], # i=0[[5, 6], # i=1[7, 8]] # i=1
]b = [[[9, 10, 11], # i=0[12, 13, 14]], # i=0[[15, 16, 17], # i=1[18, 19, 20]] # i=1
]
torch.tensor(a[0]).shape,torch.tensor(b[0]).shape
torch.tensor(a[0]).shape,torch.tensor(b[0]).shape
torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[0]) @ torch.tensor(b[0])
torch.tensor(a[1]) @ torch.tensor(b[1])
tensor([[183, 194, 205],[249, 264, 279]])
res = []
for i in range(len(a)):a1 = torch.tensor(a[i])b1 = torch.tensor(b[i])res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
res = []
for i in range(len(a)):a1 = torch.tensor(a[i])b1 = torch.tensor(b[i])res.append(a1@b1)
res1 = torch.stack(res)
print(res,"\n",res1)
x = torch.rand(3, 3)
torch.einsum('ii->i', x),x
(tensor([0.7127, 0.3843, 0.2046]),tensor([[0.7127, 0.0171, 0.9940],[0.6781, 0.3843, 0.9031],[0.4963, 0.1581, 0.2046]]))