我们实现的目标是做image classification,使用MINIST数据集。


import numpy as npfrom tqdm import tqdm, trangeimport torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoaderfrom torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST
import platformnp.random.seed(0)


def main():# Loading datatransform = ToTensor()train_set = MNIST(root='./../datasets', train=True, download=True, transform=transform)test_set = MNIST(root='./../datasets', train=False, download=True, transform=transform)train_loader = DataLoader(train_set, shuffle=True, batch_size=128)test_loader = DataLoader(test_set, shuffle=False, batch_size=128)# Defining model and training optionsif platform.system() == 'Darwin':# MacOS系统,使用MPS后端if torch.backends.mps.is_available():device = torch.device('mps')else:device = torch.device('cpu')else:# Linux或Windows系统,使用CUDA后端if torch.cuda.is_available():device = torch.device('cuda')else:device = torch.device('cpu')print('Device:', device)model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)N_EPOCHS = 5LR = 0.005# Training loopoptimizer = Adam(model.parameters(), lr=LR)criterion = CrossEntropyLoss()for epoch in trange(N_EPOCHS, desc="Training"):train_loss = 0.0for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):x, y = batchx, y = x.to(device), y.to(device)y_hat = model(x)loss = criterion(y_hat, y)train_loss += loss.detach().cpu().item() / len(train_loader)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")# Test loopwith torch.no_grad():correct, total = 0, 0test_loss = 0.0for batch in tqdm(test_loader, desc="Testing"):x, y = batchx, y = x.to(device), y.to(device)y_hat = model(x)loss = criterion(y_hat, y)test_loss += loss.detach().cpu().item() / len(test_loader)correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()total += len(x)print(f"Test loss: {test_loss:.2f}")print(f"Test accuracy: {correct / total * 100:.2f}%")

现在我们有了这个模板,从现在开始,我们可以只关注模型(ViT),它将对形状为(N x 1 x 28 x 28)的图像进行分类。

class MyViT(nn.Module):def __init__(self):# Super constructorsuper(MyViT, self).__init__()def forward(self, images):pass



我们将实现Bazi等人的paper Vision Transformers for Remote Sensing Image Classification中的vit的结构,如下图所示

根据图片,我们可以看到输入图像(a)被“切割”成等大小的子图像。 每个这样的子图像都通过一个Linear Embedding。经过Linear Embedding之后,每个子图像只是一个一维向量。

然后向这些向量(标记)添加Positional Embedding。Positional Embedding允许网络知道每个子图像最初在图像中的位置。没有这些信息,网络将无法知道每个这样的图像将被放置在哪里,从而导致可能的错误预测。

然后,这些标记连同一个特殊的分类标记一起传递给Transformer Encoder,每个Encoder由:层归一化(LN),后接多头自注意力(MSA)和残差连接。然后是第二个LN,一个多层感知器(MLP),再次是残差连接。最后,使用分类MLP块对最终分类进行处理,仅在特殊的分类标记上进行,该标记在此过程结束时具有关于图片的全局信息。



Transformer Encoder是针对序列数据开发的,例如英语句子。然而,图像并不是序列。我们将图像分解成多个子图像来实现序列话,并将每个子图像映射到一个向量。






def patchify(images, n_patches):n, c, h, w = images.shapeassert h == w, "Patchify method is implemented for square images only"patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)patch_size = h // n_patchesfor idx, image in enumerate(images):for i in range(n_patches):for j in range(n_patches):patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]patches[idx, i * n_patches + j] = patch.flatten()return patches
class MyViT(nn.Module):def __init__(self, chw=(1, 28, 28), n_patches=7):# Super constructorsuper(MyViT, self).__init__()# Attributesself.chw = chw # (C, H, W)self.n_patches = n_patchesassert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"def forward(self, images):patches = patchify(images, self.n_patches)return patches

类的构造函数现在让类知道我们输入图像的大小(通道数、高度和宽度)。在我们的实现中,n_patches 变量表示我们在宽度和高度单一方向上的块的数量,实际上要切分的块的数量是n_patches^2个块(在我们的例子中是7,因为我们将图像分割成7x7的块)。


if __name__ == '__main__':# Current modelmodel = MyViT(chw=(1, 28, 28),n_patches=7)x = torch.randn(7, 1, 28, 28) # Dummy imagesprint(model(x).shape) # torch.Size([7, 49, 16])

