欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 家装 > torch_bmm验算及代码测试

torch_bmm验算及代码测试

2025/2/11 23:44:13 来源:https://blog.csdn.net/scar2016/article/details/145529531  浏览:    关键词:torch_bmm验算及代码测试

文章目录

  • 1. torch_bmm
  • 2. pytorch源码

1. torch_bmm

torch.bmm的作用是基于batch_size的矩阵乘法,torch.bmm的作用是对应batch位置的矩阵相乘,比如,

  • mat1的第1个位置和mat2的第1个位置进行矩阵相乘得到mat3的第1个位置
  • mat1的第2个位置和mat2的第2个位置进行矩阵相乘得到mat3的第2个位置
    在这里插入图片描述

2. pytorch源码

import torch
import torch.nn as nn
import torch.nn.functional as Ftorch.set_printoptions(precision=3, sci_mode=False)if __name__ == "__main__":run_code = 0batch_size = 2mat1_h = 3mat1_w = 4mat1_total = batch_size * mat1_w * mat1_hmat2_h = 4mat2_w = 5mat2_total = batch_size * mat2_w * mat2_hmat1 = torch.arange(mat1_total).reshape((batch_size, mat1_h, mat1_w))mat2 = torch.arange(mat2_total).reshape((batch_size, mat2_h, mat2_w))mat3 = torch.bmm(mat1, mat2)print(f"mat1=\n{mat1}")print(f"mat2=\n{mat2}")print(f"mat3=\n{mat3}")
  • 结果:
mat1=
tensor([[[ 0,  1,  2,  3],[ 4,  5,  6,  7],[ 8,  9, 10, 11]],[[12, 13, 14, 15],[16, 17, 18, 19],[20, 21, 22, 23]]])
mat2=
tensor([[[ 0,  1,  2,  3,  4],[ 5,  6,  7,  8,  9],[10, 11, 12, 13, 14],[15, 16, 17, 18, 19]],[[20, 21, 22, 23, 24],[25, 26, 27, 28, 29],[30, 31, 32, 33, 34],[35, 36, 37, 38, 39]]])
mat3=
tensor([[[  70,   76,   82,   88,   94],[ 190,  212,  234,  256,  278],[ 310,  348,  386,  424,  462]],[[1510, 1564, 1618, 1672, 1726],[1950, 2020, 2090, 2160, 2230],[2390, 2476, 2562, 2648, 2734]]])

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com