Optimization for custom small blas multiplication with dynamic template parameters in C level. - unroll for loops - matrix access more cache coherent - platform independant Briefly, this commit brings 1~50% performance improvments for most cases in small_blas_gem(m/v)_benchmark, but a small drop for corner cases with small dimensions especially 1,2,3. Here we list the results partially, which show decrease percentage of executing time, compared to unoptimized version. Platform: desktop PC (i7-7700 CPU MP8@3.60GHz + ubuntu 17.10) (Lenovo Research Device+ Lab, <yangfan34@lenovo.com>) Benchmark Time CPU ----------------------------------------------------------- BM_MatrixMatrixMultiplyDynamic/2/2/2 -0.1082 -0.1083 BM_MatrixMatrixMultiplyDynamic/2/2/15 -0.1270 -0.1270 BM_MatrixMatrixMultiplyDynamic/2/4/2 -0.1433 -0.1433 BM_MatrixMatrixMultiplyDynamic/2/4/15 -0.2069 -0.2068 BM_MatrixMatrixMultiplyDynamic/2/6/2 -0.1446 -0.1446 BM_MatrixMatrixMultiplyDynamic/2/6/15 -0.2156 -0.2156 BM_MatrixMatrixMultiplyDynamic/2/8/2 -0.1788 -0.1788 BM_MatrixMatrixMultiplyDynamic/2/8/15 -0.3316 -0.3316 BM_MatrixMatrixMultiplyDynamic/2/10/2 -0.2025 -0.2025 BM_MatrixMatrixMultiplyDynamic/2/10/15 -0.3444 -0.3444 BM_MatrixMatrixMultiplyDynamic/2/12/2 -0.0515 -0.0515 BM_MatrixMatrixMultiplyDynamic/2/12/15 -0.3733 -0.3733 BM_MatrixMatrixMultiplyDynamic/2/15/2 -0.2784 -0.2784 BM_MatrixMatrixMultiplyDynamic/2/15/15 -0.3704 -0.3704 BM_MatrixMatrixMultiplyDynamic/4/2/2 -0.1839 -0.1839 BM_MatrixMatrixMultiplyDynamic/4/2/15 -0.1922 -0.1922 BM_MatrixMatrixMultiplyDynamic/4/4/2 -0.2248 -0.2248 BM_MatrixMatrixMultiplyDynamic/4/4/15 -0.3132 -0.3132 BM_MatrixMatrixMultiplyDynamic/4/6/2 -0.2311 -0.2311 BM_MatrixMatrixMultiplyDynamic/4/6/15 -0.3239 -0.3239 BM_MatrixMatrixMultiplyDynamic/4/8/2 -0.0574 -0.0574 BM_MatrixMatrixMultiplyDynamic/4/8/15 -0.4173 -0.4173 BM_MatrixMatrixMultiplyDynamic/4/10/2 -0.2861 -0.2861 BM_MatrixMatrixMultiplyDynamic/4/10/15 -0.4065 -0.4064 BM_MatrixMatrixMultiplyDynamic/4/12/2 -0.2976 -0.2975 BM_MatrixMatrixMultiplyDynamic/4/12/15 -0.4218 -0.4218 BM_MatrixMatrixMultiplyDynamic/4/15/2 -0.3116 -0.3116 BM_MatrixMatrixMultiplyDynamic/4/15/15 -0.4242 -0.4241 BM_MatrixMatrixMultiplyDynamic/8/12/2 -0.3675 -0.3674 BM_MatrixMatrixMultiplyDynamic/8/12/4 -0.5055 -0.5055 BM_MatrixMatrixMultiplyDynamic/8/12/6 -0.4302 -0.4302 BM_MatrixMatrixMultiplyDynamic/8/12/8 -0.4854 -0.4854 BM_MatrixMatrixMultiplyDynamic/8/12/10 -0.4882 -0.4882 BM_MatrixMatrixMultiplyDynamic/8/12/12 -0.5209 -0.5209 BM_MatrixMatrixMultiplyDynamic/8/12/15 -0.4558 -0.4558 BM_MatrixMatrixMultiplyDynamic/8/15/2 -0.2319 -0.2319 BM_MatrixMatrixMultiplyDynamic/8/15/4 -0.5105 -0.5105 BM_MatrixMatrixMultiplyDynamic/8/15/6 -0.4477 -0.4477 BM_MatrixMatrixMultiplyDynamic/8/15/8 -0.5479 -0.5479 BM_MatrixMatrixMultiplyDynamic/8/15/10 -0.4843 -0.4843 BM_MatrixMatrixMultiplyDynamic/8/15/12 -0.5212 -0.5212 BM_MatrixMatrixMultiplyDynamic/8/15/15 -0.4459 -0.4459 BM_MatrixVectorMultiply/1/1 +0.0978 +0.0978 BM_MatrixVectorMultiply/1/2 +0.0551 +0.0551 BM_MatrixVectorMultiply/1/3 -0.0019 -0.0020 BM_MatrixVectorMultiply/1/4 +0.0563 +0.0562 BM_MatrixVectorMultiply/1/6 +0.1379 +0.1379 BM_MatrixVectorMultiply/1/7 +0.1090 +0.1090 BM_MatrixVectorMultiply/1/12 +0.0901 +0.0901 BM_MatrixVectorMultiply/1/16 +0.0493 +0.0493 BM_MatrixVectorMultiply/1/20 +0.2255 +0.2255 BM_MatrixVectorMultiply/2/1 +0.1261 +0.1261 BM_MatrixVectorMultiply/2/2 +0.2328 +0.2328 BM_MatrixVectorMultiply/2/3 +0.1404 +0.1403 BM_MatrixVectorMultiply/2/4 +0.0257 +0.0256 BM_MatrixVectorMultiply/2/6 -0.1691 -0.1691 BM_MatrixVectorMultiply/2/7 -0.2619 -0.2619 BM_MatrixVectorMultiply/2/12 -0.4261 -0.4261 BM_MatrixVectorMultiply/2/16 -0.5387 -0.5387 BM_MatrixVectorMultiply/2/20 -0.6171 -0.6171 BM_MatrixVectorMultiply/3/1 +0.1664 +0.1664 BM_MatrixVectorMultiply/3/2 +0.0848 +0.0848 BM_MatrixVectorMultiply/3/3 -0.0044 -0.0044 BM_MatrixVectorMultiply/3/4 -0.0683 -0.0684 BM_MatrixVectorMultiply/3/6 -0.1652 -0.1652 BM_MatrixVectorMultiply/3/7 -0.1633 -0.1633 BM_MatrixVectorMultiply/3/12 -0.1921 -0.1921 BM_MatrixVectorMultiply/3/16 -0.3659 -0.3659 BM_MatrixVectorMultiply/3/20 -0.4137 -0.4137 BM_MatrixVectorMultiply/4/1 -0.0577 -0.0577 BM_MatrixVectorMultiply/4/2 -0.1337 -0.1338 BM_MatrixVectorMultiply/4/3 -0.1443 -0.1443 BM_MatrixVectorMultiply/4/4 +0.0013 +0.0013 BM_MatrixVectorMultiply/4/6 -0.1071 -0.1071 BM_MatrixVectorMultiply/4/7 -0.1396 -0.1397 BM_MatrixVectorMultiply/4/12 -0.2792 -0.2792 BM_MatrixVectorMultiply/4/16 -0.4485 -0.4486 BM_MatrixVectorMultiply/4/20 -0.3588 -0.3588 Change-Id: I64a8cf11391e3d06341a2b8764cd1b4f1b8a23f1
diff --git a/internal/ceres/small_blas.h b/internal/ceres/small_blas.h index 264ac53..34b4ec7 100644 --- a/internal/ceres/small_blas.h +++ b/internal/ceres/small_blas.h
@@ -38,6 +38,7 @@ #include "ceres/internal/port.h" #include "ceres/internal/eigen.h" #include "glog/logging.h" +#include "small_blas_generic.h" namespace ceres { namespace internal { @@ -89,6 +90,26 @@ B, num_row_b, num_col_b, \ C, start_row_c, start_col_c, row_stride_c, col_stride_c); +#define CERES_GEMM_STORE_SINGLE(p, index, value) \ + if (kOperation > 0) { \ + p[index] += value; \ + } else if (kOperation < 0) { \ + p[index] -= value; \ + } else { \ + p[index] = value; \ + } + +#define CERES_GEMM_STORE_PAIR(p, index, v1, v2) \ + if (kOperation > 0) { \ + p[index] += v1; \ + p[index + 1] += v2; \ + } else if (kOperation < 0) { \ + p[index] -= v1; \ + p[index + 1] -= v2; \ + } else { \ + p[index] = v1; \ + p[index + 1] = v2; \ + } // For the matrix-matrix functions below, there are three variants for // each functionality. Foo, FooNaive and FooEigen. Foo is the one to @@ -160,24 +181,64 @@ const int NUM_COL_C = NUM_COL_B; DCHECK_LE(start_row_c + NUM_ROW_C, row_stride_c); DCHECK_LE(start_col_c + NUM_COL_C, col_stride_c); + const int span = 4; - for (int row = 0; row < NUM_ROW_C; ++row) { - for (int col = 0; col < NUM_COL_C; ++col) { + // Calculate the remainder part first. + + // Process the last odd column if present. + if (NUM_COL_C & 1) { + int col = NUM_COL_C - 1; + const double* pa = &A[0]; + for (int row = 0; row < NUM_ROW_C; ++row, pa += NUM_COL_A) { + const double* pb = &B[col]; double tmp = 0.0; - for (int k = 0; k < NUM_COL_A; ++k) { - tmp += A[row * NUM_COL_A + k] * B[k * NUM_COL_B + col]; + for (int k = 0; k < NUM_COL_A; ++k, pb += NUM_COL_B) { + tmp += pa[k] * pb[0]; } const int index = (row + start_row_c) * col_stride_c + start_col_c + col; - if (kOperation > 0) { - C[index] += tmp; - } else if (kOperation < 0) { - C[index] -= tmp; - } else { - C[index] = tmp; - } + CERES_GEMM_STORE_SINGLE(C, index, tmp); + } + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_C == 1) { + return; } } + + // Process the couple columns in remainder if present. + if (NUM_COL_C & 2) { + int col = NUM_COL_C & (int)(~(span - 1)) ; + const double* pa = &A[0]; + for (int row = 0; row < NUM_ROW_C; ++row, pa += NUM_COL_A) { + const double* pb = &B[col]; + double tmp1 = 0.0, tmp2 = 0.0; + for (int k = 0; k < NUM_COL_A; ++k, pb += NUM_COL_B) { + double av = pa[k]; + tmp1 += av * pb[0]; + tmp2 += av * pb[1]; + } + + const int index = (row + start_row_c) * col_stride_c + start_col_c + col; + CERES_GEMM_STORE_PAIR(C, index, tmp1, tmp2); + } + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_C < span) { + return; + } + } + + // Calculate the main part with multiples of 4. + int col_m = NUM_COL_C & (int)(~(span - 1)); + for (int col = 0; col < col_m; col += span) { + for (int row = 0; row < NUM_ROW_C; ++row) { + const int index = (row + start_row_c) * col_stride_c + start_col_c + col; + MMM_mat1x4(NUM_COL_A, &A[row * NUM_COL_A], + &B[col], NUM_COL_B, &C[index], kOperation); + } + } + } CERES_GEMM_BEGIN(MatrixMatrixMultiply) { @@ -220,24 +281,68 @@ const int NUM_COL_C = NUM_COL_B; DCHECK_LE(start_row_c + NUM_ROW_C, row_stride_c); DCHECK_LE(start_col_c + NUM_COL_C, col_stride_c); + const int span = 4; - for (int row = 0; row < NUM_ROW_C; ++row) { - for (int col = 0; col < NUM_COL_C; ++col) { + // Process the remainder part first. + + // Process the last odd column if present. + if (NUM_COL_C & 1) { + int col = NUM_COL_C - 1; + for (int row = 0; row < NUM_ROW_C; ++row) { + const double* pa = &A[row]; + const double* pb = &B[col]; double tmp = 0.0; for (int k = 0; k < NUM_ROW_A; ++k) { - tmp += A[k * NUM_COL_A + row] * B[k * NUM_COL_B + col]; + tmp += pa[0] * pb[0]; + pa += NUM_COL_A; + pb += NUM_COL_B; } const int index = (row + start_row_c) * col_stride_c + start_col_c + col; - if (kOperation > 0) { - C[index]+= tmp; - } else if (kOperation < 0) { - C[index]-= tmp; - } else { - C[index]= tmp; - } + CERES_GEMM_STORE_SINGLE(C, index, tmp); + } + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_C == 1) { + return; } } + + // Process the couple columns in remainder if present. + if (NUM_COL_C & 2) { + int col = NUM_COL_C & (int)(~(span - 1)) ; + for (int row = 0; row < NUM_ROW_C; ++row) { + const double* pa = &A[row]; + const double* pb = &B[col]; + double tmp1 = 0.0, tmp2 = 0.0; + for (int k = 0; k < NUM_ROW_A; ++k) { + double av = *pa; + tmp1 += av * pb[0]; + tmp2 += av * pb[1]; + pa += NUM_COL_A; + pb += NUM_COL_B; + } + + const int index = (row + start_row_c) * col_stride_c + start_col_c + col; + CERES_GEMM_STORE_PAIR(C, index, tmp1, tmp2); + } + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_C < span) { + return; + } + } + + // Process the main part with multiples of 4. + int col_m = NUM_COL_C & (int)(~(span - 1)); + for (int col = 0; col < col_m; col += span) { + for (int row = 0; row < NUM_ROW_C; ++row) { + const int index = (row + start_row_c) * col_stride_c + start_col_c + col; + MTM_mat1x4(NUM_ROW_A, &A[row], NUM_COL_A, + &B[col], NUM_COL_B, &C[index], kOperation); + } + } + } CERES_GEMM_BEGIN(MatrixTransposeMatrixMultiply) { @@ -301,21 +406,54 @@ const int NUM_ROW_A = (kRowA != Eigen::Dynamic ? kRowA : num_row_a); const int NUM_COL_A = (kColA != Eigen::Dynamic ? kColA : num_col_a); + const int span = 4; - for (int row = 0; row < NUM_ROW_A; ++row) { + // Calculate the remainder part first. + + // Process the last odd row if present. + if (NUM_ROW_A & 1) { + int row = NUM_ROW_A - 1; + const double* pa = &A[row * NUM_COL_A]; + const double* pb = &b[0]; double tmp = 0.0; for (int col = 0; col < NUM_COL_A; ++col) { - tmp += A[row * NUM_COL_A + col] * b[col]; + tmp += (*pa++) * (*pb++); } + CERES_GEMM_STORE_SINGLE(c, row, tmp); - if (kOperation > 0) { - c[row] += tmp; - } else if (kOperation < 0) { - c[row] -= tmp; - } else { - c[row] = tmp; + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_ROW_A == 1) { + return; } } + + // Process the couple rows in remainder if present. + if (NUM_ROW_A & 2) { + int row = NUM_ROW_A & (int)(~(span - 1)); + const double* pa1 = &A[row * NUM_COL_A]; + const double* pa2 = pa1 + NUM_COL_A; + const double* pb = &b[0]; + double tmp1 = 0.0, tmp2 = 0.0; + for (int col = 0; col < NUM_ROW_A; ++col) { + double bv = *pb++; + tmp1 += *(pa1++) * bv; + tmp2 += *(pa2++) * bv; + } + CERES_GEMM_STORE_PAIR(c, row, tmp1, tmp2); + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_ROW_A < span) { + return; + } + } + + // Calculate the main part with multiples of 4. + int row_m = NUM_ROW_A & (int)(~(span - 1)); + for (int row = 0; row < row_m; row += span) { + MVM_mat4x1(NUM_COL_A, &A[row * NUM_COL_A], NUM_COL_A, + &b[0], &c[row], kOperation); + } + #endif // CERES_NO_CUSTOM_BLAS } @@ -352,21 +490,55 @@ const int NUM_ROW_A = (kRowA != Eigen::Dynamic ? kRowA : num_row_a); const int NUM_COL_A = (kColA != Eigen::Dynamic ? kColA : num_col_a); + const int span = 4; - for (int row = 0; row < NUM_COL_A; ++row) { + // Calculate the remainder part first. + + // Process the last odd column if present. + if (NUM_COL_A & 1) { + int row = NUM_COL_A - 1; + const double* pa = &A[row]; + const double* pb = &b[0]; double tmp = 0.0; for (int col = 0; col < NUM_ROW_A; ++col) { - tmp += A[col * NUM_COL_A + row] * b[col]; + tmp += *pa * (*pb++); + pa += NUM_COL_A; } + CERES_GEMM_STORE_SINGLE(c, row, tmp); - if (kOperation > 0) { - c[row] += tmp; - } else if (kOperation < 0) { - c[row] -= tmp; - } else { - c[row] = tmp; + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_A == 1) { + return; } } + + // Process the couple columns in remainder if present. + if (NUM_COL_A & 2) { + int row = NUM_COL_A & (int)(~(span - 1)); + const double* pa = &A[row]; + const double* pb = &b[0]; + double tmp1 = 0.0, tmp2 = 0.0; + for (int col = 0; col < NUM_ROW_A; ++col) { + double bv = *pb++; + tmp1 += *(pa ) * bv; + tmp2 += *(pa + 1) * bv; + pa += NUM_COL_A; + } + CERES_GEMM_STORE_PAIR(c, row, tmp1, tmp2); + + // Return directly for efficiency of extremely small matrix multiply. + if (NUM_COL_A < span) { + return; + } + } + + // Calculate the main part with multiples of 4. + int row_m = NUM_COL_A & (int)(~(span - 1)); + for (int row = 0; row < row_m; row += span) { + MTV_mat4x1(NUM_ROW_A, &A[row], NUM_COL_A, + &b[0], &c[row], kOperation); + } + #endif // CERES_NO_CUSTOM_BLAS } @@ -374,6 +546,8 @@ #undef CERES_GEMM_EIGEN_HEADER #undef CERES_GEMM_NAIVE_HEADER #undef CERES_CALL_GEMM +#undef CERES_GEMM_STORE_SINGLE +#undef CERES_GEMM_STORE_PAIR } // namespace internal } // namespace ceres
diff --git a/internal/ceres/small_blas_generic.h b/internal/ceres/small_blas_generic.h new file mode 100644 index 0000000..978c5d5 --- /dev/null +++ b/internal/ceres/small_blas_generic.h
@@ -0,0 +1,315 @@ +// Ceres Solver - A fast non-linear least squares minimizer +// Copyright 2018 Google Inc. All rights reserved. +// http://ceres-solver.org/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// * Redistributions of source code must retain the above copyright notice, +// this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// * Neither the name of Google Inc. nor the names of its contributors may be +// used to endorse or promote products derived from this software without +// specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. +// +// Author: yangfan34@lenovo.com (Lenovo Research Device+ Lab - Shanghai) +// +// Optimization for simple blas functions used in the Schur Eliminator. +// These are fairly basic implementations which already yield a significant +// speedup in the eliminator performance. + +#ifndef CERES_INTERNAL_SMALL_BLAS_GENERIC_H_ +#define CERES_INTERNAL_SMALL_BLAS_GENERIC_H_ + +namespace ceres { +namespace internal { + +// The following macros are used to share code +#define CERES_GEMM_OPT_NAIVE_HEADER \ + double c0 = 0.0; \ + double c1 = 0.0; \ + double c2 = 0.0; \ + double c3 = 0.0; \ + const double* pa = a; \ + const double* pb = b; \ + const int span = 4; \ + int col_r = col_a & (span - 1); \ + int col_m = col_a - col_r; + +#define CERES_GEMM_OPT_STORE_MAT1X4 \ + if (kOperation > 0) { \ + *c++ += c0; \ + *c++ += c1; \ + *c++ += c2; \ + *c++ += c3; \ + } else if (kOperation < 0) { \ + *c++ -= c0; \ + *c++ -= c1; \ + *c++ -= c2; \ + *c++ -= c3; \ + } else { \ + *c++ = c0; \ + *c++ = c1; \ + *c++ = c2; \ + *c++ = c3; \ + } + +// Matrix-Matrix Multiplication +// Figure out 1x4 of Matrix C in one batch +// +// c op a * B; +// where op can be +=, -=, or =, indicated by kOperation. +// +// Matrix C Matrix A Matrix B +// +// C0, C1, C2, C3 op A0, A1, A2, A3, ... * B0, B1, B2, B3 +// B4, B5, B6, B7 +// B8, B9, Ba, Bb +// Bc, Bd, Be, Bf +// . , . , . , . +// . , . , . , . +// . , . , . , . +// +// unroll for loops +// utilize the data resided in cache +// NOTE: col_a means the columns of A +static inline void MMM_mat1x4(const int col_a, + const double* a, + const double* b, + const int col_stride_b, + double* c, + const int kOperation) { + CERES_GEMM_OPT_NAIVE_HEADER + double av = 0.0; + int bi = 0; + +#define CERES_GEMM_OPT_MMM_MAT1X4_MUL \ + av = pa[k]; \ + pb = b + bi; \ + c0 += av * *pb++; \ + c1 += av * *pb++; \ + c2 += av * *pb++; \ + c3 += av * *pb++; \ + bi += col_stride_b; \ + k++; + + for (int k = 0; k < col_m;) { + CERES_GEMM_OPT_MMM_MAT1X4_MUL + CERES_GEMM_OPT_MMM_MAT1X4_MUL + CERES_GEMM_OPT_MMM_MAT1X4_MUL + CERES_GEMM_OPT_MMM_MAT1X4_MUL + } + + for (int k = col_m; k < col_a;) { + CERES_GEMM_OPT_MMM_MAT1X4_MUL + } + + CERES_GEMM_OPT_STORE_MAT1X4 + +#undef CERES_GEMM_OPT_MMM_MAT1X4_MUL +} + +// Matrix Transpose-Matrix multiplication +// Figure out 1x4 of Matrix C in one batch +// +// c op a' * B; +// where op can be +=, -=, or = indicated by kOperation. +// +// Matrix A +// +// A0 +// A1 +// A2 +// A3 +// . +// . +// . +// +// Matrix C Matrix A' Matrix B +// +// C0, C1, C2, C3 op A0, A1, A2, A3, ... * B0, B1, B2, B3 +// B4, B5, B6, B7 +// B8, B9, Ba, Bb +// Bc, Bd, Be, Bf +// . , . , . , . +// . , . , . , . +// . , . , . , . +// +// unroll for loops +// utilize the data resided in cache +// NOTE: col_a means the columns of A' +static inline void MTM_mat1x4(const int col_a, + const double* a, + const int col_stride_a, + const double* b, + const int col_stride_b, + double* c, + const int kOperation) { + CERES_GEMM_OPT_NAIVE_HEADER + double av = 0.0; + int ai = 0; + int bi = 0; + +#define CERES_GEMM_OPT_MTM_MAT1X4_MUL \ + av = pa[ai]; \ + pb = b + bi; \ + c0 += av * *pb++; \ + c1 += av * *pb++; \ + c2 += av * *pb++; \ + c3 += av * *pb++; \ + ai += col_stride_a; \ + bi += col_stride_b; + + for (int k = 0; k < col_m; k += span) { + CERES_GEMM_OPT_MTM_MAT1X4_MUL + CERES_GEMM_OPT_MTM_MAT1X4_MUL + CERES_GEMM_OPT_MTM_MAT1X4_MUL + CERES_GEMM_OPT_MTM_MAT1X4_MUL + } + + for (int k = col_m; k < col_a; k++) { + CERES_GEMM_OPT_MTM_MAT1X4_MUL + } + + CERES_GEMM_OPT_STORE_MAT1X4 + +#undef CERES_GEMM_OPT_MTM_MAT1X4_MUL +} + +// Matrix-Vector Multiplication +// Figure out 4x1 of vector c in one batch +// +// c op A * b; +// where op can be +=, -=, or =, indicated by kOperation. +// +// Vector c Matrix A Vector b +// +// C0 op A0, A1, A2, A3, ... * B0 +// C1 A4, A5, A6, A7, ... B1 +// C2 A8, A9, Aa, Ab, ... B2 +// C3 Ac, Ad, Ae, Af, ... B3 +// . +// . +// . +// +// unroll for loops +// utilize the data resided in cache +// NOTE: col_a means the columns of A +static inline void MVM_mat4x1(const int col_a, + const double* a, + const int col_stride_a, + const double* b, + double* c, + const int kOperation) { + CERES_GEMM_OPT_NAIVE_HEADER + double bv = 0.0; + +#define CERES_GEMM_OPT_MVM_MAT4X1_MUL \ + bv = *pb; \ + c0 += *(pa ) * bv; \ + c1 += *(pa + col_stride_a ) * bv; \ + c2 += *(pa + col_stride_a * 2) * bv; \ + c3 += *(pa + col_stride_a * 3) * bv; \ + pa++; \ + pb++; + + for (int k = 0; k < col_m; k += span) { + CERES_GEMM_OPT_MVM_MAT4X1_MUL + CERES_GEMM_OPT_MVM_MAT4X1_MUL + CERES_GEMM_OPT_MVM_MAT4X1_MUL + CERES_GEMM_OPT_MVM_MAT4X1_MUL + } + + for (int k = col_m; k < col_a; k++) { + CERES_GEMM_OPT_MVM_MAT4X1_MUL + } + + CERES_GEMM_OPT_STORE_MAT1X4 + +#undef CERES_GEMM_OPT_MVM_MAT4X1_MUL +} + +// Matrix Transpose-Vector multiplication +// Figure out 4x1 of vector c in one batch +// +// c op A' * b; +// where op can be +=, -=, or =, indicated by kOperation. +// +// Matrix A +// +// A0, A4, A8, Ac +// A1, A5, A9, Ad +// A2, A6, Aa, Ae +// A3, A7, Ab, Af +// . , . , . , . +// . , . , . , . +// . , . , . , . +// +// Vector c Matrix A' Vector b +// +// C0 op A0, A1, A2, A3, ... * B0 +// C1 A4, A5, A6, A7, ... B1 +// C2 A8, A9, Aa, Ab, ... B2 +// C3 Ac, Ad, Ae, Af, ... B3 +// . +// . +// . +// +// unroll for loops +// utilize the data resided in cache +// NOTE: col_a means the columns of A' +static inline void MTV_mat4x1(const int col_a, + const double* a, + const int col_stride_a, + const double* b, + double* c, + const int kOperation) { + CERES_GEMM_OPT_NAIVE_HEADER + double bv = 0.0; + +#define CERES_GEMM_OPT_MTV_MAT4X1_MUL \ + bv = *pb; \ + c0 += *(pa ) * bv; \ + c1 += *(pa + 1) * bv; \ + c2 += *(pa + 2) * bv; \ + c3 += *(pa + 3) * bv; \ + pa += col_stride_a; \ + pb++; + + for (int k = 0; k < col_m; k += span) { + CERES_GEMM_OPT_MTV_MAT4X1_MUL + CERES_GEMM_OPT_MTV_MAT4X1_MUL + CERES_GEMM_OPT_MTV_MAT4X1_MUL + CERES_GEMM_OPT_MTV_MAT4X1_MUL + } + + for (int k = col_m; k < col_a; k++) { + CERES_GEMM_OPT_MTV_MAT4X1_MUL + } + + CERES_GEMM_OPT_STORE_MAT1X4 + +#undef CERES_GEMM_OPT_MTV_MAT4X1_MUL +} + +#undef CERES_GEMM_OPT_NAIVE_HEADER +#undef CERES_GEMM_OPT_STORE_MAT1X4 + +} // namespace internal +} // namespace ceres + +#endif // CERES_INTERNAL_SMALL_BLAS_GENERIC_H_