Automatic grid size
This commit is contained in:
parent
17f187669d
commit
ad6c4ca996
@ -25,6 +25,7 @@ namespace detail {
|
||||
static constexpr auto LoadVector = _mm256_loadu_ps;
|
||||
static constexpr auto StoreVector = _mm256_storeu_ps;
|
||||
static constexpr auto BroadcastToVector = _mm256_set1_ps;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
struct __m256d_block_wise_config {
|
||||
@ -33,6 +34,7 @@ namespace detail {
|
||||
static constexpr auto LoadVector = _mm256_loadu_pd;
|
||||
static constexpr auto StoreVector = _mm256_storeu_pd;
|
||||
static constexpr auto BroadcastToVector = _mm256_set1_pd;
|
||||
static constexpr unsigned Registers = 16;
|
||||
};
|
||||
|
||||
#ifdef WITH_AVX512
|
||||
@ -42,6 +44,7 @@ namespace detail {
|
||||
static constexpr auto LoadVector = _mm512_loadu_ps;
|
||||
static constexpr auto StoreVector = _mm512_storeu_ps;
|
||||
static constexpr auto BroadcastToVector = _mm512_set1_ps;
|
||||
static constexpr unsigned Registers = 32;
|
||||
};
|
||||
|
||||
struct __m512d_block_wise_config {
|
||||
@ -50,13 +53,34 @@ namespace detail {
|
||||
static constexpr auto LoadVector = _mm512_loadu_pd;
|
||||
static constexpr auto StoreVector = _mm512_storeu_pd;
|
||||
static constexpr auto BroadcastToVector = _mm512_set1_pd;
|
||||
static constexpr unsigned Registers = 32;
|
||||
};
|
||||
#endif
|
||||
|
||||
// maximize = (R * C) / (R + C) for R + R * C < 16 => any fixed r -> largest C with < 16
|
||||
// C = floor ((16 - R) / R)
|
||||
|
||||
constexpr unsigned GetInitialColumnVectors(unsigned R, unsigned Registers) {
|
||||
if (R == 0) return 0;
|
||||
return (unsigned) ((double) (Registers - R) / (double) R);
|
||||
}
|
||||
|
||||
constexpr unsigned GetInitialRows(unsigned Registers) {
|
||||
for(unsigned R = Registers; R > 0; --R) {
|
||||
if(R + R * (R + 1) < Registers) {
|
||||
return R;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
template<typename BlockWiseConfig>
|
||||
struct ExtendedBlockWiseConfig {
|
||||
typedef typename BlockWiseConfig::FloatType FloatType;
|
||||
using VectorType = typename BlockWiseConfig::VectorType;
|
||||
|
||||
static constexpr auto Registers = BlockWiseConfig::Registers;
|
||||
static constexpr auto LoadVector = BlockWiseConfig::LoadVector;
|
||||
static constexpr auto StoreVector = BlockWiseConfig::StoreVector;
|
||||
static constexpr auto BroadcastToVector = BlockWiseConfig::BroadcastToVector;
|
||||
@ -113,30 +137,16 @@ namespace detail {
|
||||
|
||||
};
|
||||
|
||||
template<typename BitWiseConfig, unsigned Rows>
|
||||
struct BestRowRegisterBlocking;
|
||||
|
||||
template<typename BitWiseConfig>
|
||||
struct BestRowRegisterBlocking<BitWiseConfig, 1> : public RegisterBlocking<BitWiseConfig, 1, 15> {
|
||||
};
|
||||
template<typename BitWiseConfig>
|
||||
struct BestRowRegisterBlocking<BitWiseConfig, 2> : public RegisterBlocking<BitWiseConfig, 2, 7> {
|
||||
};
|
||||
template<typename BitWiseConfig>
|
||||
struct BestRowRegisterBlocking<BitWiseConfig, 3> : public RegisterBlocking<BitWiseConfig, 3, 4> {
|
||||
};
|
||||
// maximize = (R * C) / (R + C) for R + R * C < 16 => any fixed r -> largest C with < 16
|
||||
// C = floor ((16 - R) / R)
|
||||
|
||||
template<
|
||||
// template parameter as struct: otherwise some warning about losing alignment information warning
|
||||
typename BlockWiseConfig = __m256_block_wise_config,
|
||||
unsigned InitialNumRows = 3, unsigned InitialNumColumnVectors = 4
|
||||
typename BlockWiseConfig = __m256_block_wise_config
|
||||
>
|
||||
struct block_wise {
|
||||
|
||||
typedef ExtendedBlockWiseConfig<BlockWiseConfig> bwc;
|
||||
|
||||
static constexpr unsigned InitialNumRows = GetInitialRows(bwc::Registers);
|
||||
static constexpr unsigned InitialNumColumnVectors = GetInitialColumnVectors(InitialNumRows, bwc::Registers);
|
||||
static constexpr unsigned NumRows = InitialNumRows;
|
||||
static constexpr unsigned NumColumns = InitialNumColumnVectors * bwc::VectorWidth;
|
||||
|
||||
@ -248,7 +258,7 @@ namespace detail {
|
||||
consume(k, m_i, m, n, C, A, B);
|
||||
if (CurrentNumRows > 1) {
|
||||
// try with less num rows vectors
|
||||
iterate_rows<CurrentNumRows - 1, CurrentNumColumnVectors>::iterate(k, m_i, m, n, C, A, B);
|
||||
iterate_rows<CurrentNumRows - 1, GetInitialColumnVectors(CurrentNumRows - 1, bwc::Registers)>::iterate(k, m_i, m, n, C, A, B);
|
||||
}
|
||||
assert(m_i == m);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user