欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 焦点 > C#,数值计算,矩阵相乘的斯特拉森(Strassen’s Matrix Multiplication)分治算法与源代码

C#,数值计算,矩阵相乘的斯特拉森(Strassen’s Matrix Multiplication)分治算法与源代码

2025/1/15 2:40:31 来源:https://blog.csdn.net/beijinghorn/article/details/125074303  浏览:    关键词:C#,数值计算,矩阵相乘的斯特拉森(Strassen’s Matrix Multiplication)分治算法与源代码

Volker Strassen

矩阵乘法

矩阵乘法机器学习中最基本的运算之一,对其进行优化是多种优化的关键。通常,将两个大小为N X N的矩阵相乘需要N^3次运算。从那以后,我们在更好、更聪明的矩阵乘法算法方面取得了长足的进步。沃尔克·斯特拉森于1969年首次发表了他的算法。这是第一个证明基本O(n^3)运行时不是optiomal的算法。

Strassen算法的基本思想是将A和B分为8个子矩阵,然后递归计算C的子矩阵。这种策略称为分而治之。

2 伪代码

  1. 如上图所示,将矩阵A和B划分为大小为N/2 x N/2的4个子矩阵。
  2. 递归计算7个矩阵乘法。
  3. 计算C的子矩阵。
  4. 将这些子矩阵组合到我们的新矩阵C中

3 复杂性

  1. 最坏情况时间复杂度:Θ(n^2.8074)
  2. 最佳情况时间复杂度:Θ(1)
  3. 空间复杂度:Θ(logn)

年青时正在发愁的  Volker Strassen

4 算法的详细解释

矩阵相乘在进行3D变换的时候是经常用到的。在应用中常用矩阵相乘的定义算法对其进行计算。这个算法用到了大量的循环和相乘运算,这使得算法效率不高。而矩阵相乘的计算效率很大程度上的影响了整个程序的运行速度,所以对矩阵相乘算法进行一些改进是必要的。

        我们先讨论二阶矩阵的计算方法。

        对于二阶矩阵

        a11    a12                    b11    b12    
        A =    a21    a22    B =    b21    b22
        先计算下面7个量(1)

        x1 = (a11 + a22) * (b11 + b22);
        x2 = (a21 + a22) * b11;
        x3 = a11 * (b12 - b22);
        x4 = a22 * (b21 - b11);
        x5 = (a11 + a12) * b22;
        x6 = (a21 - a11) * (b11 + b12);
        x7 = (a12 - a22) * (b21 + b22);
        再设C = AB。根据矩阵相乘的规则,C的各元素为(2)

        c11 = a11 * b11 + a12 * b21
        c12 = a11 * b12 + a12 * b22
        c21 = a21 * b11 + a22 * b21
        c22 = a21 * b12 + a22 * b22
        比较(1)(2),C的各元素可以表示为(3)

        c11 = x1 + x4 - x5 + x7
        c12 = x3 + x5
        c21 = x2 + x4
        c22 = x1 + x3 - x2 + x6


        根据以上的方法,我们就可以计算4阶矩阵了,先将4阶矩阵A和B划分成四块2阶矩阵,分别利用公式计算它们的乘积,再使用(1)(3)来计算出最后结果。

本文给出了多种算法,大家自己选择吧。

5 源程序

