Volker Strassen
1 矩阵乘法
矩阵乘法是机器学习中最基本的运算之一,对其进行优化是多种优化的关键。通常,将两个大小为N X N的矩阵相乘需要N^3次运算。从那以后,我们在更好、更聪明的矩阵乘法算法方面取得了长足的进步。沃尔克·斯特拉森于1969年首次发表了他的算法。这是第一个证明基本O(n^3)运行时不是optiomal的算法。
Strassen算法的基本思想是将A和B分为8个子矩阵,然后递归计算C的子矩阵。这种策略称为分而治之。
2 伪代码
- 如上图所示,将矩阵A和B划分为大小为N/2 x N/2的4个子矩阵。
- 递归计算7个矩阵乘法。
- 计算C的子矩阵。
- 将这些子矩阵组合到我们的新矩阵C中
3 复杂性
- 最坏情况时间复杂度:Θ(n^2.8074)
- 最佳情况时间复杂度:Θ(1)
- 空间复杂度:Θ(logn)
年青时正在发愁的 Volker Strassen
4 算法的详细解释
矩阵相乘在进行3D变换的时候是经常用到的。在应用中常用矩阵相乘的定义算法对其进行计算。这个算法用到了大量的循环和相乘运算,这使得算法效率不高。而矩阵相乘的计算效率很大程度上的影响了整个程序的运行速度,所以对矩阵相乘算法进行一些改进是必要的。
我们先讨论二阶矩阵的计算方法。
对于二阶矩阵
a11 a12 b11 b12
A = a21 a22 B = b21 b22
先计算下面7个量(1)
x1 = (a11 + a22) * (b11 + b22);
x2 = (a21 + a22) * b11;
x3 = a11 * (b12 - b22);
x4 = a22 * (b21 - b11);
x5 = (a11 + a12) * b22;
x6 = (a21 - a11) * (b11 + b12);
x7 = (a12 - a22) * (b21 + b22);
再设C = AB。根据矩阵相乘的规则,C的各元素为(2)
c11 = a11 * b11 + a12 * b21
c12 = a11 * b12 + a12 * b22
c21 = a21 * b11 + a22 * b21
c22 = a21 * b12 + a22 * b22
比较(1)(2),C的各元素可以表示为(3)
c11 = x1 + x4 - x5 + x7
c12 = x3 + x5
c21 = x2 + x4
c22 = x1 + x3 - x2 + x6
根据以上的方法,我们就可以计算4阶矩阵了,先将4阶矩阵A和B划分成四块2阶矩阵,分别利用公式计算它们的乘积,再使用(1)(3)来计算出最后结果。
本文给出了多种算法,大家自己选择吧。
5 源程序
using System;
using System.Text;namespace Legal.Truffer.Algorithm
{/// <summary>/// 矩阵相乘的斯特拉森(V. Strassen)方法/// </summary>public static class Matrix_Calculator{#region [4x4]x[4x4]矩阵相乘的斯特拉森(V. Strassen)方法(快速算法)// 计算2X2矩阵private static void Multiply2X2(out double fOut_11, out double fOut_12, out double fOut_21, out double fOut_22,double f1_11, double f1_12, double f1_21, double f1_22,double f2_11, double f2_12, double f2_21, double f2_22){double x1 = ((f1_11 + f1_22) * (f2_11 + f2_22));double x2 = ((f1_21 + f1_22) * f2_11);double x3 = (f1_11 * (f2_12 - f2_22));double x4 = (f1_22 * (f2_21 - f2_11));double x5 = ((f1_11 + f1_12) * f2_22);double x6 = ((f1_21 - f1_11) * (f2_11 + f2_12));double x7 = ((f1_12 - f1_22) * (f2_21 + f2_22));fOut_11 = x1 + x4 - x5 + x7;fOut_12 = x3 + x5;fOut_21 = x2 + x4;fOut_22 = x1 - x2 + x3 + x6;}// 计算4X4矩阵public static Matrix Multiply4x4(Matrix a, Matrix b){//double c[7,4] = new double[7,4];double c_0_0, c_0_1, c_0_2, c_0_3;double c_1_0, c_1_1, c_1_2, c_1_3;double c_2_0, c_2_1, c_2_2, c_2_3;double c_3_0, c_3_1, c_3_2, c_3_3;double c_4_0, c_4_1, c_4_2, c_4_3;double c_5_0, c_5_1, c_5_2, c_5_3;double c_6_0, c_6_1, c_6_2, c_6_3;// (ma11 + ma22) * (mb11 + mb22)Multiply2X2(out c_0_0, out c_0_1, out c_0_2, out c_0_3,a[0] + a[10], a[1] + a[11], a[4] + a[14], a[5] + a[15],b[0] + b[10], b[1] + b[11], b[4] + b[14], b[5] + b[15]);// (ma21 + ma22) * mb11Multiply2X2(out c_1_0, out c_1_1, out c_1_2, out c_1_3,a[8] + a[10], a[9] + a[11], a[12] + a[14], a[13] + a[15],b[0], b[1], b[4], b[5]);// ma11 * (mb12 - mb22)Multiply2X2(out c_2_0, out c_2_1, out c_2_2, out c_2_3,a[0], a[1], a[4], a[5],b[2] - b[10], b[3] - b[11], b[6] - b[14], b[7] - b[15]);// ma22 * (mb21 - mb11)Multiply2X2(out c_3_0, out c_3_1, out c_3_2, out c_3_3,a[10], a[11], a[14], a[15],b[8] - b[0], b[9] - b[1], b[12] - b[4], b[13] - b[5]);// (ma11 + ma12) * mb22Multiply2X2(out c_4_0, out c_4_1, out c_4_2, out c_4_3,a[0] + a[2], a[1] + a[3], a[4] + a[6], a[5] + a[7],b[10], b[11], b[14], b[15]);// (ma21 - ma11) * (mb11 + mb12)Multiply2X2(out c_5_0, out c_5_1, out c_5_2, out c_5_3,a[8] - a[0], a[9] - a[1], a[12] - a[4], a[13] - a[5],b[0] + b[2], b[1] + b[3], b[4] + b[6], b[5] + b[7]);// (ma12 - ma22) * (mb21 + mb22)Multiply2X2(out c_6_0, out c_6_1, out c_6_2, out c_6_3,a[2] - a[10], a[3] - a[11], a[6] - a[14], a[7] - a[15],b[8] + b[10], b[9] + b[11], b[12] + b[14], b[13] + b[15]);return new Matrix(4, 4, new double[4 * 4] {c_0_0 + c_3_0 - c_4_0 + c_6_0,c_0_1 + c_3_1 - c_4_1 + c_6_1,c_2_0 + c_4_0,c_2_1 + c_4_1,c_0_2 + c_3_2 - c_4_2 + c_6_2,c_0_3 + c_3_3 - c_4_3 + c_6_3,c_2_2 + c_4_2,c_2_3 + c_4_3,c_1_0 + c_3_0,c_1_1 + c_3_1,c_0_0 - c_1_0 + c_2_0 + c_5_0,c_0_1 - c_1_1 + c_2_1 + c_5_1,c_1_2 + c_3_2,c_1_3 + c_3_3,c_0_2 - c_1_2 + c_2_2 + c_5_2,c_0_3 - c_1_3 + c_2_3 + c_5_3});}#endregion#region 基于Strassen算法的矩阵“分治”乘法(只支持维度为2的幂次的方阵相乘。)private static Matrix create(Matrix input, int r1, int r2, int c1, int c2){Matrix res = new Matrix(r2 - r1, c2 - c1);for (int i = r1, ii = 0; i <= r2 && ii < r2 - r1; i++, ii++){for (int j = c1, jj = 0; j < c2 && jj < c2 - c1; j++, jj++){res[ii, jj] = input[i, j];}}return res;}public static Matrix Multipy(Matrix A, Matrix B, int len, int r1 = 0, int c1 = 0){if (len == 1){return new Matrix(1, 1,new double[1] { A[0] * B[0] });}int lend2 = len / 2;Matrix a = create(A, r1, r1 + lend2, c1, c1 + lend2);Matrix e = create(B, r1, r1 + lend2, c1, c1 + lend2);Matrix b = create(A, r1, r1 + lend2, c1 + lend2, len);Matrix f = create(B, r1, r1 + lend2, c1 + lend2, len);Matrix c = create(A, r1 + lend2, len, c1, c1 + lend2);Matrix g = create(B, r1 + lend2, len, c1, c1 + lend2);Matrix d = create(A, r1 + lend2, len, c1 + lend2, len);Matrix h = create(B, r1 + lend2, len, c1 + lend2, len);Matrix p1 = a * (f - h); // multi(a, sub(f, h, lend2), 0, 0, lend2); Matrix p2 = (a + b) * h; // multi(add(a, b, lend2), h, 0, 0, lend2);Matrix p3 = (c + d) * e; // multi(add(c, d, lend2), e, 0, 0, lend2);Matrix p4 = d * (g - e); // multi(d, sub(g, e, lend2), 0, 0, lend2);Matrix p5 = (a + d) * (e + h); // multi(add(a, d, lend2), add(e, h, lend2), 0, 0, lend2);Matrix p6 = (b - d) * (g + h); // multi(sub(b, d, lend2), add(g, h, lend2), 0, 0, lend2);Matrix p7 = (a - c) * (e + f); // multi(sub(a, c, lend2), add(e, f, lend2), 0, 0, lend2);Matrix r = (((p5 + p4) + p6) - p2); // sub(add(add(p5, p4, lend2), p6, lend2), p2, lend2);Matrix s = p1 + p2; // add(p1, p2, lend2);Matrix t = p3 + p4; // add(p3, p4, lend2);Matrix u = (p5 + p1) - (p3 + p7);// sub(add(p5, p1, lend2), add(p3, p7, lend2), lend2);Matrix rr = new Matrix(len, len);for (int j = 0; j < lend2; j++){for (int jj = 0; jj < lend2; jj++){rr[j, jj] = r[j, jj];}}for (int j = 0; j < lend2; j++){for (int jj = 0; jj < lend2; jj++){rr[j, jj + lend2] = s[j, jj];}}for (int j = 0; j < lend2; j++){for (int jj = 0; jj < lend2; jj++){rr[j + lend2, jj] = t[j, jj];}}for (int j = 0; j < lend2; j++){for (int jj = 0; jj < lend2; jj++){rr[j + lend2, jj + lend2] = u[j, jj];}}return rr;}#endregion#region 基于Strassen矩阵乘法的递归分治算法/// <summary>/// 基于Strassen矩阵乘法的递归分治算法/// </summary>/// <param name="n"></param>/// <param name="A"></param>/// <param name="B"></param>/// <returns></returns>public static Matrix Strassen(int n, Matrix A, Matrix B){//2-order if (n == 2){return A * B;}int N = n / 2;Matrix A11 = new Matrix(N, N);Matrix A12 = new Matrix(N, N);Matrix A21 = new Matrix(N, N);Matrix A22 = new Matrix(N, N);Matrix B11 = new Matrix(N, N);Matrix B12 = new Matrix(N, N);Matrix B21 = new Matrix(N, N);Matrix B22 = new Matrix(N, N);//将矩阵A和B分成阶数相同的四个子矩阵,即分治思想。 for (int i = 0; i < n / 2; i++){for (int j = 0; j < n / 2; j++){A11[i, j] = A[i, j];A12[i, j] = A[i, j + n / 2];A21[i, j] = A[i + n / 2, j];A22[i, j] = A[i + n / 2, j + n / 2];B11[i, j] = B[i, j];B12[i, j] = B[i, j + n / 2];B21[i, j] = B[i + n / 2, j];B22[i, j] = B[i + n / 2, j + n / 2];}}//Calculate M1 = (A0 + A3) × (B0 + B3) Matrix M1 = Strassen(N, A11 + A22, B11 + B22);//Calculate M2 = (A2 + A3) × B0 Matrix M2 = Strassen(N, A21 + A22, B11);//Calculate M3 = A0 × (B1 - B3) Matrix M3 = Strassen(N, A11, B12 - B22);//Calculate M4 = A3 × (B2 - B0) Matrix M4 = Strassen(N, A22, B21 - B11);//Calculate M5 = (A0 + A1) × B3 Matrix M5 = Strassen(N, A11 + A12, B22);//Calculate M6 = (A2 - A0) × (B0 + B1) Matrix M6 = Strassen(N, A21 - A11, B11 + B12);//Calculate M7 = (A1 - A3) × (B2 + B3) Matrix M7 = Strassen(N, A12 - A22, B21 + B22);//Calculate C0 = M1 + M4 - M5 + M7 Matrix C11 = (M1 + M4) + (M7 - M5);//Calculate C1 = M3 + M5 Matrix C12 = M3 + M5;//Calculate C2 = M2 + M4 Matrix C21 = M2 + M4;//Calculate C3 = M1 - M2 + M3 + M6 Matrix C22 = (M1 - M2) + (M3 + M6);Matrix C = new Matrix(n, n);for (int i = 0; i < N; i++){for (int j = 0; j < N; j++){C[i, j] = C11[i, j];C[i, j + N] = C12[i, j];C[i + N, j] = C21[i, j];C[i + N, j + N] = C22[i, j];}}return C;}#endregion}
}