From 34d384db8a8a5d174dcc5b3e5140b9ffef1af0b3 Mon Sep 17 00:00:00 2001
From: har0ke <dev@ohargens.de>
Date: Sun, 5 Jul 2020 20:39:04 +0200
Subject: [PATCH] Speed up remainder

---
 scripts/test.py                               |   5 +-
 src/register_blocking/BlockWise.h             |  65 +--------
 .../detail/BlockWiseConfigs.h                 | 134 ++++++++++++++++++
 src/register_blocking/detail/BlockWiseImpl.h  |  35 ++++-
 .../detail/RegisterBlockingManual.h           |  80 -----------
 5 files changed, 169 insertions(+), 150 deletions(-)
 create mode 100644 src/register_blocking/detail/BlockWiseConfigs.h
 delete mode 100644 src/register_blocking/detail/RegisterBlockingManual.h

diff --git a/scripts/test.py b/scripts/test.py
index a97102d..fabb2f1 100755
--- a/scripts/test.py
+++ b/scripts/test.py
@@ -73,6 +73,7 @@ if __name__ == '__main__':
     parser.add_argument("--function", type=str, nargs="*")
     parser.add_argument("--release", action="store_true")
     parser.add_argument("--no_manual", action="store_true")
+    parser.add_argument("--verbose", action="store_true")
 
     options = parser.parse_args()
 
@@ -125,7 +126,9 @@ if __name__ == '__main__':
             if with_double:
                 arguments.append("--double")
             output = compile_and_run("..", "builds", "simd_multiply", True, clang, options.avx512, options.release, options.no_manual, arguments + extra_args)
-            ms = output.decode()[output.decode().find("multiply:") + 10:]
+            if options.verbose:
+                print(output.decode("utf-8"))
+            ms = output.decode()[output.decode("utf-8").find("multiply:") + 10:]
             
             if not already_dumped:
                 build_path = os.path.join(get_build_path("..", "builds", "simd_multiply", True, clang, options.avx512, options.release, options.no_manual, arguments + extra_args)[1], "simd_multiply")
diff --git a/src/register_blocking/BlockWise.h b/src/register_blocking/BlockWise.h
index eb5d9ae..1080b1b 100644
--- a/src/register_blocking/BlockWise.h
+++ b/src/register_blocking/BlockWise.h
@@ -6,70 +6,7 @@
 #define SMID_MATRIX_BLOCKWISE_H
 
 #include "detail/BlockWiseImpl.h"
