Fix incorrect splitting of divide and conquer sub-blocks

This commit is contained in:
har0ke 2020-07-05 19:00:40 +02:00
parent dff24a914c
commit ea5556616b

View File

@ -6,16 +6,17 @@
#define SMID_MATRIX_DEVIDEANDCONQUER_H
#include "register_blocking/BlockWise.h"
#include "Boost.h"
#include "Naive.h"
#include <boost/numeric/ublas/matrix_proxy.hpp>
enum SplitAction {
SplitA, SplitB, SplitBoth, DoNotSplit
};
size_t splitByMultiple(size_t value, size_t multiple) {
auto result = (value / (multiple * 2)) * multiple;
return result == 0 ? 1 : result;
return result == 0 ? multiple : result;
}
template<typename Multiplier, typename M1, typename M2, typename M3>
@ -34,7 +35,6 @@ void _divide_and_conquer(M1 &C, const M2 &A, const M3 &B) {
case SplitA:
{
size_t split_index = splitByMultiple(n, Multiplier::SplitMultipleM);
if (split_index == 0) split_index = 1;
const auto A1 = boost::numeric::ublas::project(A, {0, split_index}, {0, A.size2()});
const auto A2 = boost::numeric::ublas::project(A, {split_index, A.size1()}, {0, A.size2()});
auto C1 = boost::numeric::ublas::project(C, {0, split_index}, {0, C.size2()});
@ -46,7 +46,6 @@ void _divide_and_conquer(M1 &C, const M2 &A, const M3 &B) {
case SplitB:
{
size_t split_index = splitByMultiple(m, Multiplier::SplitMultipleN);
if (split_index == 0) split_index = 1;
auto B1 = boost::numeric::ublas::project(B, {0, B.size1()}, {0, split_index});
auto B2 = boost::numeric::ublas::project(B, {0, B.size1()}, {split_index, B.size2()});
auto C1 = boost::numeric::ublas::project(C, {0, C.size1()}, {0, split_index});
@ -123,15 +122,15 @@ struct multiplier_block_wise : block_wise<FloatType, version> {
size_t p_p = p / SplitMultipleP;
size_t n_p = n / SplitMultipleN;
unsigned _SplitMultipleM = SplitMultipleM;
unsigned _SplitMultipleP = 1;
unsigned _SplitMultipleP = PMultiplier;
unsigned _SplitMultipleN = SplitMultipleN;
auto max_dim = std::max(m_p, std::max(p_p, n_p));
if (m <= SplitMultipleM && n <= SplitMultipleN) {
return DoNotSplit;
} else if (max_dim == m_p) {
} else if (max_dim == m_p && m != SplitMultipleM) {
return SplitA;
} else if (max_dim == n_p) {
} else if (max_dim == n_p && n != SplitMultipleN) {
return SplitB;
} else if (p > SplitMultipleP){
return SplitBoth;