Optimization for custom small blas multiplication with dynamic template parameters in C level. - unroll for loops - matrix access more cache coherent - platform independant Briefly, this commit brings 1~50% performance improvments for most cases in small_blas_gem(m/v)_benchmark, but a small drop for corner cases with small dimensions especially 1,2,3. Here we list the results partially, which show decrease percentage of executing time, compared to unoptimized version. Platform: desktop PC (i7-7700 CPU MP8@3.60GHz + ubuntu 17.10) (Lenovo Research Device+ Lab, <yangfan34@lenovo.com>) Benchmark Time CPU ----------------------------------------------------------- BM_MatrixMatrixMultiplyDynamic/1/1/1 -0.0850 -0.0851 BM_MatrixMatrixMultiplyDynamic/1/1/2 -0.1444 -0.1446 BM_MatrixMatrixMultiplyDynamic/1/1/3 -0.1934 -0.1935 BM_MatrixMatrixMultiplyDynamic/1/1/4 -0.2933 -0.2934 BM_MatrixMatrixMultiplyDynamic/1/1/8 -0.1579 -0.1580 BM_MatrixMatrixMultiplyDynamic/1/1/12 -0.1556 -0.1558 BM_MatrixMatrixMultiplyDynamic/1/1/15 -0.1598 -0.1599 BM_MatrixMatrixMultiplyDynamic/1/2/1 -0.0797 -0.0799 BM_MatrixMatrixMultiplyDynamic/1/2/2 -0.2950 -0.2951 BM_MatrixMatrixMultiplyDynamic/1/2/3 -0.1363 -0.1364 BM_MatrixMatrixMultiplyDynamic/1/2/4 -0.2435 -0.2437 BM_MatrixMatrixMultiplyDynamic/1/2/8 -0.2299 -0.2300 BM_MatrixMatrixMultiplyDynamic/1/2/12 -0.2441 -0.2442 BM_MatrixMatrixMultiplyDynamic/1/2/15 -0.1671 -0.1673 BM_MatrixMatrixMultiplyDynamic/1/3/1 -0.0774 -0.0775 BM_MatrixMatrixMultiplyDynamic/1/3/2 -0.2761 -0.2762 BM_MatrixMatrixMultiplyDynamic/1/3/3 -0.0840 -0.0841 BM_MatrixMatrixMultiplyDynamic/1/3/4 -0.2027 -0.2028 BM_MatrixMatrixMultiplyDynamic/1/3/8 -0.2481 -0.2482 BM_MatrixMatrixMultiplyDynamic/1/3/12 -0.2629 -0.2630 BM_MatrixMatrixMultiplyDynamic/1/3/15 -0.1958 -0.1959 BM_MatrixMatrixMultiplyDynamic/1/4/1 -0.1260 -0.1261 BM_MatrixMatrixMultiplyDynamic/1/4/2 -0.1834 -0.1835 BM_MatrixMatrixMultiplyDynamic/1/4/3 -0.1379 -0.1380 BM_MatrixMatrixMultiplyDynamic/1/4/4 -0.2636 -0.2637 BM_MatrixMatrixMultiplyDynamic/1/4/8 -0.2838 -0.2839 BM_MatrixMatrixMultiplyDynamic/1/4/12 -0.3320 -0.3321 BM_MatrixMatrixMultiplyDynamic/1/4/15 -0.2464 -0.2465 BM_MatrixMatrixMultiplyDynamic/1/8/1 -0.0766 -0.0767 BM_MatrixMatrixMultiplyDynamic/1/8/2 -0.1713 -0.1714 BM_MatrixMatrixMultiplyDynamic/1/8/3 -0.1158 -0.1159 BM_MatrixMatrixMultiplyDynamic/1/8/4 -0.3205 -0.3206 BM_MatrixMatrixMultiplyDynamic/1/8/8 -0.3514 -0.3515 BM_MatrixMatrixMultiplyDynamic/1/8/12 -0.3658 -0.3658 BM_MatrixMatrixMultiplyDynamic/1/8/15 -0.3187 -0.3188 BM_MatrixMatrixMultiplyDynamic/1/12/1 -0.0424 -0.0425 BM_MatrixMatrixMultiplyDynamic/1/12/2 -0.1800 -0.1800 BM_MatrixMatrixMultiplyDynamic/1/12/3 -0.1457 -0.1457 BM_MatrixMatrixMultiplyDynamic/1/12/4 -0.3768 -0.3769 BM_MatrixMatrixMultiplyDynamic/1/12/8 -0.4072 -0.4073 BM_MatrixMatrixMultiplyDynamic/1/12/12 -0.4391 -0.4392 BM_MatrixMatrixMultiplyDynamic/1/12/15 -0.3383 -0.3383 BM_MatrixMatrixMultiplyDynamic/1/15/1 -0.0442 -0.0443 BM_MatrixMatrixMultiplyDynamic/1/15/2 -0.2378 -0.2379 BM_MatrixMatrixMultiplyDynamic/1/15/3 -0.1553 -0.1554 BM_MatrixMatrixMultiplyDynamic/1/15/4 -0.3954 -0.3955 BM_MatrixMatrixMultiplyDynamic/1/15/8 -0.4334 -0.4335 BM_MatrixMatrixMultiplyDynamic/1/15/12 -0.4175 -0.4175 BM_MatrixMatrixMultiplyDynamic/1/15/15 -0.3242 -0.3243 BM_MatrixVectorMultiply/1/1 +0.1613 +0.1613 BM_MatrixVectorMultiply/1/2 +0.1715 +0.1715 BM_MatrixVectorMultiply/1/3 +0.1051 +0.1051 BM_MatrixVectorMultiply/1/4 +0.1369 +0.1369 BM_MatrixVectorMultiply/1/8 +0.1180 +0.1180 BM_MatrixVectorMultiply/1/12 +0.0869 +0.0869 BM_MatrixVectorMultiply/1/15 +0.1887 +0.1886 BM_MatrixVectorMultiply/2/1 +0.1152 +0.1152 BM_MatrixVectorMultiply/2/2 +0.1520 +0.1520 BM_MatrixVectorMultiply/2/3 +0.1867 +0.1867 BM_MatrixVectorMultiply/2/4 +0.0173 +0.0173 BM_MatrixVectorMultiply/2/8 -0.0528 -0.0528 BM_MatrixVectorMultiply/2/12 -0.0176 -0.0176 BM_MatrixVectorMultiply/2/15 -0.0753 -0.0753 BM_MatrixVectorMultiply/3/1 +0.0844 +0.0844 BM_MatrixVectorMultiply/3/2 +0.0750 +0.0750 BM_MatrixVectorMultiply/3/3 -0.0153 -0.0153 BM_MatrixVectorMultiply/3/4 +0.0060 +0.0060 BM_MatrixVectorMultiply/3/8 +0.0152 +0.0152 BM_MatrixVectorMultiply/3/12 +0.0101 +0.0101 BM_MatrixVectorMultiply/3/15 -0.0795 -0.0795 BM_MatrixVectorMultiply/4/1 -0.1425 -0.1425 BM_MatrixVectorMultiply/4/2 -0.0869 -0.0869 BM_MatrixVectorMultiply/4/3 -0.1371 -0.1371 BM_MatrixVectorMultiply/4/4 -0.0088 -0.0088 BM_MatrixVectorMultiply/4/8 -0.1049 -0.1049 BM_MatrixVectorMultiply/4/12 -0.2566 -0.2566 BM_MatrixVectorMultiply/4/15 -0.2940 -0.2940 BM_MatrixVectorMultiply/6/1 -0.1798 -0.1798 BM_MatrixVectorMultiply/6/2 -0.0627 -0.0627 BM_MatrixVectorMultiply/6/3 -0.0389 -0.0389 BM_MatrixVectorMultiply/6/4 -0.1088 -0.1088 BM_MatrixVectorMultiply/6/8 -0.1815 -0.1815 BM_MatrixVectorMultiply/6/12 -0.1650 -0.1650 BM_MatrixVectorMultiply/6/15 -0.1855 -0.1855 BM_MatrixVectorMultiply/8/1 -0.1630 -0.1630 BM_MatrixVectorMultiply/8/2 -0.1248 -0.1248 BM_MatrixVectorMultiply/8/3 -0.1911 -0.1911 BM_MatrixVectorMultiply/8/4 -0.1996 -0.1996 BM_MatrixVectorMultiply/8/8 -0.2590 -0.2590 BM_MatrixVectorMultiply/8/12 -0.3266 -0.3266 BM_MatrixVectorMultiply/8/15 -0.3999 -0.3999 BM_MatrixTransposeVectorMultiply/1/1 -0.0234 -0.0234 BM_MatrixTransposeVectorMultiply/1/2 -0.0243 -0.0243 BM_MatrixTransposeVectorMultiply/1/3 -0.1324 -0.1324 BM_MatrixTransposeVectorMultiply/1/4 -0.2635 -0.2635 BM_MatrixTransposeVectorMultiply/1/8 -0.2461 -0.2461 BM_MatrixTransposeVectorMultiply/1/12 -0.2702 -0.2702 BM_MatrixTransposeVectorMultiply/1/15 -0.2538 -0.2538 BM_MatrixTransposeVectorMultiply/2/1 -0.0170 -0.0170 BM_MatrixTransposeVectorMultiply/2/2 -0.1475 -0.1475 BM_MatrixTransposeVectorMultiply/2/3 -0.1082 -0.1082 BM_MatrixTransposeVectorMultiply/2/4 -0.2594 -0.2595 BM_MatrixTransposeVectorMultiply/2/8 -0.2710 -0.2710 BM_MatrixTransposeVectorMultiply/2/12 -0.3053 -0.3053 BM_MatrixTransposeVectorMultiply/2/15 -0.2706 -0.2706 BM_MatrixTransposeVectorMultiply/3/1 -0.0096 -0.0096 BM_MatrixTransposeVectorMultiply/3/2 -0.2885 -0.2886 BM_MatrixTransposeVectorMultiply/3/3 -0.0790 -0.0790 BM_MatrixTransposeVectorMultiply/3/4 -0.2329 -0.2330 BM_MatrixTransposeVectorMultiply/3/8 -0.2742 -0.2742 BM_MatrixTransposeVectorMultiply/3/12 -0.3177 -0.3177 BM_MatrixTransposeVectorMultiply/3/15 -0.2610 -0.2610 BM_MatrixTransposeVectorMultiply/4/1 -0.0024 -0.0024 BM_MatrixTransposeVectorMultiply/4/2 -0.1578 -0.1578 BM_MatrixTransposeVectorMultiply/4/3 -0.0918 -0.0918 BM_MatrixTransposeVectorMultiply/4/4 -0.2570 -0.2570 BM_MatrixTransposeVectorMultiply/4/8 -0.3064 -0.3064 BM_MatrixTransposeVectorMultiply/4/12 -0.3316 -0.3316 BM_MatrixTransposeVectorMultiply/4/15 -0.2794 -0.2794 BM_MatrixTransposeVectorMultiply/6/1 -0.0484 -0.0484 BM_MatrixTransposeVectorMultiply/6/2 -0.1102 -0.1102 BM_MatrixTransposeVectorMultiply/6/3 -0.1188 -0.1188 BM_MatrixTransposeVectorMultiply/6/4 -0.2967 -0.2967 BM_MatrixTransposeVectorMultiply/6/8 -0.3190 -0.3190 BM_MatrixTransposeVectorMultiply/6/12 -0.3441 -0.3441 BM_MatrixTransposeVectorMultiply/6/15 -0.2723 -0.2723 BM_MatrixTransposeVectorMultiply/8/1 -0.0397 -0.0397 BM_MatrixTransposeVectorMultiply/8/2 -0.1453 -0.1453 BM_MatrixTransposeVectorMultiply/8/3 -0.1337 -0.1337 BM_MatrixTransposeVectorMultiply/8/4 -0.3084 -0.3084 BM_MatrixTransposeVectorMultiply/8/8 -0.3444 -0.3444 BM_MatrixTransposeVectorMultiply/8/12 -0.3717 -0.3717 BM_MatrixTransposeVectorMultiply/8/15 -0.3440 -0.3440 Change-Id: I17de05bf94699a07eea880b92a6d08daf1f038bb
diff --git a/internal/ceres/small_blas.h b/internal/ceres/small_blas.h index 264ac53..81c5872 100644 --- a/internal/ceres/small_blas.h +++ b/internal/ceres/small_blas.h
@@ -38,6 +38,7 @@ #include "ceres/internal/port.h" #include "ceres/internal/eigen.h" #include "glog/logging.h" +#include "small_blas_generic.h" namespace ceres { namespace internal { @@ -89,6 +90,26 @@ B, num_row_b, num_col_b, \ C, start_row_c, start_col_c, row_stride_c, col_stride_c); +#define CERES_GEMM_STORE_SINGLE(p, index, value) \ + if (kOperation > 0) { \ + p[index] += value; \ + } else if (kOperation < 0) { \ + p[index] -= value; \ + } else { \ + p[index] = value; \ + } + +#define CERES_GEMM_STORE_PAIR(p, index, v1, v2) \ + if (kOperation > 0) { \ + p[index] += v1; \ + p[index + 1] += v2; \ + } else if (kOperation < 0) { \ + p[index] -= v1; \ + p[index + 1] -= v2; \ + } else { \ + p[index] = v1; \ + p[index + 1] = v2; \ + } // For the matrix-matrix functions below, there are three variants for // each functionality. Foo, FooNaive and FooEigen. Foo is the one to @@ -160,24 +181,64 @@ const int NUM_COL_C = NUM_COL_B; DCHECK_LE(start_row_c + NUM_ROW_C, row_stride_c); DCHECK_LE(start_col_c + NUM_COL_C, col_stride_c); + const int span = 4; - for (int row = 0; row < NUM_ROW_C; ++row) { - for (int col = 0; col < NUM_COL_C; ++col) { + // Calculate the remainder part first. + + // Process the last odd column if present. + if (NUM_COL_C & 1) { + int col = NUM_COL_C - 1; + const double* pa = &A[0]; + for (int row = 0; row < NUM_ROW_C; ++row, pa += NUM_COL_A) { + const double* pb = &B[col]; double tmp = 0.0; - for (int k = 0; k < NUM_COL_A; ++k) { - tmp += A[row * NUM_COL_A + k] * B[k * NUM_COL_B + col]; + for (int k = 0; k < NUM_COL_A; ++k, pb += NUM_COL_B) { + tmp += pa[k] * pb[0]; } const int index = (row + start_row_c) * col_stride_c + start_col_c + col; - if (kOperation > 0) { - C[index] += tmp; - } else if (kOperation < 0) { - C[index] -= tmp; - } else { - C[index] = tmp; - } + CERES_GEMM_STORE_SINGLE(C, index, tmp); + } + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_C == 1) { + return; } } + + // Process the couple columns in remainder if present. + if (NUM_COL_C & 2) { + int col = NUM_COL_C & (int)(~(span - 1)) ; + const double* pa = &A[0]; + for (int row = 0; row < NUM_ROW_C; ++row, pa += NUM_COL_A) { + const double* pb = &B[col]; + double tmp1 = 0.0, tmp2 = 0.0; + for (int k = 0; k < NUM_COL_A; ++k, pb += NUM_COL_B) { + double av = pa[k]; + tmp1 += av * pb[0]; + tmp2 += av * pb[1]; + } + + const int index = (row + start_row_c) * col_stride_c + start_col_c + col; + CERES_GEMM_STORE_PAIR(C, index, tmp1, tmp2); + } + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_C < span) { + return; + } + } + + // Calculate the main part with multiples of 4. + int col_m = NUM_COL_C & (int)(~(span - 1)); + for (int col = 0; col < col_m; col += span) { + for (int row = 0; row < NUM_ROW_C; ++row) { + const int index = (row + start_row_c) * col_stride_c + start_col_c + col; + MMM_mat1x4(NUM_COL_A, &A[row * NUM_COL_A], + &B[col], NUM_COL_B, &C[index], kOperation); + } + } + } CERES_GEMM_BEGIN(MatrixMatrixMultiply) { @@ -220,24 +281,68 @@ const int NUM_COL_C = NUM_COL_B; DCHECK_LE(start_row_c + NUM_ROW_C, row_stride_c); DCHECK_LE(start_col_c + NUM_COL_C, col_stride_c); + const int span = 4; - for (int row = 0; row < NUM_ROW_C; ++row) { - for (int col = 0; col < NUM_COL_C; ++col) { + // Process the remainder part first. + + // Process the last odd column if present. + if (NUM_COL_C & 1) { + int col = NUM_COL_C - 1; + for (int row = 0; row < NUM_ROW_C; ++row) { + const double* pa = &A[row]; + const double* pb = &B[col]; double tmp = 0.0; for (int k = 0; k < NUM_ROW_A; ++k) { - tmp += A[k * NUM_COL_A + row] * B[k * NUM_COL_B + col]; + tmp += pa[0] * pb[0]; + pa += NUM_COL_A; + pb += NUM_COL_B; } const int index = (row + start_row_c) * col_stride_c + start_col_c + col; - if (kOperation > 0) { - C[index]+= tmp; - } else if (kOperation < 0) { - C[index]-= tmp; - } else { - C[index]= tmp; - } + CERES_GEMM_STORE_SINGLE(C, index, tmp); + } + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_C == 1) { + return; } } + + // Process the couple columns in remainder if present. + if (NUM_COL_C & 2) { + int col = NUM_COL_C & (int)(~(span - 1)) ; + for (int row = 0; row < NUM_ROW_C; ++row) { + const double* pa = &A[row]; + const double* pb = &B[col]; + double tmp1 = 0.0, tmp2 = 0.0; + for (int k = 0; k < NUM_ROW_A; ++k) { + double av = *pa; + tmp1 += av * pb[0]; + tmp2 += av * pb[1]; + pa += NUM_COL_A; + pb += NUM_COL_B; + } + + const int index = (row + start_row_c) * col_stride_c + start_col_c + col; + CERES_GEMM_STORE_PAIR(C, index, tmp1, tmp2); + } + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_C < span) { + return; + } + } + + // Process the main part with multiples of 4. + int col_m = NUM_COL_C & (int)(~(span - 1)); + for (int col = 0; col < col_m; col += span) { + for (int row = 0; row < NUM_ROW_C; ++row) { + const int index = (row + start_row_c) * col_stride_c + start_col_c + col; + MTM_mat1x4(NUM_ROW_A, &A[row], NUM_COL_A, + &B[col], NUM_COL_B, &C[index], kOperation); + } + } + } CERES_GEMM_BEGIN(MatrixTransposeMatrixMultiply) { @@ -301,21 +406,54 @@ const int NUM_ROW_A = (kRowA != Eigen::Dynamic ? kRowA : num_row_a); const int NUM_COL_A = (kColA != Eigen::Dynamic ? kColA : num_col_a); + const int span = 4; - for (int row = 0; row < NUM_ROW_A; ++row) { + // Calculate the remainder part first. + + // Process the last odd row if present. + if (NUM_ROW_A & 1) { + int row = NUM_ROW_A - 1; + const double* pa = &A[row * NUM_COL_A]; + const double* pb = &b[0]; double tmp = 0.0; for (int col = 0; col < NUM_COL_A; ++col) { - tmp += A[row * NUM_COL_A + col] * b[col]; + tmp += (*pa++) * (*pb++); } + CERES_GEMM_STORE_SINGLE(c, row, tmp); - if (kOperation > 0) { - c[row] += tmp; - } else if (kOperation < 0) { - c[row] -= tmp; - } else { - c[row] = tmp; + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_ROW_A == 1) { + return; } } + + // Process the couple rows in remainder if present. + if (NUM_ROW_A & 2) { + int row = NUM_ROW_A & (int)(~(span - 1)); + const double* pa1 = &A[row * NUM_COL_A]; + const double* pa2 = pa1 + NUM_COL_A; + const double* pb = &b[0]; + double tmp1 = 0.0, tmp2 = 0.0; + for (int col = 0; col < NUM_COL_A; ++col) { + double bv = *pb++; + tmp1 += *(pa1++) * bv; + tmp2 += *(pa2++) * bv; + } + CERES_GEMM_STORE_PAIR(c, row, tmp1, tmp2); + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_ROW_A < span) { + return; + } + } + + // Calculate the main part with multiples of 4. + int row_m = NUM_ROW_A & (int)(~(span - 1)); + for (int row = 0; row < row_m; row += span) { + MVM_mat4x1(NUM_COL_A, &A[row * NUM_COL_A], NUM_COL_A, + &b[0], &c[row], kOperation); + } + #endif // CERES_NO_CUSTOM_BLAS } @@ -352,21 +490,55 @@ const int NUM_ROW_A = (kRowA != Eigen::Dynamic ? kRowA : num_row_a); const int NUM_COL_A = (kColA != Eigen::Dynamic ? kColA : num_col_a); + const int span = 4; - for (int row = 0; row < NUM_COL_A; ++row) { + // Calculate the remainder part first. + + // Process the last odd column if present. + if (NUM_COL_A & 1) { + int row = NUM_COL_A - 1; + const double* pa = &A[row]; + const double* pb = &b[0]; double tmp = 0.0; for (int col = 0; col < NUM_ROW_A; ++col) { - tmp += A[col * NUM_COL_A + row] * b[col]; + tmp += *pa * (*pb++); + pa += NUM_COL_A; } + CERES_GEMM_STORE_SINGLE(c, row, tmp); - if (kOperation > 0) { - c[row] += tmp; - } else if (kOperation < 0) { - c[row] -= tmp; - } else { - c[row] = tmp; + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_A == 1) { + return; } } + + // Process the couple columns in remainder if present. + if (NUM_COL_A & 2) { + int row = NUM_COL_A & (int)(~(span - 1)); + const double* pa = &A[row]; + const double* pb = &b[0]; + double tmp1 = 0.0, tmp2 = 0.0; + for (int col = 0; col < NUM_ROW_A; ++col) { + double bv = *pb++; + tmp1 += *(pa ) * bv; + tmp2 += *(pa + 1) * bv; + pa += NUM_COL_A; + } + CERES_GEMM_STORE_PAIR(c, row, tmp1, tmp2); + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_A < span) { + return; + } + } + + // Calculate the main part with multiples of 4. + int row_m = NUM_COL_A & (int)(~(span - 1)); + for (int row = 0; row < row_m; row += span) { + MTV_mat4x1(NUM_ROW_A, &A[row], NUM_COL_A, + &b[0], &c[row], kOperation); + } + #endif // CERES_NO_CUSTOM_BLAS } @@ -374,6 +546,8 @@ #undef CERES_GEMM_EIGEN_HEADER #undef CERES_GEMM_NAIVE_HEADER #undef CERES_CALL_GEMM +#undef CERES_GEMM_STORE_SINGLE +#undef CERES_GEMM_STORE_PAIR } // namespace internal } // namespace ceres
diff --git a/internal/ceres/small_blas_generic.h b/internal/ceres/small_blas_generic.h new file mode 100644 index 0000000..978c5d5 --- /dev/null +++ b/internal/ceres/small_blas_generic.h
@@ -0,0 +1,315 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2018 Google Inc. All rights reserved. +// http://ceres-solver.org/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: yangfan34@lenovo.com (Lenovo Research Device+ Lab - Shanghai) +// +// Optimization for simple blas functions used in the Schur Eliminator. +// These are fairly basic implementations which already yield a significant +// speedup in the eliminator performance. + +#ifndef CERES_INTERNAL_SMALL_BLAS_GENERIC_H_ +#define CERES_INTERNAL_SMALL_BLAS_GENERIC_H_ + +namespace ceres { +namespace internal { + +// The following macros are used to share code +#define CERES_GEMM_OPT_NAIVE_HEADER \ + double c0 = 0.0; \ + double c1 = 0.0; \ + double c2 = 0.0; \ + double c3 = 0.0; \ + const double* pa = a; \ + const double* pb = b; \ + const int span = 4; \ + int col_r = col_a & (span - 1); \ + int col_m = col_a - col_r; + +#define CERES_GEMM_OPT_STORE_MAT1X4 \ + if (kOperation > 0) { \ + *c++ += c0; \ + *c++ += c1; \ + *c++ += c2; \ + *c++ += c3; \ + } else if (kOperation < 0) { \ + *c++ -= c0; \ + *c++ -= c1; \ + *c++ -= c2; \ + *c++ -= c3; \ + } else { \ + *c++ = c0; \ + *c++ = c1; \ + *c++ = c2; \ + *c++ = c3; \ + } + +// Matrix-Matrix Multiplication +// Figure out 1x4 of Matrix C in one batch +// +// c op a * B; +// where op can be +=, -=, or =, indicated by kOperation. +// +// Matrix C Matrix A Matrix B +// +// C0, C1, C2, C3 op A0, A1, A2, A3, ... * B0, B1, B2, B3 +// B4, B5, B6, B7 +// B8, B9, Ba, Bb +// Bc, Bd, Be, Bf +// . , . , . , . +// . , . , . , . +// . , . , . , . +// +// unroll for loops +// utilize the data resided in cache +// NOTE: col_a means the columns of A +static inline void MMM_mat1x4(const int col_a, + const double* a, + const double* b, + const int col_stride_b, + double* c, + const int kOperation) { + CERES_GEMM_OPT_NAIVE_HEADER + double av = 0.0; + int bi = 0; + +#define CERES_GEMM_OPT_MMM_MAT1X4_MUL \ + av = pa[k]; \ + pb = b + bi; \ + c0 += av * *pb++; \ + c1 += av * *pb++; \ + c2 += av * *pb++; \ + c3 += av * *pb++; \ + bi += col_stride_b; \ + k++; + + for (int k = 0; k < col_m;) { + CERES_GEMM_OPT_MMM_MAT1X4_MUL + CERES_GEMM_OPT_MMM_MAT1X4_MUL + CERES_GEMM_OPT_MMM_MAT1X4_MUL + CERES_GEMM_OPT_MMM_MAT1X4_MUL + } + + for (int k = col_m; k < col_a;) { + CERES_GEMM_OPT_MMM_MAT1X4_MUL + } + + CERES_GEMM_OPT_STORE_MAT1X4 + +#undef CERES_GEMM_OPT_MMM_MAT1X4_MUL +} + +// Matrix Transpose-Matrix multiplication +// Figure out 1x4 of Matrix C in one batch +// +// c op a' * B; +// where op can be +=, -=, or = indicated by kOperation. +// +// Matrix A +// +// A0 +// A1 +// A2 +// A3 +// . +// . +// . +// +// Matrix C Matrix A' Matrix B +// +// C0, C1, C2, C3 op A0, A1, A2, A3, ... * B0, B1, B2, B3 +// B4, B5, B6, B7 +// B8, B9, Ba, Bb +// Bc, Bd, Be, Bf +// . , . , . , . +// . , . , . , . +// . , . , . , . +// +// unroll for loops +// utilize the data resided in cache +// NOTE: col_a means the columns of A' +static inline void MTM_mat1x4(const int col_a, + const double* a, + const int col_stride_a, + const double* b, + const int col_stride_b, + double* c, + const int kOperation) { + CERES_GEMM_OPT_NAIVE_HEADER + double av = 0.0; + int ai = 0; + int bi = 0; + +#define CERES_GEMM_OPT_MTM_MAT1X4_MUL \ + av = pa[ai]; \ + pb = b + bi; \ + c0 += av * *pb++; \ + c1 += av * *pb++; \ + c2 += av * *pb++; \ + c3 += av * *pb++; \ + ai += col_stride_a; \ + bi += col_stride_b; + + for (int k = 0; k < col_m; k += span) { + CERES_GEMM_OPT_MTM_MAT1X4_MUL + CERES_GEMM_OPT_MTM_MAT1X4_MUL + CERES_GEMM_OPT_MTM_MAT1X4_MUL + CERES_GEMM_OPT_MTM_MAT1X4_MUL + } + + for (int k = col_m; k < col_a; k++) { + CERES_GEMM_OPT_MTM_MAT1X4_MUL + } + + CERES_GEMM_OPT_STORE_MAT1X4 + +#undef CERES_GEMM_OPT_MTM_MAT1X4_MUL +} + +// Matrix-Vector Multiplication +// Figure out 4x1 of vector c in one batch +// +// c op A * b; +// where op can be +=, -=, or =, indicated by kOperation. +// +// Vector c Matrix A Vector b +// +// C0 op A0, A1, A2, A3, ... * B0 +// C1 A4, A5, A6, A7, ... B1 +// C2 A8, A9, Aa, Ab, ... B2 +// C3 Ac, Ad, Ae, Af, ... B3 +// . +// . +// . +// +// unroll for loops +// utilize the data resided in cache +// NOTE: col_a means the columns of A +static inline void MVM_mat4x1(const int col_a, + const double* a, + const int col_stride_a, + const double* b, + double* c, + const int kOperation) { + CERES_GEMM_OPT_NAIVE_HEADER + double bv = 0.0; + +#define CERES_GEMM_OPT_MVM_MAT4X1_MUL \ + bv = *pb; \ + c0 += *(pa ) * bv; \ + c1 += *(pa + col_stride_a ) * bv; \ + c2 += *(pa + col_stride_a * 2) * bv; \ + c3 += *(pa + col_stride_a * 3) * bv; \ + pa++; \ + pb++; + + for (int k = 0; k < col_m; k += span) { + CERES_GEMM_OPT_MVM_MAT4X1_MUL + CERES_GEMM_OPT_MVM_MAT4X1_MUL + CERES_GEMM_OPT_MVM_MAT4X1_MUL + CERES_GEMM_OPT_MVM_MAT4X1_MUL + } + + for (int k = col_m; k < col_a; k++) { + CERES_GEMM_OPT_MVM_MAT4X1_MUL + } + + CERES_GEMM_OPT_STORE_MAT1X4 + +#undef CERES_GEMM_OPT_MVM_MAT4X1_MUL +} + +// Matrix Transpose-Vector multiplication +// Figure out 4x1 of vector c in one batch +// +// c op A' * b; +// where op can be +=, -=, or =, indicated by kOperation. +// +// Matrix A +// +// A0, A4, A8, Ac +// A1, A5, A9, Ad +// A2, A6, Aa, Ae +// A3, A7, Ab, Af +// . , . , . , . +// . , . , . , . +// . , . , . , . +// +// Vector c Matrix A' Vector b +// +// C0 op A0, A1, A2, A3, ... * B0 +// C1 A4, A5, A6, A7, ... B1 +// C2 A8, A9, Aa, Ab, ... B2 +// C3 Ac, Ad, Ae, Af, ... B3 +// . +// . +// . +// +// unroll for loops +// utilize the data resided in cache +// NOTE: col_a means the columns of A' +static inline void MTV_mat4x1(const int col_a, + const double* a, + const int col_stride_a, + const double* b, + double* c, + const int kOperation) { + CERES_GEMM_OPT_NAIVE_HEADER + double bv = 0.0; + +#define CERES_GEMM_OPT_MTV_MAT4X1_MUL \ + bv = *pb; \ + c0 += *(pa ) * bv; \ + c1 += *(pa + 1) * bv; \ + c2 += *(pa + 2) * bv; \ + c3 += *(pa + 3) * bv; \ + pa += col_stride_a; \ + pb++; + + for (int k = 0; k < col_m; k += span) { + CERES_GEMM_OPT_MTV_MAT4X1_MUL + CERES_GEMM_OPT_MTV_MAT4X1_MUL + CERES_GEMM_OPT_MTV_MAT4X1_MUL + CERES_GEMM_OPT_MTV_MAT4X1_MUL + } + + for (int k = col_m; k < col_a; k++) { + CERES_GEMM_OPT_MTV_MAT4X1_MUL + } + + CERES_GEMM_OPT_STORE_MAT1X4 + +#undef CERES_GEMM_OPT_MTV_MAT4X1_MUL +} + +#undef CERES_GEMM_OPT_NAIVE_HEADER +#undef CERES_GEMM_OPT_STORE_MAT1X4 + +} // namespace internal +} // namespace ceres + +#endif // CERES_INTERNAL_SMALL_BLAS_GENERIC_H_