现在我们已经得到了展平的块,我们可以通过线性映射将它们一一映射。虽然每个块是一个4x4=16维的向量,但线性映射可以映射到任何任意的向量大小。因此,我们在类构造函数中添加了一个参数,称为 hidden_d,代表“隐藏维度”。


我们只需创建一个 nn.Linear 层,并在前向函数中调用它。

class MyViT(nn.Module):def __init__(self, chw=(1, 28, 28), n_patches=7):# Super constructorsuper(MyViT, self).__init__()# Attributesself.chw = chw # (C, H, W)self.n_patches = n_patchesassert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)# 1) Linear mapperself.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)def forward(self, images):patches = patchify(images, self.n_patches)tokens = self.linear_mapper(patches)return tokens

注意,我们通过一个(16, 8)的线性映射层处理一个(N, 49, 16)的张量。线性操作仅在最后一个维度上发生。


如果你仔细观察架构图,我们会发现还有一个“v_class”标记传递给Transformer Encoder。 这是一个我们添加到模型中的特殊标记,它的作用是捕获关于其他标记的信息。当所有其他标记的信息都汇聚在这里时,我们将能够仅使用这个特殊标记来对图像进行分类。v_class初始值是模型的一个参数,参与网络的学习过程。


我们现在可以向我们的模型添加一个参数,并将我们的(N, 49, 8)标记张量转换为(N, 50, 8)张量(我们在每个序列中添加了特殊标记)。

class MyViT(nn.Module):def __init__(self, chw=(1, 28, 28), n_patches=7):# Super constructorsuper(MyViT, self).__init__()# Attributesself.chw = chw # (C, H, W)self.n_patches = n_patchesassert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)# 1) Linear mapperself.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)# 2) Learnable classifiation tokenself.class_token = nn.Parameter(torch.rand(1, self.hidden_d))def forward(self, images):patches = patchify(images, self.n_patches)tokens = self.linear_mapper(patches)# Adding classification token to the tokenstokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])return tokens



正如预期的那样,位置编码允许模型理解每个块在原始图像中的位置。虽然理论上可以学习这样的位置嵌入,但Vaswani等人之前的工作Attention Is All You Need表明,我们可以直接添加正弦和余弦波。




下面的函数是其一个简单的实现。它根据标记的数量和每个标记的维度,输出一个矩阵,其中每个坐标(i,j)是要添加到第i个标记在第j个维度上要添加到token i上的值。

def get_positional_embeddings(sequence_length, d):result = torch.ones(sequence_length, d)for i in range(sequence_length):for j in range(d):result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))return resultif __name__ == "__main__":import matplotlib.pyplot as pltplt.imshow(get_positional_embeddings(100, 300), cmap="hot", interpolation="nearest")plt.show()



class MyViT(nn.Module):def __init__(self, chw=(1, 28, 28), n_patches=7):# Super constructorsuper(MyViT, self).__init__()# Attributesself.chw = chw # (C, H, W)self.n_patches = n_patchesassert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)# 1) Linear mapperself.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)# 2) Learnable classifiation tokenself.class_token = nn.Parameter(torch.rand(1, self.hidden_d))# 3) Positional embeddingself.pos_embed = nn.Parameter(torch.tensor(get_positional_embeddings(self.n_patches ** 2 + 1, self.hidden_d)))self.pos_embed.requires_grad = Falsedef forward(self, images):patches = patchify(images, self.n_patches)tokens = self.linear_mapper(patches)# Adding classification token to the tokenstokens = torch.stack([torch.vstack((self.class_token, tokens[i])) for i in range(len(tokens))])# Adding positional embeddingpos_embed = self.pos_embed.repeat(n, 1, 1)out = tokens + pos_embedreturn out

我们将位置嵌入定义为模型的一个参数(我们通过设置其 requires_gradFalse 不更新它)。注意,在前向方法中,由于标记的大小为 (N, 50, 8),我们必须将 (50, 8) 的位置编码矩阵重复 N 次。

步骤四:encoder block(part 1/2)

这是最核心的一部分。encoder block以当前张量[N, S, D]作为输入,并输出相同维度的张量。
encoder block的第一部分对我们的标记应用层归一化,然后是多头自注意力机制,最后加上一个残差连接。