using System;
using System.Text;namespace Legal.Truffer.Algorithm
{/// <summary>/// 矩阵相乘的斯特拉森(V. Strassen)方法/// </summary>public static class Matrix_Calculator{#region [4x4]x[4x4]矩阵相乘的斯特拉森(V. Strassen)方法(快速算法)// 计算2X2矩阵private static void Multiply2X2(out double fOut_11, out double fOut_12, out double fOut_21, out double fOut_22,double f1_11, double f1_12, double f1_21, double f1_22,double f2_11, double f2_12, double f2_21, double f2_22){double x1 = ((f1_11 + f1_22) * (f2_11 + f2_22));double x2 = ((f1_21 + f1_22) * f2_11);double x3 = (f1_11 * (f2_12 - f2_22));double x4 = (f1_22 * (f2_21 - f2_11));double x5 = ((f1_11 + f1_12) * f2_22);double x6 = ((f1_21 - f1_11) * (f2_11 + f2_12));double x7 = ((f1_12 - f1_22) * (f2_21 + f2_22));fOut_11 = x1 + x4 - x5 + x7;fOut_12 = x3 + x5;fOut_21 = x2 + x4;fOut_22 = x1 - x2 + x3 + x6;}// 计算4X4矩阵public static Matrix Multiply4x4(Matrix a, Matrix b){//double c[7,4] = new double[7,4];double c_0_0, c_0_1, c_0_2, c_0_3;double c_1_0, c_1_1, c_1_2, c_1_3;double c_2_0, c_2_1, c_2_2, c_2_3;double c_3_0, c_3_1, c_3_2, c_3_3;double c_4_0, c_4_1, c_4_2, c_4_3;double c_5_0, c_5_1, c_5_2, c_5_3;double c_6_0, c_6_1, c_6_2, c_6_3;// (ma11 + ma22) * (mb11 + mb22)Multiply2X2(out c_0_0, out c_0_1, out c_0_2, out c_0_3,a[0] + a[10], a[1] + a[11], a[4] + a[14], a[5] + a[15],b[0] + b[10], b[1] + b[11], b[4] + b[14], b[5] + b[15]);// (ma21 + ma22) * mb11Multiply2X2(out c_1_0, out c_1_1, out c_1_2, out c_1_3,a[8] + a[10], a[9] + a[11], a[12] + a[14], a[13] + a[15],b[0], b[1], b[4], b[5]);// ma11 * (mb12 - mb22)Multiply2X2(out c_2_0, out c_2_1, out c_2_2, out c_2_3,a[0], a[1], a[4], a[5],b[2] - b[10], b[3] - b[11], b[6] - b[14], b[7] - b[15]);// ma22 * (mb21 - mb11)Multiply2X2(out c_3_0, out c_3_1, out c_3_2, out c_3_3,a[10], a[11], a[14], a[15],b[8] - b[0], b[9] - b[1], b[12] - b[4], b[13] - b[5]);// (ma11 + ma12) * mb22Multiply2X2(out c_4_0, out c_4_1, out c_4_2, out c_4_3,a[0] + a[2], a[1] + a[3], a[4] + a[6], a[5] + a[7],b[10], b[11], b[14], b[15]);// (ma21 - ma11) * (mb11 + mb12)Multiply2X2(out c_5_0, out c_5_1, out c_5_2, out c_5_3,a[8] - a[0], a[9] - a[1], a[12] - a[4], a[13] - a[5],b[0] + b[2], b[1] + b[3], b[4] + b[6], b[5] + b[7]);// (ma12 - ma22) * (mb21 + mb22)Multiply2X2(out c_6_0, out c_6_1, out c_6_2, out c_6_3,a[2] - a[10], a[3] - a[11], a[6] - a[14], a[7] - a[15],b[8] + b[10], b[9] + b[11], b[12] + b[14], b[13] + b[15]);return new Matrix(4, 4, new double[4 * 4] {c_0_0 + c_3_0 - c_4_0 + c_6_0,c_0_1 + c_3_1 - c_4_1 + c_6_1,c_2_0 + c_4_0,c_2_1 + c_4_1,c_0_2 + c_3_2 - c_4_2 + c_6_2,c_0_3 + c_3_3 - c_4_3 + c_6_3,c_2_2 + c_4_2,c_2_3 + c_4_3,c_1_0 + c_3_0,c_1_1 + c_3_1,c_0_0 - c_1_0 + c_2_0 + c_5_0,c_0_1 - c_1_1 + c_2_1 + c_5_1,c_1_2 + c_3_2,c_1_3 + c_3_3,c_0_2 - c_1_2 + c_2_2 + c_5_2,c_0_3 - c_1_3 + c_2_3 + c_5_3});}#endregion#region 基于Strassen算法的矩阵“分治”乘法(只支持维度为2的幂次的方阵相乘。)private static Matrix create(Matrix input, int r1, int r2, int c1, int c2){Matrix res = new Matrix(r2 - r1, c2 - c1);for (int i = r1, ii = 0; i <= r2 && ii < r2 - r1; i++, ii++){for (int j = c1, jj = 0; j < c2 && jj < c2 - c1; j++, jj++){res[ii, jj] = input[i, j];}}return res;}public static Matrix Multipy(Matrix A, Matrix B, int len, int r1 = 0, int c1 = 0){if (len == 1){return new Matrix(1, 1,new double[1] { A[0] * B[0] });}int lend2 = len / 2;Matrix a = create(A, r1, r1 + lend2, c1, c1 + lend2);Matrix e = create(B, r1, r1 + lend2, c1, c1 + lend2);Matrix b = create(A, r1, r1 + lend2, c1 + lend2, len);Matrix f = create(B, r1, r1 + lend2, c1 + lend2, len);Matrix c = create(A, r1 + lend2, len, c1, c1 + lend2);Matrix g = create(B, r1 + lend2, len, c1, c1 + lend2);Matrix d = create(A, r1 + lend2, len, c1 + lend2, len);Matrix h = create(B, r1 + lend2, len, c1 + lend2, len);Matrix p1 = a * (f - h); // multi(a, sub(f, h, lend2), 0, 0, lend2); Matrix p2 = (a + b) * h; // multi(add(a, b, lend2), h, 0, 0, lend2);Matrix p3 = (c + d) * e; // multi(add(c, d, lend2), e, 0, 0, lend2);Matrix p4 = d * (g - e); // multi(d, sub(g, e, lend2), 0, 0, lend2);Matrix p5 = (a + d) * (e + h); // multi(add(a, d, lend2), add(e, h, lend2), 0, 0, lend2);Matrix p6 = (b - d) * (g + h); // multi(sub(b, d, lend2), add(g, h, lend2), 0, 0, lend2);Matrix p7 = (a - c) * (e + f); // multi(sub(a, c, lend2), add(e, f, lend2), 0, 0, lend2);Matrix r = (((p5 + p4) + p6) - p2); // sub(add(add(p5, p4, lend2), p6, lend2), p2, lend2);Matrix s = p1 + p2; // add(p1, p2, lend2);Matrix t = p3 + p4; // add(p3, p4, lend2);Matrix u = (p5 + p1) - (p3 + p7);// sub(add(p5, p1, lend2), add(p3, p7, lend2), lend2);Matrix rr = new Matrix(len, len);for (int j = 0; j < lend2; j++){for (int jj = 0; jj < lend2; jj++){rr[j, jj] = r[j, jj];}}for (int j = 0; j < lend2; j++){for (int jj = 0; jj < lend2; jj++){rr[j, jj + lend2] = s[j, jj];}}for (int j = 0; j < lend2; j++){for (int jj = 0; jj < lend2; jj++){rr[j + lend2, jj] = t[j, jj];}}for (int j = 0; j < lend2; j++){for (int jj = 0; jj < lend2; jj++){rr[j + lend2, jj + lend2] = u[j, jj];}}return rr;}#endregion#region 基于Strassen矩阵乘法的递归分治算法/// <summary>/// 基于Strassen矩阵乘法的递归分治算法/// </summary>/// <param name="n"></param>/// <param name="A"></param>/// <param name="B"></param>/// <returns></returns>public static Matrix Strassen(int n, Matrix A, Matrix B){//2-order if (n == 2){return A * B;}int N = n / 2;Matrix A11 = new Matrix(N, N);Matrix A12 = new Matrix(N, N);Matrix A21 = new Matrix(N, N);Matrix A22 = new Matrix(N, N);Matrix B11 = new Matrix(N, N);Matrix B12 = new Matrix(N, N);Matrix B21 = new Matrix(N, N);Matrix B22 = new Matrix(N, N);//将矩阵A和B分成阶数相同的四个子矩阵,即分治思想。  for (int i = 0; i < n / 2; i++){for (int j = 0; j < n / 2; j++){A11[i, j] = A[i, j];A12[i, j] = A[i, j + n / 2];A21[i, j] = A[i + n / 2, j];A22[i, j] = A[i + n / 2, j + n / 2];B11[i, j] = B[i, j];B12[i, j] = B[i, j + n / 2];B21[i, j] = B[i + n / 2, j];B22[i, j] = B[i + n / 2, j + n / 2];}}//Calculate M1 = (A0 + A3) × (B0 + B3)  Matrix M1 = Strassen(N, A11 + A22, B11 + B22);//Calculate M2 = (A2 + A3) × B0  Matrix M2 = Strassen(N, A21 + A22, B11);//Calculate M3 = A0 × (B1 - B3)  Matrix M3 = Strassen(N, A11, B12 - B22);//Calculate M4 = A3 × (B2 - B0)  Matrix M4 = Strassen(N, A22, B21 - B11);//Calculate M5 = (A0 + A1) × B3  Matrix M5 = Strassen(N, A11 + A12, B22);//Calculate M6 = (A2 - A0) × (B0 + B1)  Matrix M6 = Strassen(N, A21 - A11, B11 + B12);//Calculate M7 = (A1 - A3) × (B2 + B3)  Matrix M7 = Strassen(N, A12 - A22, B21 + B22);//Calculate C0 = M1 + M4 - M5 + M7  Matrix C11 = (M1 + M4) + (M7 - M5);//Calculate C1 = M3 + M5  Matrix C12 = M3 + M5;//Calculate C2 = M2 + M4  Matrix C21 = M2 + M4;//Calculate C3 = M1 - M2 + M3 + M6  Matrix C22 = (M1 - M2) + (M3 + M6);Matrix C = new Matrix(n, n);for (int i = 0; i < N; i++){for (int j = 0; j < N; j++){C[i, j] = C11[i, j];C[i, j + N] = C12[i, j];C[i + N, j] = C21[i, j];C[i + N, j + N] = C22[i, j];}}return C;}#endregion}
}

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com