-
-#include <immintrin.h>
-
-struct __m128_block_wise_config {
-    using FloatType = float;
-    using VectorType = __m128;
-    static constexpr auto LoadVector = _mm_loadu_ps;
-    static constexpr auto StoreVector = _mm_storeu_ps;
-    static constexpr auto BroadcastToVector = _mm_set1_ps;
-    static constexpr auto XOR = _mm_xor_ps;
-    static constexpr unsigned Registers = 16;
-};
-
-struct __m128d_block_wise_config {
-    using FloatType = double;
-    using VectorType = __m128d;
-    static constexpr auto LoadVector = _mm_loadu_pd;
-    static constexpr auto StoreVector = _mm_storeu_pd;
-    static constexpr auto BroadcastToVector = _mm_set1_pd;
-    static constexpr auto XOR = _mm_xor_pd;
-    static constexpr unsigned Registers = 16;
-};
-
-struct __m256_block_wise_config {
-    using FloatType = float;
-    using VectorType = __m256;
-    static constexpr auto LoadVector = _mm256_loadu_ps;
-    static constexpr auto StoreVector = _mm256_storeu_ps;
-    static constexpr auto BroadcastToVector = _mm256_set1_ps;
-    static constexpr auto XOR = _mm256_xor_ps;
-    static constexpr unsigned Registers = 16;
-};
-
-struct __m256d_block_wise_config {
-    using FloatType = double;
-    using VectorType = __m256d;
-    static constexpr auto LoadVector = _mm256_loadu_pd;
-    static constexpr auto StoreVector = _mm256_storeu_pd;
-    static constexpr auto BroadcastToVector = _mm256_set1_pd;
-    static constexpr auto XOR = _mm256_xor_pd;
-    static constexpr unsigned Registers = 16;
-};
-
-#ifdef WITH_AVX512
-struct __m512_block_wise_config {
-    using FloatType = float;
-    using VectorType = __m512;
-    static constexpr auto LoadVector = _mm512_loadu_ps;
-    static constexpr auto StoreVector = _mm512_storeu_ps;
-    static constexpr auto BroadcastToVector = _mm512_set1_ps;
-    static constexpr auto XOR = _mm512_xor_ps;
-    static constexpr unsigned Registers = 32;
-};
-
-struct __m512d_block_wise_config {
-    using FloatType = double;
-    using VectorType = __m512d;
-    static constexpr auto LoadVector = _mm512_loadu_pd;
-    static constexpr auto StoreVector = _mm512_storeu_pd;
-    static constexpr auto BroadcastToVector = _mm512_set1_pd;
-    static constexpr auto XOR = _mm512_xor_pd;
-    static constexpr unsigned Registers = 32;
-};
-#endif
+#include "detail/BlockWiseConfigs.h"
 
 enum AvxVersion {
     SSE,
diff --git a/src/register_blocking/detail/BlockWiseConfigs.h b/src/register_blocking/detail/BlockWiseConfigs.h
new file mode 100644
index 0000000..04d98df
--- /dev/null
+++ b/src/register_blocking/detail/BlockWiseConfigs.h
@@ -0,0 +1,134 @@
+//
+// Created by oke on 05.07.20.
+//
+
+#ifndef SMID_MATRIX_BLOCKWISECONFIGS_H
+#define SMID_MATRIX_BLOCKWISECONFIGS_H
+
+#include <immintrin.h>
+
+template<typename Config, unsigned CurrentNumRows>
+struct masked_remaining_columns {
+
+    static constexpr bool IsCustom = true;
+
+    typedef ExtendedBlockWiseConfig<Config> bwc;
+
+    template<typename M1, typename M2, typename M3>
+    static void handle(size_t k, size_t m_i, size_t &n_i, size_t n, M1 &C, const M2 &A, const M3 &B) {
+        typename Config::Mask mask = 0;
+        mask = (~mask);
+        mask = mask << (n - n_i);
+        mask = ~mask;
+        constexpr typename bwc::VectorType zeros = { 0.0 };
+        typename bwc::VectorType CReg[CurrentNumRows] = { 0.0 };
+        for (int p = 0; p < k; p++) {
+            typename bwc::VectorType bb = Config::LoadVectorMasked(zeros, mask, &B(p, n_i));
+            for (int ai = 0; ai < CurrentNumRows; ai++) { // row index in A (handling regsA rows)
+                typename bwc::VectorType aa = bwc::BroadcastToVector(A(m_i + ai, p));
+                CReg[ai] += aa * bb;
+            }
+        }
+        for (int ai = 0; ai < CurrentNumRows; ai++) {
+            auto memory = &C(m_i + ai, n_i);
+            Config::StoreVectorMasked(memory, mask, Config::LoadVectorMasked(zeros, mask, memory) + CReg[ai]);
+        }
+    }
+
+};
+
+struct __m128_block_wise_config {
+    static constexpr bool has_sub_config = false;
+    using FloatType = float;
+    using VectorType = __m128;
+    static constexpr auto LoadVector = _mm_loadu_ps;
+    static constexpr auto StoreVector = _mm_storeu_ps;
+    static constexpr auto BroadcastToVector = _mm_set1_ps;
+    static constexpr auto XOR = _mm_xor_ps;
+    static constexpr unsigned Registers = 16;
+};
+
+struct __m128d_block_wise_config {
+    static constexpr bool has_sub_config = false;
+    using FloatType = double;
+    using VectorType = __m128d;
+    static constexpr auto LoadVector = _mm_loadu_pd;
+    static constexpr auto StoreVector = _mm_storeu_pd;
+    static constexpr auto BroadcastToVector = _mm_set1_pd;
+    static constexpr auto XOR = _mm_xor_pd;
+    static constexpr unsigned Registers = 16;
+};
+
+struct __m256_block_wise_config {
+    static constexpr bool has_sub_config = true;
+    using SubConfig = __m128_block_wise_config;
+
+    using FloatType = float;
+    using VectorType = __m256;
+    static constexpr auto LoadVector = _mm256_loadu_ps;
+    static constexpr auto StoreVector = _mm256_storeu_ps;
+    static constexpr auto BroadcastToVector = _mm256_set1_ps;
+    static constexpr auto XOR = _mm256_xor_ps;
+    static constexpr unsigned Registers = 16;
+};
+
+struct __m256d_block_wise_config {
+    static constexpr bool has_sub_config = true;
+    using SubConfig = __m128d_block_wise_config;
+    using FloatType = double;
+    using VectorType = __m256d;
+
+    static constexpr auto LoadVector = _mm256_loadu_pd;
+    static constexpr auto StoreVector = _mm256_storeu_pd;
+    static constexpr auto BroadcastToVector = _mm256_set1_pd;
+    static constexpr auto XOR = _mm256_xor_pd;
+    static constexpr unsigned Registers = 16;
+};
+
+#ifdef WITH_AVX512
+
+struct __m512_block_wise_config {
+    static constexpr bool has_sub_config = true;
+    using SubConfig = __m256_block_wise_config;
+    using FloatType = float;
+    using VectorType = __m512;
+    using Mask = __mmask16;
+    static constexpr auto LoadVectorMasked = _mm512_mask_loadu_ps;
+    static constexpr auto StoreVectorMasked = _mm512_mask_storeu_ps;
+    static constexpr auto LoadVector = _mm512_loadu_ps;
+    static constexpr auto StoreVector = _mm512_storeu_ps;
+    static constexpr auto BroadcastToVector = _mm512_set1_ps;
+    static constexpr auto XOR = _mm512_xor_ps;
+    static constexpr unsigned Registers = 32;
+};
+
+struct __m512d_block_wise_config {
+    static constexpr bool has_sub_config = true;
+    using SubConfig = __m256d_block_wise_config;
+    using FloatType = double;
+    using VectorType = __m512d;
+    using Mask = __mmask16;
+    static constexpr auto LoadVectorMasked = _mm512_mask_loadu_pd;
+    static constexpr auto StoreVectorMasked = _mm512_mask_storeu_pd;
+    static constexpr auto LoadVector = _mm512_loadu_pd;
+    static constexpr auto StoreVector = _mm512_storeu_pd;
+    static constexpr auto BroadcastToVector = _mm512_set1_pd;
+    static constexpr auto XOR = _mm512_xor_pd;
+    static constexpr unsigned Registers = 32;
+};
+
+template<typename BlockWiseConfig, unsigned CurrentNumRows>
+struct remaining_columns;
+
+template<unsigned CurrentNumRows>
+struct remaining_columns<__m512_block_wise_config, CurrentNumRows>
+        : masked_remaining_columns<__m512_block_wise_config, CurrentNumRows> {};
+
+template<unsigned CurrentNumRows>
+struct remaining_columns<__m512d_block_wise_config, CurrentNumRows>
+        : masked_remaining_columns<__m512d_block_wise_config, CurrentNumRows> {};
+
+#endif
+
+
+#endif //SMID_MATRIX_BLOCKWISECONFIGS_H
diff --git a/src/register_blocking/detail/BlockWiseImpl.h b/src/register_blocking/detail/BlockWiseImpl.h
index 544c3d3..a740e0e 100644
--- a/src/register_blocking/detail/BlockWiseImpl.h
+++ b/src/register_blocking/detail/BlockWiseImpl.h
@@ -7,6 +7,23 @@
 
 #include "RegisterBlocking.h"
 
+template<typename BlockWiseConfig, unsigned CurrentNumRows>
+struct remaining_columns {
+    static constexpr bool IsCustom = false;
+
+    template<typename M1, typename M2, typename M3>
+    static void handle(size_t k, size_t m_i, size_t &n_i, size_t n, M1 &C, const M2 &A, const M3 &B) {
+         for (auto m_i_o = 0; m_i_o < CurrentNumRows; ++m_i_o) {
+            for (int p = 0; p < k; ++p) {
+                for (auto n_i_rest = n_i; n_i_rest < n; ++n_i_rest) {
+                    C(m_i + m_i_o, n_i_rest) += A(m_i + m_i_o, p) * B(p, n_i_rest);
+                }
+            }
+        }
+    }
+};
+
+
 namespace detail {
 
     template<
@@ -89,11 +106,19 @@ namespace detail {
                 } else {
                     // do rest of columns manually
                     if (n_i < n) {
-                        for (auto m_i_o = 0; m_i_o < CurrentNumRows; ++m_i_o) {
-                            for (int p = 0; p < k; ++p) {
-                                for (auto n_i_rest = n_i; n_i_rest < n; ++n_i_rest) {
-                                    C(m_i + m_i_o, n_i_rest) += A(m_i + m_i_o, p) * B(p, n_i_rest);
-                                }
+                        // custom implementation of remaining columns? good, use it.
+                        if constexpr(remaining_columns<BlockWiseConfig, CurrentNumRows>::IsCustom) {
+                            remaining_columns<BlockWiseConfig, CurrentNumRows>::handle(k, m_i, n_i, n, C, A, B);
+                        } else {
+                            // otherwise: has sub_config with less vector width? use this first before going to fallback implementation
+                            if constexpr(BlockWiseConfig::has_sub_config) {
+                                typedef typename BlockWiseConfig::SubConfig SubConfig;
+                                typedef block_wise<SubConfig> SubBlockWise;
+                                typedef typename SubBlockWise::template iterate_columns<CurrentNumRows, GetInitialColumnVectors(
+                                        CurrentNumRows, SubConfig::Registers)> SubIterateColumns;
+                                SubIterateColumns::iterate(k, m_i, n_i, n, C, A, B);
+                            } else {
+                                remaining_columns<BlockWiseConfig, CurrentNumRows>::handle(k, m_i, n_i, n, C, A, B);
                             }
                         }
                     }
diff --git a/src/register_blocking/detail/RegisterBlockingManual.h b/src/register_blocking/detail/RegisterBlockingManual.h
deleted file mode 100644
index fd624e3..0000000
--- a/src/register_blocking/detail/RegisterBlockingManual.h
+++ /dev/null
@@ -1,80 +0,0 @@
-//
-// Created by oke on 01.07.20.
-//
-
-#ifndef SMID_MATRIX_REGISTERBLOCKINGMANUAL_H
-#define SMID_MATRIX_REGISTERBLOCKINGMANUAL_H
-
-#include "../../Matrix.h"
-#include "ExtendedBlockWiseConfig.h"
-
-template<typename BitWiseConfig, size_t Rows, size_t ColumnVectors>
-struct a {
-    typedef ExtendedBlockWiseConfig<BitWiseConfig> bwc;
-
-
-    template<typename M1, typename M2, typename M3>
-    void do_stuff(int p, int bi, M1 &C, const M2 &A, int aRowOffset, const M3 &B, int bColOffset) {
-
-        typename bwc::VectorType bb = bwc::LoadVector(&B(p, bColOffset + bi * bwc::VectorWidth));
-
-        for (int ai = 0; ai < Rows; ai++) { // row index in A (handling regsA rows)
-            typename bwc::VectorType aa = bwc::BroadcastToVector(A(aRowOffset + ai, p));
-            CReg[ai][bi] += aa * bb;
-        }
-
-    }
-
-};
-
-
-namespace detail {
-
-    template<
-            // template parameter as struct: otherwise some warning about losing alignment information warning
-            typename BitWiseConfig,
-            unsigned _NumRows, unsigned _NumColumnVectors
-    >
-    struct RegisterBlockingManual {
-        typedef ExtendedBlockWiseConfig<BitWiseConfig> bwc;
-
-        static constexpr auto AddAndStore = [](typename bwc::FloatType *memory, typename bwc::VectorType vector) {
-            bwc::StoreVector(memory, bwc::LoadVector(memory) + vector);
-        };
-
-        static constexpr auto NumRows = _NumRows;
-        static constexpr auto NumColumns = _NumColumnVectors * bwc::VectorWidth;
-
-        template<typename M1, typename M2, typename M3>
-        static void __attribute__ ((noinline))
-        handle_block(int k, M1 &C, const M2 &A, int aRowOffset, const M3 &B, int bColOffset) {
-
-            // AVX2 has 16 registers
-            // should be compiled as registers (total: regA * regB)
-            typename bwc::VectorType CReg[_NumRows][_NumColumnVectors] = {{0.0}};
-            // iterate over dot-product terms
-            for (int p = 0; p < k; p++) { // row index in B and column index in A (handling all rows/columns)
-                // Perform the DOT product
-                for (int bi = 0; bi < _NumColumnVectors; bi++) { // column index in B (handling regsB * 'VectorWidth' columns)
-                    typename bwc::VectorType bb = bwc::LoadVector(&B(p, bColOffset + bi * bwc::VectorWidth));
-                    for (int ai = 0; ai < _NumRows; ai++) { // row index in A (handling regsA rows)
-                        typename bwc::VectorType aa = bwc::BroadcastToVector(A(aRowOffset + ai, p));
-                        CReg[ai][bi] += aa * bb;
-                    }
-                }
-            }
-            // total regA * regB + regB registers
-
-            // Accumulate the results into C.
-            for (int ai = 0; ai < _NumRows; ai++) {
-                for (int bi = 0; bi < _NumColumnVectors; bi++) {
-                    AddAndStore(&C(aRowOffset + ai, bColOffset + bi * bwc::VectorWidth), CReg[ai][bi]);
-                }
-            }
-        }
-
-    };
-
-}
-
-#endif //SMID_MATRIX_REGISTERBLOCKINGMANUAL_H