Speed up remainder
This commit is contained in:
parent
ea5556616b
commit
34d384db8a
@ -73,6 +73,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--function", type=str, nargs="*")
|
||||
parser.add_argument("--release", action="store_true")
|
||||
parser.add_argument("--no_manual", action="store_true")
|
||||
parser.add_argument("--verbose", action="store_true")
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
@ -125,7 +126,9 @@ if __name__ == '__main__':
|
||||
if with_double:
|
||||
arguments.append("--double")
|
||||
output = compile_and_run("..", "builds", "simd_multiply", True, clang, options.avx512, options.release, options.no_manual, arguments + extra_args)
|
||||
ms = output.decode()[output.decode().find("multiply:") + 10:]
|
||||
if options.verbose:
|
||||
print(output.decode("utf-8"))
|
||||
ms = output.decode()[output.decode("utf-8").find("multiply:") + 10:]
|
||||
|
||||
if not already_dumped:
|
||||
build_path = os.path.join(get_build_path("..", "builds", "simd_multiply", True, clang, options.avx512, options.release, options.no_manual, arguments + extra_args)[1], "simd_multiply")
|
||||
|
@ -6,70 +6,7 @@
|
||||
#define SMID_MATRIX_BLOCKWISE_H
|
||||
|
||||
#include "detail/BlockWiseImpl.h"
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
struct __m128_block_wise_config {
|
||||
using FloatType = float;
|
||||
using VectorType = __m128;
|
||||
static constexpr auto LoadVector = _mm_loadu_ps;
|
||||
static constexpr auto StoreVector = _mm_storeu_ps;
|
||||
static constexpr auto BroadcastToVector = _mm_set1_ps;
|
||||
static constexpr auto XOR = _mm_xor_ps;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
struct __m128d_block_wise_config {
|
||||
using FloatType = double;
|
||||
using VectorType = __m128d;
|
||||
static constexpr auto LoadVector = _mm_loadu_pd;
|
||||
static constexpr auto StoreVector = _mm_storeu_pd;
|
||||
static constexpr auto BroadcastToVector = _mm_set1_pd;
|
||||
static constexpr auto XOR = _mm_xor_pd;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
struct __m256_block_wise_config {
|
||||
using FloatType = float;
|
||||
using VectorType = __m256;
|
||||
static constexpr auto LoadVector = _mm256_loadu_ps;
|
||||
static constexpr auto StoreVector = _mm256_storeu_ps;
|
||||
static constexpr auto BroadcastToVector = _mm256_set1_ps;
|
||||
static constexpr auto XOR = _mm256_xor_ps;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
struct __m256d_block_wise_config {
|
||||
using FloatType = double;
|
||||
using VectorType = __m256d;
|
||||
static constexpr auto LoadVector = _mm256_loadu_pd;
|
||||
static constexpr auto StoreVector = _mm256_storeu_pd;
|
||||
static constexpr auto BroadcastToVector = _mm256_set1_pd;
|
||||
static constexpr auto XOR = _mm256_xor_pd;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
#ifdef WITH_AVX512
|
||||
struct __m512_block_wise_config {
|
||||
using FloatType = float;
|
||||
using VectorType = __m512;
|
||||
static constexpr auto LoadVector = _mm512_loadu_ps;
|
||||
static constexpr auto StoreVector = _mm512_storeu_ps;
|
||||
static constexpr auto BroadcastToVector = _mm512_set1_ps;
|
||||
static constexpr auto XOR = _mm512_xor_ps;
|
||||
static constexpr unsigned Registers = 32;
|
||||
};
|
||||
|
||||
struct __m512d_block_wise_config {
|
||||
using FloatType = double;
|
||||
using VectorType = __m512d;
|
||||
static constexpr auto LoadVector = _mm512_loadu_pd;
|
||||
static constexpr auto StoreVector = _mm512_storeu_pd;
|
||||
static constexpr auto BroadcastToVector = _mm512_set1_pd;
|
||||
static constexpr auto XOR = _mm512_xor_pd;
|
||||
static constexpr unsigned Registers = 32;
|
||||
};
|
||||
#endif
|
||||
#include "detail/BlockWiseConfigs.h"
|
||||
|
||||
enum AvxVersion {
|
||||
SSE,
|
||||
|
134
src/register_blocking/detail/BlockWiseConfigs.h
Normal file
134
src/register_blocking/detail/BlockWiseConfigs.h
Normal file
@ -0,0 +1,134 @@
|
||||
//
|
||||
// Created by oke on 05.07.20.
|
||||
//
|
||||
|
||||
#ifndef SMID_MATRIX_BLOCKWISECONFIGS_H
|
||||
#define SMID_MATRIX_BLOCKWISECONFIGS_H
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
template<typename Config, unsigned CurrentNumRows>
|
||||
struct masked_remaining_columns {
|
||||
|
||||
static constexpr bool IsCustom = true;
|
||||
|
||||
typedef ExtendedBlockWiseConfig<Config> bwc;
|
||||
|
||||
template<typename M1, typename M2, typename M3>
|
||||
static void handle(size_t k, size_t m_i, size_t &n_i, size_t n, M1 &C, const M2 &A, const M3 &B) {
|
||||
typename Config::Mask mask = 0;
|
||||
mask = (~mask);
|
||||
mask = mask << (n - n_i);
|
||||
mask = ~mask;
|
||||
constexpr typename bwc::VectorType zeros = { 0.0 };
|
||||
typename bwc::VectorType CReg[CurrentNumRows] = { 0.0 };
|
||||
for (int p = 0; p < k; p++) {
|
||||
typename bwc::VectorType bb = Config::LoadVectorMasked(zeros, mask, &B(p, n_i));
|
||||
for (int ai = 0; ai < CurrentNumRows; ai++) { // row index in A (handling regsA rows)
|
||||
typename bwc::VectorType aa = bwc::BroadcastToVector(A(m_i + ai, p));
|
||||
CReg[ai] += aa * bb;
|
||||
}
|
||||
}
|
||||
for (int ai = 0; ai < CurrentNumRows; ai++) {
|
||||
auto memory = &C(m_i + ai, n_i);
|
||||
Config::StoreVectorMasked(memory, mask, Config::LoadVectorMasked(zeros, mask, memory) + CReg[ai]);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
struct __m128_block_wise_config {
|
||||
static constexpr bool has_sub_config = false;
|
||||
using FloatType = float;
|
||||
using VectorType = __m128;
|
||||
static constexpr auto LoadVector = _mm_loadu_ps;
|
||||
static constexpr auto StoreVector = _mm_storeu_ps;
|
||||
static constexpr auto BroadcastToVector = _mm_set1_ps;
|
||||
static constexpr auto XOR = _mm_xor_ps;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
struct __m128d_block_wise_config {
|
||||
static constexpr bool has_sub_config = false;
|
||||
using FloatType = double;
|
||||
using VectorType = __m128d;
|
||||
static constexpr auto LoadVector = _mm_loadu_pd;
|
||||
static constexpr auto StoreVector = _mm_storeu_pd;
|
||||
static constexpr auto BroadcastToVector = _mm_set1_pd;
|
||||
static constexpr auto XOR = _mm_xor_pd;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
struct __m256_block_wise_config {
|
||||
static constexpr bool has_sub_config = true;
|
||||
using SubConfig = __m128_block_wise_config;
|
||||
|
||||
using FloatType = float;
|
||||
using VectorType = __m256;
|
||||
static constexpr auto LoadVector = _mm256_loadu_ps;
|
||||
static constexpr auto StoreVector = _mm256_storeu_ps;
|
||||
static constexpr auto BroadcastToVector = _mm256_set1_ps;
|
||||
static constexpr auto XOR = _mm256_xor_ps;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
struct __m256d_block_wise_config {
|
||||
static constexpr bool has_sub_config = true;
|
||||
using SubConfig = __m128d_block_wise_config;
|
||||
using FloatType = double;
|
||||
using VectorType = __m256d;
|
||||
|
||||
static constexpr auto LoadVector = _mm256_loadu_pd;
|
||||
static constexpr auto StoreVector = _mm256_storeu_pd;
|
||||
static constexpr auto BroadcastToVector = _mm256_set1_pd;
|
||||
static constexpr auto XOR = _mm256_xor_pd;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
#ifdef WITH_AVX512
|
||||
|
||||
struct __m512_block_wise_config {
|
||||
static constexpr bool has_sub_config = true;
|
||||
using SubConfig = __m256_block_wise_config;
|
||||
using FloatType = float;
|
||||
using VectorType = __m512;
|
||||
using Mask = __mmask16;
|
||||
static constexpr auto LoadVectorMasked = _mm512_mask_loadu_ps;
|
||||
static constexpr auto StoreVectorMasked = _mm512_mask_storeu_ps;
|
||||
static constexpr auto LoadVector = _mm512_loadu_ps;
|
||||
static constexpr auto StoreVector = _mm512_storeu_ps;
|
||||
static constexpr auto BroadcastToVector = _mm512_set1_ps;
|
||||
static constexpr auto XOR = _mm512_xor_ps;
|
||||
static constexpr unsigned Registers = 32;
|
||||
};
|
||||
|
||||
struct __m512d_block_wise_config {
|
||||
static constexpr bool has_sub_config = true;
|
||||
using SubConfig = __m256d_block_wise_config;
|
||||
using FloatType = double;
|
||||
using VectorType = __m512d;
|
||||
using Mask = __mmask16;
|
||||
static constexpr auto LoadVectorMasked = _mm512_mask_loadu_pd;
|
||||
static constexpr auto StoreVectorMasked = _mm512_mask_storeu_pd;
|
||||
static constexpr auto LoadVector = _mm512_loadu_pd;
|
||||
static constexpr auto StoreVector = _mm512_storeu_pd;
|
||||
static constexpr auto BroadcastToVector = _mm512_set1_pd;
|
||||
static constexpr auto XOR = _mm512_xor_pd;
|
||||
static constexpr unsigned Registers = 32;
|
||||
};
|
||||
|
||||
template<typename BlockWiseConfig, unsigned CurrentNumRows>
|
||||
struct remaining_columns;
|
||||
|
||||
template<unsigned CurrentNumRows>
|
||||
struct remaining_columns<__m512_block_wise_config, CurrentNumRows>
|
||||
: masked_remaining_columns<__m512_block_wise_config, CurrentNumRows> {};
|
||||
|
||||
template<unsigned CurrentNumRows>
|
||||
struct remaining_columns<__m512d_block_wise_config, CurrentNumRows>
|
||||
: masked_remaining_columns<__m512d_block_wise_config, CurrentNumRows> {};
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
#endif //SMID_MATRIX_BLOCKWISECONFIGS_H
|
@ -7,6 +7,23 @@
|
||||
|
||||
#include "RegisterBlocking.h"
|
||||
|
||||
template<typename BlockWiseConfig, unsigned CurrentNumRows>
|
||||
struct remaining_columns {
|
||||
static constexpr bool IsCustom = false;
|
||||
|
||||
template<typename M1, typename M2, typename M3>
|
||||
static void handle(size_t k, size_t m_i, size_t &n_i, size_t n, M1 &C, const M2 &A, const M3 &B) {
|
||||
for (auto m_i_o = 0; m_i_o < CurrentNumRows; ++m_i_o) {
|
||||
for (int p = 0; p < k; ++p) {
|
||||
for (auto n_i_rest = n_i; n_i_rest < n; ++n_i_rest) {
|
||||
C(m_i + m_i_o, n_i_rest) += A(m_i + m_i_o, p) * B(p, n_i_rest);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<
|
||||
@ -89,11 +106,19 @@ namespace detail {
|
||||
} else {
|
||||
// do rest of columns manually
|
||||
if (n_i < n) {
|
||||
for (auto m_i_o = 0; m_i_o < CurrentNumRows; ++m_i_o) {
|
||||
for (int p = 0; p < k; ++p) {
|
||||
for (auto n_i_rest = n_i; n_i_rest < n; ++n_i_rest) {
|
||||
C(m_i + m_i_o, n_i_rest) += A(m_i + m_i_o, p) * B(p, n_i_rest);
|
||||
}
|
||||
// custom implementation of remaining columns? good, use it.
|
||||
if constexpr(remaining_columns<BlockWiseConfig, CurrentNumRows>::IsCustom) {
|
||||
remaining_columns<BlockWiseConfig, CurrentNumRows>::handle(k, m_i, n_i, n, C, A, B);
|
||||
} else {
|
||||
// otherwise: has sub_config with less vector width? use this first before going to fallback implementation
|
||||
if constexpr(BlockWiseConfig::has_sub_config) {
|
||||
typedef typename BlockWiseConfig::SubConfig SubConfig;
|
||||
typedef block_wise<SubConfig> SubBlockWise;
|
||||
typedef typename SubBlockWise::template iterate_columns<CurrentNumRows, GetInitialColumnVectors(
|
||||
CurrentNumRows, SubConfig::Registers)> SubIterateColumns;
|
||||
SubIterateColumns::iterate(k, m_i, n_i, n, C, A, B);
|
||||
} else {
|
||||
remaining_columns<BlockWiseConfig, CurrentNumRows>::handle(k, m_i, n_i, n, C, A, B);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1,80 +0,0 @@
|
||||
//
|
||||
// Created by oke on 01.07.20.
|
||||
//
|
||||
|
||||
#ifndef SMID_MATRIX_REGISTERBLOCKINGMANUAL_H
|
||||
#define SMID_MATRIX_REGISTERBLOCKINGMANUAL_H
|
||||
|
||||
#include "../../Matrix.h"
|
||||
#include "ExtendedBlockWiseConfig.h"
|
||||
|
||||
template<typename BitWiseConfig, size_t Rows, size_t ColumnVectors>
|
||||
struct a {
|
||||
typedef ExtendedBlockWiseConfig<BitWiseConfig> bwc;
|
||||
|
||||
|
||||
template<typename M1, typename M2, typename M3>
|
||||
void do_stuff(int p, int bi, M1 &C, const M2 &A, int aRowOffset, const M3 &B, int bColOffset) {
|
||||
|
||||
typename bwc::VectorType bb = bwc::LoadVector(&B(p, bColOffset + bi * bwc::VectorWidth));
|
||||
|
||||
for (int ai = 0; ai < Rows; ai++) { // row index in A (handling regsA rows)
|
||||
typename bwc::VectorType aa = bwc::BroadcastToVector(A(aRowOffset + ai, p));
|
||||
CReg[ai][bi] += aa * bb;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
namespace detail {
|
||||
|
||||
template<
|
||||
// template parameter as struct: otherwise some warning about losing alignment information warning
|
||||
typename BitWiseConfig,
|
||||
unsigned _NumRows, unsigned _NumColumnVectors
|
||||
>
|
||||
struct RegisterBlockingManual {
|
||||
typedef ExtendedBlockWiseConfig<BitWiseConfig> bwc;
|
||||
|
||||
static constexpr auto AddAndStore = [](typename bwc::FloatType *memory, typename bwc::VectorType vector) {
|
||||
bwc::StoreVector(memory, bwc::LoadVector(memory) + vector);
|
||||
};
|
||||
|
||||
static constexpr auto NumRows = _NumRows;
|
||||
static constexpr auto NumColumns = _NumColumnVectors * bwc::VectorWidth;
|
||||
|
||||
template<typename M1, typename M2, typename M3>
|
||||
static void __attribute__ ((noinline))
|
||||
handle_block(int k, M1 &C, const M2 &A, int aRowOffset, const M3 &B, int bColOffset) {
|
||||
|
||||
// AVX2 has 16 registers
|
||||
// should be compiled as registers (total: regA * regB)
|
||||
typename bwc::VectorType CReg[_NumRows][_NumColumnVectors] = {{0.0}};
|
||||
// iterate over dot-product terms
|
||||
for (int p = 0; p < k; p++) { // row index in B and column index in A (handling all rows/columns)
|
||||
// Perform the DOT product
|
||||
for (int bi = 0; bi < _NumColumnVectors; bi++) { // column index in B (handling regsB * 'VectorWidth' columns)
|
||||
typename bwc::VectorType bb = bwc::LoadVector(&B(p, bColOffset + bi * bwc::VectorWidth));
|
||||
for (int ai = 0; ai < _NumRows; ai++) { // row index in A (handling regsA rows)
|
||||
typename bwc::VectorType aa = bwc::BroadcastToVector(A(aRowOffset + ai, p));
|
||||
CReg[ai][bi] += aa * bb;
|
||||
}
|
||||
}
|
||||
}
|
||||
// total regA * regB + regB registers
|
||||
|
||||
// Accumulate the results into C.
|
||||
for (int ai = 0; ai < _NumRows; ai++) {
|
||||
for (int bi = 0; bi < _NumColumnVectors; bi++) {
|
||||
AddAndStore(&C(aRowOffset + ai, bColOffset + bi * bwc::VectorWidth), CReg[ai][bi]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif //SMID_MATRIX_REGISTERBLOCKINGMANUAL_H
|
Loading…
x
Reference in New Issue
Block a user