Automatic grid size
This commit is contained in:
parent
17f187669d
commit
ad6c4ca996
@ -25,6 +25,7 @@ namespace detail {
|
|||||||
static constexpr auto LoadVector = _mm256_loadu_ps;
|
static constexpr auto LoadVector = _mm256_loadu_ps;
|
||||||
static constexpr auto StoreVector = _mm256_storeu_ps;
|
static constexpr auto StoreVector = _mm256_storeu_ps;
|
||||||
static constexpr auto BroadcastToVector = _mm256_set1_ps;
|
static constexpr auto BroadcastToVector = _mm256_set1_ps;
|
||||||
|
static constexpr unsigned Registers = 16;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct __m256d_block_wise_config {
|
struct __m256d_block_wise_config {
|
||||||
@ -33,6 +34,7 @@ namespace detail {
|
|||||||
static constexpr auto LoadVector = _mm256_loadu_pd;
|
static constexpr auto LoadVector = _mm256_loadu_pd;
|
||||||
static constexpr auto StoreVector = _mm256_storeu_pd;
|
static constexpr auto StoreVector = _mm256_storeu_pd;
|
||||||
static constexpr auto BroadcastToVector = _mm256_set1_pd;
|
static constexpr auto BroadcastToVector = _mm256_set1_pd;
|
||||||
|
static constexpr unsigned Registers = 16;
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef WITH_AVX512
|
#ifdef WITH_AVX512
|
||||||
@ -42,6 +44,7 @@ namespace detail {
|
|||||||
static constexpr auto LoadVector = _mm512_loadu_ps;
|
static constexpr auto LoadVector = _mm512_loadu_ps;
|
||||||
static constexpr auto StoreVector = _mm512_storeu_ps;
|
static constexpr auto StoreVector = _mm512_storeu_ps;
|
||||||
static constexpr auto BroadcastToVector = _mm512_set1_ps;
|
static constexpr auto BroadcastToVector = _mm512_set1_ps;
|
||||||
|
static constexpr unsigned Registers = 32;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct __m512d_block_wise_config {
|
struct __m512d_block_wise_config {
|
||||||
@ -50,13 +53,34 @@ namespace detail {
|
|||||||
static constexpr auto LoadVector = _mm512_loadu_pd;
|
static constexpr auto LoadVector = _mm512_loadu_pd;
|
||||||
static constexpr auto StoreVector = _mm512_storeu_pd;
|
static constexpr auto StoreVector = _mm512_storeu_pd;
|
||||||
static constexpr auto BroadcastToVector = _mm512_set1_pd;
|
static constexpr auto BroadcastToVector = _mm512_set1_pd;
|
||||||
|
static constexpr unsigned Registers = 32;
|
||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// maximize = (R * C) / (R + C) for R + R * C < 16 => any fixed r -> largest C with < 16
|
||||||
|
// C = floor ((16 - R) / R)
|
||||||
|
|
||||||
|
constexpr unsigned GetInitialColumnVectors(unsigned R, unsigned Registers) {
|
||||||
|
if (R == 0) return 0;
|
||||||
|
return (unsigned) ((double) (Registers - R) / (double) R);
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr unsigned GetInitialRows(unsigned Registers) {
|
||||||
|
for(unsigned R = Registers; R > 0; --R) {
|
||||||
|
if(R + R * (R + 1) < Registers) {
|
||||||
|
return R;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename BlockWiseConfig>
|
template<typename BlockWiseConfig>
|
||||||
struct ExtendedBlockWiseConfig {
|
struct ExtendedBlockWiseConfig {
|
||||||
typedef typename BlockWiseConfig::FloatType FloatType;
|
typedef typename BlockWiseConfig::FloatType FloatType;
|
||||||
using VectorType = typename BlockWiseConfig::VectorType;
|
using VectorType = typename BlockWiseConfig::VectorType;
|
||||||
|
|
||||||
|
static constexpr auto Registers = BlockWiseConfig::Registers;
|
||||||
static constexpr auto LoadVector = BlockWiseConfig::LoadVector;
|
static constexpr auto LoadVector = BlockWiseConfig::LoadVector;
|
||||||
static constexpr auto StoreVector = BlockWiseConfig::StoreVector;
|
static constexpr auto StoreVector = BlockWiseConfig::StoreVector;
|
||||||
static constexpr auto BroadcastToVector = BlockWiseConfig::BroadcastToVector;
|
static constexpr auto BroadcastToVector = BlockWiseConfig::BroadcastToVector;
|
||||||
@ -113,30 +137,16 @@ namespace detail {
|
|||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename BitWiseConfig, unsigned Rows>
|
|
||||||
struct BestRowRegisterBlocking;
|
|
||||||
|
|
||||||
template<typename BitWiseConfig>
|
|
||||||
struct BestRowRegisterBlocking<BitWiseConfig, 1> : public RegisterBlocking<BitWiseConfig, 1, 15> {
|
|
||||||
};
|
|
||||||
template<typename BitWiseConfig>
|
|
||||||
struct BestRowRegisterBlocking<BitWiseConfig, 2> : public RegisterBlocking<BitWiseConfig, 2, 7> {
|
|
||||||
};
|
|
||||||
template<typename BitWiseConfig>
|
|
||||||
struct BestRowRegisterBlocking<BitWiseConfig, 3> : public RegisterBlocking<BitWiseConfig, 3, 4> {
|
|
||||||
};
|
|
||||||
// maximize = (R * C) / (R + C) for R + R * C < 16 => any fixed r -> largest C with < 16
|
|
||||||
// C = floor ((16 - R) / R)
|
|
||||||
|
|
||||||
template<
|
template<
|
||||||
// template parameter as struct: otherwise some warning about losing alignment information warning
|
// template parameter as struct: otherwise some warning about losing alignment information warning
|
||||||
typename BlockWiseConfig = __m256_block_wise_config,
|
typename BlockWiseConfig = __m256_block_wise_config
|
||||||
unsigned InitialNumRows = 3, unsigned InitialNumColumnVectors = 4
|
|
||||||
>
|
>
|
||||||
struct block_wise {
|
struct block_wise {
|
||||||
|
|
||||||
typedef ExtendedBlockWiseConfig<BlockWiseConfig> bwc;
|
typedef ExtendedBlockWiseConfig<BlockWiseConfig> bwc;
|
||||||
|
|
||||||
|
static constexpr unsigned InitialNumRows = GetInitialRows(bwc::Registers);
|
||||||
|
static constexpr unsigned InitialNumColumnVectors = GetInitialColumnVectors(InitialNumRows, bwc::Registers);
|
||||||
static constexpr unsigned NumRows = InitialNumRows;
|
static constexpr unsigned NumRows = InitialNumRows;
|
||||||
static constexpr unsigned NumColumns = InitialNumColumnVectors * bwc::VectorWidth;
|
static constexpr unsigned NumColumns = InitialNumColumnVectors * bwc::VectorWidth;
|
||||||
|
|
||||||
@ -248,7 +258,7 @@ namespace detail {
|
|||||||
consume(k, m_i, m, n, C, A, B);
|
consume(k, m_i, m, n, C, A, B);
|
||||||
if (CurrentNumRows > 1) {
|
if (CurrentNumRows > 1) {
|
||||||
// try with less num rows vectors
|
// try with less num rows vectors
|
||||||
iterate_rows<CurrentNumRows - 1, CurrentNumColumnVectors>::iterate(k, m_i, m, n, C, A, B);
|
iterate_rows<CurrentNumRows - 1, GetInitialColumnVectors(CurrentNumRows - 1, bwc::Registers)>::iterate(k, m_i, m, n, C, A, B);
|
||||||
}
|
}
|
||||||
assert(m_i == m);
|
assert(m_i == m);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user