多头自注意力(Multi-head Self Attention)是Transformer架构中的关键组成部分,它允许模型在处理一个图像时,每个块(patch)根据与其他块的相似性度量来更新自己。具体来说,每个块(在我们的示例中是一个8维向量)通过线性映射被转换为三个不同的向量:q(query),k(key)和v(value)。



由于进行了大量的计算,创建一个新的MSA(Multi-head Self Attention)类。这样,我们可以将多头自注意力的实现封装在一个类中,以便于管理和维护代码。这个类将包含必要的属性和方法来执行上述的多头自注意力计算,包括线性映射、缩放点积注意力的计算、softmax归一化以及最终的输出拼接和线性变换。

class MyMSA(nn.Module):def __init__(self, d, n_heads=2):super(MyMSA, self).__init__()self.d = dself.n_heads = n_headsassert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"d_head = int(d / n_heads)self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])self.d_head = d_headself.softmax = nn.Softmax(dim=-1)def forward(self, sequences):# Sequences has shape (N, seq_length, token_dim)# We go into shape    (N, seq_length, n_heads, token_dim / n_heads)# And come back to    (N, seq_length, item_dim)  (through concatenation)result = []for sequence in sequences:seq_result = []for head in range(self.n_heads):q_mapping = self.q_mappings[head]k_mapping = self.k_mappings[head]v_mapping = self.v_mappings[head]seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)attention = self.softmax(q @ k.T / (self.d_head ** 0.5))seq_result.append(attention @ v)result.append(torch.hstack(seq_result))return torch.cat([torch.unsqueeze(r, dim=0) for r in result])

由于我们的输入将是大小为(N, 50, 8)的序列,并且我们只使用2个头,我们将在某个时候得到一个(N, 50, 2, 4)的张量,对它使用一个nn.Linear(4, 4)模块,然后在拼接后返回到一个(N, 50, 8)的张量。

我们后面添加一个残差连接,将我们的原始(N, 50, 8)张量与LN(层归一化)和MSA(多头自注意力)后得到的(N, 50, 8)张量相加。


class MyViTBlock(nn.Module):def __init__(self, hidden_d, n_heads, mlp_ratio=4):super(MyViTBlock, self).__init__()self.hidden_d = hidden_dself.n_heads = n_headsself.norm1 = nn.LayerNorm(hidden_d)self.mhsa = MyMSA(hidden_d, n_heads)def forward(self, x):out = x + self.mhsa(self.norm1(x))return out


步骤五:encoder block(part 2/2)


class MyViTBlock(nn.Module):def __init__(self, hidden_d, n_heads, mlp_ratio=4):super(MyViTBlock, self).__init__()self.hidden_d = hidden_dself.n_heads = n_headsself.norm1 = nn.LayerNorm(hidden_d)self.mhsa = MyMSA(hidden_d, n_heads)self.norm2 = nn.LayerNorm(hidden_d)self.mlp = nn.Sequential(nn.Linear(hidden_d, mlp_ratio * hidden_d),nn.GELU(),nn.Linear(mlp_ratio * hidden_d, hidden_d))def forward(self, x):out = x + self.mhsa(self.norm1(x))out = out + self.mlp(self.norm2(out))return out



class MyViT(nn.Module):def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):# Super constructorsuper(MyViT, self).__init__()# Attributesself.chw = chw # ( C , H , W )self.n_patches = n_patchesself.n_blocks = n_blocksself.n_heads = n_headsself.hidden_d = hidden_d# Input and patches sizesassert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)# 1) Linear mapperself.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)# 2) Learnable classification tokenself.class_token = nn.Parameter(torch.rand(1, self.hidden_d))# 3) Positional embeddingself.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)# 4) Transformer encoder blocksself.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])def forward(self, images):# Dividing images into patchesn, c, h, w = images.shapepatches = patchify(images, self.n_patches).to(self.positional_embeddings.device)# Running linear layer tokenization# Map the vector corresponding to each patch to the hidden size dimensiontokens = self.linear_mapper(patches)# Adding classification token to the tokenstokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)# Adding positional embeddingout = tokens + self.positional_embeddings.repeat(n, 1, 1)# Transformer Blocksfor block in self.blocks:out = block(out)return out

如果我们通过我们的模型运行一个随机的(7, 1, 28, 28)张量,我们会得到一个(7, 50, 8)张量。




class MyViT(nn.Module):def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):# Super constructorsuper(MyViT, self).__init__()# Attributesself.chw = chw # ( C , H , W )self.n_patches = n_patchesself.n_blocks = n_blocksself.n_heads = n_headsself.hidden_d = hidden_d# Input and patches sizesassert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)# 1) Linear mapperself.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)# 2) Learnable classification tokenself.class_token = nn.Parameter(torch.rand(1, self.hidden_d))# 3) Positional embeddingself.register_buffer('positional_embeddings', get_positional_embeddings(n_patches ** 2 + 1, hidden_d), persistent=False)# 4) Transformer encoder blocksself.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])# 5) Classification MLPkself.mlp = nn.Sequential(nn.Linear(self.hidden_d, out_d),nn.Softmax(dim=-1))def forward(self, images):# Dividing images into patchesn, c, h, w = images.shapepatches = patchify(images, self.n_patches).to(self.positional_embeddings.device)# Running linear layer tokenization# Map the vector corresponding to each patch to the hidden size dimensiontokens = self.linear_mapper(patches)# Adding classification token to the tokenstokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)# Adding positional embeddingout = tokens + self.positional_embeddings.repeat(n, 1, 1)# Transformer Blocksfor block in self.blocks:out = block(out)# Getting the classification token onlyout = out[:, 0]return self.mlp(out) # Map to output dimension, output category distribution

我们的模型的输出现在是一个(N, 10)的张量。



model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)


torch.manual_seed(0)def patchify(images, n_patches):n, c, h, w = images.shapeassert h == w, "Patchify method is implemented for square images only"patches = torch.zeros(n, n_patches**2, h * w * c // n_patches**2)patch_size = h // n_patchesfor idx, image in enumerate(images):for i in range(n_patches):for j in range(n_patches):patch = image[:,i * patch_size : (i + 1) * patch_size,j * patch_size : (j + 1) * patch_size,]patches[idx, i * n_patches + j] = patch.flatten()return patchesclass MyMSA(nn.Module):def __init__(self, d, n_heads=2):super(MyMSA, self).__init__()self.d = dself.n_heads = n_headsassert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"d_head = int(d / n_heads)self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])self.d_head = d_headself.softmax = nn.Softmax(dim=-1)def forward(self, sequences):# Sequences has shape (N, seq_length, token_dim)# We go into shape    (N, seq_length, n_heads, token_dim / n_heads)# And come back to    (N, seq_length, item_dim)  (through concatenation)result = []for sequence in sequences:seq_result = []for head in range(self.n_heads):q_mapping = self.q_mappings[head]k_mapping = self.k_mappings[head]v_mapping = self.v_mappings[head]seq = sequence[:, head * self.d_head : (head + 1) * self.d_head]q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)attention = self.softmax(q @ k.T / (self.d_head**0.5))seq_result.append(attention @ v)result.append(torch.hstack(seq_result))return torch.cat([torch.unsqueeze(r, dim=0) for r in result])class MyViTBlock(nn.Module):def __init__(self, hidden_d, n_heads, mlp_ratio=4):super(MyViTBlock, self).__init__()self.hidden_d = hidden_dself.n_heads = n_headsself.norm1 = nn.LayerNorm(hidden_d)self.mhsa = MyMSA(hidden_d, n_heads)self.norm2 = nn.LayerNorm(hidden_d)self.mlp = nn.Sequential(nn.Linear(hidden_d, mlp_ratio * hidden_d),nn.GELU(),nn.Linear(mlp_ratio * hidden_d, hidden_d),)def forward(self, x):out = x + self.mhsa(self.norm1(x))out = out + self.mlp(self.norm2(out))return outclass MyViT(nn.Module):def __init__(self, chw, n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10):# Super constructorsuper(MyViT, self).__init__()# Attributesself.chw = chw  # ( C , H , W )self.n_patches = n_patchesself.n_blocks = n_blocksself.n_heads = n_headsself.hidden_d = hidden_d# Input and patches sizesassert (chw[1] % n_patches == 0), "Input shape not entirely divisible by number of patches"assert (chw[2] % n_patches == 0), "Input shape not entirely divisible by number of patches"self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)# 1) Linear mapperself.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)# 2) Learnable classification tokenself.class_token = nn.Parameter(torch.rand(1, self.hidden_d))# 3) Positional embeddingself.register_buffer("positional_embeddings",get_positional_embeddings(n_patches**2 + 1, hidden_d),persistent=False,)# 4) Transformer encoder blocksself.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])# 5) Classification MLPkself.mlp = nn.Sequential(nn.Linear(self.hidden_d, out_d), nn.Softmax(dim=-1))def forward(self, images):# Dividing images into patchesn, c, h, w = images.shapepatches = patchify(images, self.n_patches).to(self.positional_embeddings.device)# Running linear layer tokenization# Map the vector corresponding to each patch to the hidden size dimensiontokens = self.linear_mapper(patches)# Adding classification token to the tokenstokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)# Adding positional embeddingout = tokens + self.positional_embeddings.repeat(n, 1, 1)# Transformer Blocksfor block in self.blocks:out = block(out)# Getting the classification token onlyout = out[:, 0]return self.mlp(out)  # Map to output dimension, output category distributiondef get_positional_embeddings(sequence_length, d):result = torch.ones(sequence_length, d)for i in range(sequence_length):for j in range(d):result[i][j] = (np.sin(i / (10000 ** (j / d)))if j % 2 == 0else np.cos(i / (10000 ** ((j - 1) / d))))return resultdef main():# Loading datatransform = ToTensor()train_set = MNIST(root="./../datasets", train=True, download=True, transform=transform)test_set = MNIST(root="./../datasets", train=False, download=True, transform=transform)train_loader = DataLoader(train_set, shuffle=True, batch_size=128)test_loader = DataLoader(test_set, shuffle=False, batch_size=128)# Defining model and training optionsdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print("Using device: ",device,f"({torch.cuda.get_device_name(device)})" if torch.cuda.is_available() else "",)model = MyViT((1, 28, 28), n_patches=7, n_blocks=2, hidden_d=8, n_heads=2, out_d=10).to(device)N_EPOCHS = 5LR = 0.005# Training loopoptimizer = Adam(model.parameters(), lr=LR)criterion = CrossEntropyLoss()for epoch in trange(N_EPOCHS, desc="Training"):train_loss = 0.0for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1} in training", leave=False):x, y = batchx, y = x.to(device), y.to(device)y_hat = model(x)loss = criterion(y_hat, y)train_loss += loss.detach().cpu().item() / len(train_loader)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch {epoch + 1}/{N_EPOCHS} loss: {train_loss:.2f}")# Test loopwith torch.no_grad():correct, total = 0, 0test_loss = 0.0for batch in tqdm(test_loader, desc="Testing"):x, y = batchx, y = x.to(device), y.to(device)y_hat = model(x)loss = criterion(y_hat, y)test_loss += loss.detach().cpu().item() / len(test_loader)correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()total += len(x)print(f"Test loss: {test_loss:.2f}")print(f"Test accuracy: {correct / total * 100:.2f}%")if __name__ == "__main__":main()
Using device:  cpu 
Training:   0%|                                                                            | 0/5 [00:00<?, ?it/sEpoch 1/5 loss: 2.11                                                                                              
Training:  20%|█████████████▌                                                      | 1/5 [00:37<02:31, 37.76s/itEpoch 2/5 loss: 1.84                                                                                              
Training:  40%|███████████████████████████▏                                        | 2/5 [01:16<01:54, 38.13s/itEpoch 3/5 loss: 1.76                                                                                              
Training:  60%|████████████████████████████████████████▊                           | 3/5 [01:54<01:16, 38.11s/itEpoch 4/5 loss: 1.72                                                                                              
Training:  80%|██████████████████████████████████████████████████████▍             | 4/5 [02:32<00:38, 38.11s/itEpoch 5/5 loss: 1.71                                                                                              
Training: 100%|████████████████████████████████████████████████████████████████████| 5/5 [03:11<00:00, 38.27s/it]
Testing: 100%|███████████████████████████████████████████████████████████████████| 79/79 [00:03<00:00, 23.90it/s]
Test loss: 1.69
Test accuracy: 77.38%




