eigen3/fix_ppc64le_always_inline_680.patch

3139 lines
135 KiB
Diff
Raw Normal View History

From 9e3873b1dce3ba65980c7e7b979325dac2fb4bbd Mon Sep 17 00:00:00 2001
From: Chip-Kerchner <chip.kerchner@ibm.com>
Date: Wed, 20 Oct 2021 11:06:50 -0500
Subject: [PATCH 1/2] New branch for inverting rows and depth in non-vectorized
portion of packing.
---
Eigen/src/Core/arch/AltiVec/Complex.h | 10 +-
Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 1546 ++++++++---------
.../Core/arch/AltiVec/MatrixProductCommon.h | 206 +--
.../src/Core/arch/AltiVec/MatrixProductMMA.h | 335 ++--
4 files changed, 927 insertions(+), 1170 deletions(-)
diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h
index f730ce8d3..4fd923e84 100644
--- a/Eigen/src/Core/arch/AltiVec/Complex.h
+++ b/Eigen/src/Core/arch/AltiVec/Complex.h
@@ -129,20 +129,20 @@ template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<
template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstore((float*)to, from.v); }
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstoreu((float*)to, from.v); }
-EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>* from0, const std::complex<float>* from1)
+EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>& from0, const std::complex<float>& from1)
{
Packet4f res0, res1;
#ifdef __VSX__
- __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (*from0));
- __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (*from1));
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (from0));
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (from1));
#ifdef _BIG_ENDIAN
__asm__ ("xxpermdi %x0, %x1, %x2, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
#else
__asm__ ("xxpermdi %x0, %x2, %x1, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
#endif
#else
- *reinterpret_cast<std::complex<float> *>(&res0) = *from0;
- *reinterpret_cast<std::complex<float> *>(&res1) = *from1;
+ *reinterpret_cast<std::complex<float> *>(&res0) = from0;
+ *reinterpret_cast<std::complex<float> *>(&res1) = from1;
res0 = vec_perm(res0, res1, p16uc_TRANSPOSE64_HI);
#endif
return Packet2cf(res0);
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
index 1d67d60d0..bd5da3623 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
@@ -166,24 +166,23 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* bloc
rir += vectorDelta;
}
- if (j < cols)
+
+ for(; j < cols; j++)
{
- rii = rir + ((cols - j) * rows);
+ rii = rir + rows;
for(Index i = k2; i < depth; i++)
{
- Index k = j;
- for(; k < cols; k++)
- {
- std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, k, rhs);
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j, rhs);
- blockBf[rir] = v.real();
- blockBf[rii] = v.imag();
+ blockBf[rir] = v.real();
+ blockBf[rii] = v.imag();
- rir += 1;
- rii += 1;
- }
+ rir += 1;
+ rii += 1;
}
+
+ rir += rows;
}
}
@@ -262,19 +261,15 @@ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs
}
}
- if (j < cols)
+ for(; j < cols; j++)
{
for(Index i = k2; i < depth; i++)
{
- Index k = j;
- for(; k < cols; k++)
- {
- if(k <= i)
- blockB[ri] = rhs(i, k);
- else
- blockB[ri] = rhs(k, i);
- ri += 1;
- }
+ if(j <= i)
+ blockB[ri] = rhs(i, j);
+ else
+ blockB[ri] = rhs(j, i);
+ ri += 1;
}
}
}
@@ -408,22 +403,18 @@ struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
* and offset and behaves accordingly.
**/
-template<typename Scalar, typename Packet, typename Index>
-EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,4>& block)
-{
- const Index size = 16 / sizeof(Scalar);
- pstore<Scalar>(to + (0 * size), block.packet[0]);
- pstore<Scalar>(to + (1 * size), block.packet[1]);
- pstore<Scalar>(to + (2 * size), block.packet[2]);
- pstore<Scalar>(to + (3 * size), block.packet[3]);
-}
-
-template<typename Scalar, typename Packet, typename Index>
-EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,2>& block)
+template<typename Scalar, typename Packet, typename Index, int N>
+EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
{
const Index size = 16 / sizeof(Scalar);
pstore<Scalar>(to + (0 * size), block.packet[0]);
pstore<Scalar>(to + (1 * size), block.packet[1]);
+ if (N > 2) {
+ pstore<Scalar>(to + (2 * size), block.packet[2]);
+ }
+ if (N > 3) {
+ pstore<Scalar>(to + (3 * size), block.packet[3]);
+ }
}
// General template for lhs & rhs complex packing.
@@ -449,9 +440,9 @@ struct dhs_cpack {
PacketBlock<PacketC,8> cblock;
if (UseLhs) {
- bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, j, i);
+ bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, j, i);
} else {
- bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, i, j);
+ bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, i, j);
}
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
@@ -478,8 +469,8 @@ struct dhs_cpack {
ptranspose(blocki);
}
- storeBlock<Scalar, Packet, Index>(blockAt + rir, blockr);
- storeBlock<Scalar, Packet, Index>(blockAt + rii, blocki);
+ storeBlock<Scalar, Packet, Index, 4>(blockAt + rir, blockr);
+ storeBlock<Scalar, Packet, Index, 4>(blockAt + rii, blocki);
rir += 4*vectorSize;
rii += 4*vectorSize;
@@ -499,21 +490,12 @@ struct dhs_cpack {
cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2);
}
} else {
- std::complex<Scalar> lhs0, lhs1;
if (UseLhs) {
- lhs0 = lhs(j + 0, i);
- lhs1 = lhs(j + 1, i);
- cblock.packet[0] = pload2(&lhs0, &lhs1);
- lhs0 = lhs(j + 2, i);
- lhs1 = lhs(j + 3, i);
- cblock.packet[1] = pload2(&lhs0, &lhs1);
+ cblock.packet[0] = pload2(lhs(j + 0, i), lhs(j + 1, i));
+ cblock.packet[1] = pload2(lhs(j + 2, i), lhs(j + 3, i));
} else {
- lhs0 = lhs(i, j + 0);
- lhs1 = lhs(i, j + 1);
- cblock.packet[0] = pload2(&lhs0, &lhs1);
- lhs0 = lhs(i, j + 2);
- lhs1 = lhs(i, j + 3);
- cblock.packet[1] = pload2(&lhs0, &lhs1);
+ cblock.packet[0] = pload2(lhs(i, j + 0), lhs(i, j + 1));
+ cblock.packet[1] = pload2(lhs(i, j + 2), lhs(i, j + 3));
}
}
@@ -535,34 +517,50 @@ struct dhs_cpack {
rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
}
- if (j < rows)
+ if (!UseLhs)
{
- if(PanelMode) rir += (offset*(rows - j - vectorSize));
- rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
+ if(PanelMode) rir -= (offset*(vectorSize - 1));
- for(Index i = 0; i < depth; i++)
+ for(; j < rows; j++)
{
- Index k = j;
- for(; k < rows; k++)
+ rii = rir + ((PanelMode) ? stride : depth);
+
+ for(Index i = 0; i < depth; i++)
{
- if (UseLhs) {
+ blockAt[rir] = lhs(i, j).real();
+
+ if(Conjugate)
+ blockAt[rii] = -lhs(i, j).imag();
+ else
+ blockAt[rii] = lhs(i, j).imag();
+
+ rir += 1;
+ rii += 1;
+ }
+
+ rir += ((PanelMode) ? (2*stride - depth) : depth);
+ }
+ } else {
+ if (j < rows)
+ {
+ if(PanelMode) rir += (offset*(rows - j - vectorSize));
+ rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
blockAt[rir] = lhs(k, i).real();
if(Conjugate)
blockAt[rii] = -lhs(k, i).imag();
else
blockAt[rii] = lhs(k, i).imag();
- } else {
- blockAt[rir] = lhs(i, k).real();
- if(Conjugate)
- blockAt[rii] = -lhs(i, k).imag();
- else
- blockAt[rii] = lhs(i, k).imag();
+ rir += 1;
+ rii += 1;
}
-
- rir += 1;
- rii += 1;
}
}
}
@@ -588,16 +586,16 @@ struct dhs_pack{
PacketBlock<Packet,4> block;
if (UseLhs) {
- bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, j, i);
+ bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, j, i);
} else {
- bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, i, j);
+ bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, i, j);
}
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
{
ptranspose(block);
}
- storeBlock<Scalar, Packet, Index>(blockA + ri, block);
+ storeBlock<Scalar, Packet, Index, 4>(blockA + ri, block);
ri += 4*vectorSize;
}
@@ -632,21 +630,33 @@ struct dhs_pack{
if(PanelMode) ri += vectorSize*(stride - offset - depth);
}
- if (j < rows)
+ if (!UseLhs)
{
- if(PanelMode) ri += offset*(rows - j);
+ if(PanelMode) ri += offset;
- for(Index i = 0; i < depth; i++)
+ for(; j < rows; j++)
{
- Index k = j;
- for(; k < rows; k++)
+ for(Index i = 0; i < depth; i++)
{
- if (UseLhs) {
+ blockA[ri] = lhs(i, j);
+ ri += 1;
+ }
+
+ if(PanelMode) ri += stride - depth;
+ }
+ } else {
+ if (j < rows)
+ {
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
blockA[ri] = lhs(k, i);
- } else {
- blockA[ri] = lhs(i, k);
+ ri += 1;
}
- ri += 1;
}
}
}
@@ -682,7 +692,7 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, tr
block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
}
- storeBlock<double, Packet2d, Index>(blockA + ri, block);
+ storeBlock<double, Packet2d, Index, 2>(blockA + ri, block);
ri += 2*vectorSize;
}
@@ -759,7 +769,7 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, fa
block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
- storeBlock<double, Packet2d, Index>(blockB + ri, block);
+ storeBlock<double, Packet2d, Index, 4>(blockB + ri, block);
}
ri += 4*vectorSize;
@@ -790,19 +800,17 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, fa
if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
}
- if (j < cols)
- {
- if(PanelMode) ri += offset*(cols - j);
+ if(PanelMode) ri += offset;
+ for(; j < cols; j++)
+ {
for(Index i = 0; i < depth; i++)
{
- Index k = j;
- for(; k < cols; k++)
- {
- blockB[ri] = rhs(i, k);
- ri += 1;
- }
+ blockB[ri] = rhs(i, j);
+ ri += 1;
}
+
+ if(PanelMode) ri += stride - depth;
}
}
};
@@ -863,8 +871,8 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
blocki.packet[1] = -blocki.packet[1];
}
- storeBlock<double, Packet, Index>(blockAt + rir, blockr);
- storeBlock<double, Packet, Index>(blockAt + rii, blocki);
+ storeBlock<double, Packet, Index, 2>(blockAt + rir, blockr);
+ storeBlock<double, Packet, Index, 2>(blockAt + rii, blocki);
rir += 2*vectorSize;
rii += 2*vectorSize;
@@ -943,7 +951,7 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
PacketBlock<PacketC,4> cblock;
PacketBlock<Packet,2> blockr, blocki;
- bload<DataMapper, PacketC, Index, 2, 0, ColMajor>(cblock, rhs, i, j);
+ bload<DataMapper, PacketC, Index, 2, ColMajor, false, 4>(cblock, rhs, i, j);
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
@@ -957,8 +965,8 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
blocki.packet[1] = -blocki.packet[1];
}
- storeBlock<double, Packet, Index>(blockBt + rir, blockr);
- storeBlock<double, Packet, Index>(blockBt + rii, blocki);
+ storeBlock<double, Packet, Index, 2>(blockBt + rir, blockr);
+ storeBlock<double, Packet, Index, 2>(blockBt + rii, blocki);
rir += 2*vectorSize;
rii += 2*vectorSize;
@@ -967,27 +975,26 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
}
- if (j < cols)
+ if(PanelMode) rir -= (offset*(2*vectorSize - 1));
+
+ for(; j < cols; j++)
{
- if(PanelMode) rir += (offset*(cols - j - 2*vectorSize));
- rii = rir + (((PanelMode) ? stride : depth) * (cols - j));
+ rii = rir + ((PanelMode) ? stride : depth);
for(Index i = 0; i < depth; i++)
{
- Index k = j;
- for(; k < cols; k++)
- {
- blockBt[rir] = rhs(i, k).real();
+ blockBt[rir] = rhs(i, j).real();
- if(Conjugate)
- blockBt[rii] = -rhs(i, k).imag();
- else
- blockBt[rii] = rhs(i, k).imag();
+ if(Conjugate)
+ blockBt[rii] = -rhs(i, j).imag();
+ else
+ blockBt[rii] = rhs(i, j).imag();
- rir += 1;
- rii += 1;
- }
+ rir += 1;
+ rii += 1;
}
+
+ rir += ((PanelMode) ? (2*stride - depth) : depth);
}
}
};
@@ -997,31 +1004,32 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
**************/
// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
-template<typename Packet, bool NegativeAccumulate>
-EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,4>* acc, const Packet& lhsV, const Packet* rhsV)
-{
- if(NegativeAccumulate)
- {
- acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
- acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
- acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
- acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
- } else {
- acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
- acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
- acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
- acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
- }
-}
-
-template<typename Packet, bool NegativeAccumulate>
-EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,1>* acc, const Packet& lhsV, const Packet* rhsV)
+template<typename Packet, bool NegativeAccumulate, int N>
+EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,N>* acc, const Packet& lhsV, const Packet* rhsV)
{
if(NegativeAccumulate)
{
acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ if (N > 1) {
+ acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
+ }
+ if (N > 2) {
+ acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
+ }
+ if (N > 3) {
+ acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
+ }
} else {
acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
+ if (N > 1) {
+ acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
+ }
+ if (N > 2) {
+ acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
+ }
+ if (N > 3) {
+ acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
+ }
}
}
@@ -1030,11 +1038,11 @@ EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, con
{
Packet lhsV = pload<Packet>(lhs);
- pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
+ pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
}
-template<typename Scalar, typename Packet, typename Index>
-EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows)
+template<typename Scalar, typename Packet, typename Index, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV)
{
#ifdef _ARCH_PWR9
lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
@@ -1046,32 +1054,32 @@ EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, In
#endif
}
-template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate>
-EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows)
+template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
{
Packet lhsV;
- loadPacketRemaining<Scalar, Packet, Index>(lhs, lhsV, remaining_rows);
+ loadPacketRemaining<Scalar, Packet, Index, remaining_rows>(lhs, lhsV);
- pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
+ pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
}
// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
{
- pger_common<Packet, false>(accReal, lhsV, rhsV);
+ pger_common<Packet, false, N>(accReal, lhsV, rhsV);
if(LhsIsReal)
{
- pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
+ pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
EIGEN_UNUSED_VARIABLE(lhsVi);
} else {
if (!RhsIsReal) {
- pger_common<Packet, ConjugateLhs == ConjugateRhs>(accReal, lhsVi, rhsVi);
- pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
+ pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
+ pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
} else {
EIGEN_UNUSED_VARIABLE(rhsVi);
}
- pger_common<Packet, ConjugateLhs>(accImag, lhsVi, rhsV);
+ pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
}
}
@@ -1086,8 +1094,8 @@ EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packe
pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
}
-template<typename Scalar, typename Packet, typename Index, bool LhsIsReal>
-EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows)
+template<typename Scalar, typename Packet, typename Index, bool LhsIsReal, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi)
{
#ifdef _ARCH_PWR9
lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar));
@@ -1103,11 +1111,11 @@ EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar
#endif
}
-template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows)
+template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
{
Packet lhsV, lhsVi;
- loadPacketRemaining<Scalar, Packet, Index, LhsIsReal>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows);
+ loadPacketRemaining<Scalar, Packet, Index, LhsIsReal, remaining_rows>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi);
pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
}
@@ -1119,132 +1127,142 @@ EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs)
}
// Zero the accumulator on PacketBlock.
-template<typename Scalar, typename Packet>
-EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,4>& acc)
-{
- acc.packet[0] = pset1<Packet>((Scalar)0);
- acc.packet[1] = pset1<Packet>((Scalar)0);
- acc.packet[2] = pset1<Packet>((Scalar)0);
- acc.packet[3] = pset1<Packet>((Scalar)0);
-}
-
-template<typename Scalar, typename Packet>
-EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,1>& acc)
+template<typename Scalar, typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,N>& acc)
{
acc.packet[0] = pset1<Packet>((Scalar)0);
+ if (N > 1) {
+ acc.packet[1] = pset1<Packet>((Scalar)0);
+ }
+ if (N > 2) {
+ acc.packet[2] = pset1<Packet>((Scalar)0);
+ }
+ if (N > 3) {
+ acc.packet[3] = pset1<Packet>((Scalar)0);
+ }
}
// Scale the PacketBlock vectors by alpha.
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
-{
- acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
- acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
- acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
- acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
-}
-
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
{
acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
+ if (N > 1) {
+ acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
+ }
+ if (N > 2) {
+ acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
+ }
+ if (N > 3) {
+ acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
+ }
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
-{
- acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
- acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
- acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
- acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
-}
-
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
{
acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
+ if (N > 1) {
+ acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
+ }
+ if (N > 2) {
+ acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
+ }
+ if (N > 3) {
+ acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
+ }
}
// Complex version of PacketBlock scaling.
template<typename Packet, int N>
EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag)
{
- bscalec_common<Packet>(cReal, aReal, bReal);
+ bscalec_common<Packet, N>(cReal, aReal, bReal);
- bscalec_common<Packet>(cImag, aImag, bReal);
+ bscalec_common<Packet, N>(cImag, aImag, bReal);
- pger_common<Packet, true>(&cReal, bImag, aImag.packet);
+ pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
- pger_common<Packet, false>(&cImag, bImag, aReal.packet);
+ pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,4>& acc, const Packet& pMask)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,N>& acc, const Packet& pMask)
{
acc.packet[0] = pand(acc.packet[0], pMask);
- acc.packet[1] = pand(acc.packet[1], pMask);
- acc.packet[2] = pand(acc.packet[2], pMask);
- acc.packet[3] = pand(acc.packet[3], pMask);
+ if (N > 1) {
+ acc.packet[1] = pand(acc.packet[1], pMask);
+ }
+ if (N > 2) {
+ acc.packet[2] = pand(acc.packet[2], pMask);
+ }
+ if (N > 3) {
+ acc.packet[3] = pand(acc.packet[3], pMask);
+ }
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,4>& aReal, PacketBlock<Packet,4>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,4>& cReal, PacketBlock<Packet,4>& cImag, const Packet& pMask)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag, const Packet& pMask)
{
- band<Packet>(aReal, pMask);
- band<Packet>(aImag, pMask);
+ band<Packet, N>(aReal, pMask);
+ band<Packet, N>(aImag, pMask);
- bscalec<Packet,4>(aReal, aImag, bReal, bImag, cReal, cImag);
+ bscalec<Packet,N>(aReal, aImag, bReal, bImag, cReal, cImag);
}
// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col)
-{
- if (StorageOrder == RowMajor) {
- acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
- acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
- acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
- acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
- } else {
- acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
- acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
- acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
- acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
- }
-}
-
-// An overload of bload when you have a PacketBLock with 8 vectors.
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col)
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col)
{
if (StorageOrder == RowMajor) {
- acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
- acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
- acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
- acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
- acc.packet[4] = res.template loadPacket<Packet>(row + 0, col + (N+1)*accCols);
- acc.packet[5] = res.template loadPacket<Packet>(row + 1, col + (N+1)*accCols);
- acc.packet[6] = res.template loadPacket<Packet>(row + 2, col + (N+1)*accCols);
- acc.packet[7] = res.template loadPacket<Packet>(row + 3, col + (N+1)*accCols);
+ acc.packet[0] = res.template loadPacket<Packet>(row + 0, col);
+ if (N > 1) {
+ acc.packet[1] = res.template loadPacket<Packet>(row + 1, col);
+ }
+ if (N > 2) {
+ acc.packet[2] = res.template loadPacket<Packet>(row + 2, col);
+ }
+ if (N > 3) {
+ acc.packet[3] = res.template loadPacket<Packet>(row + 3, col);
+ }
+ if (Complex) {
+ acc.packet[0+N] = res.template loadPacket<Packet>(row + 0, col + accCols);
+ if (N > 1) {
+ acc.packet[1+N] = res.template loadPacket<Packet>(row + 1, col + accCols);
+ }
+ if (N > 2) {
+ acc.packet[2+N] = res.template loadPacket<Packet>(row + 2, col + accCols);
+ }
+ if (N > 3) {
+ acc.packet[3+N] = res.template loadPacket<Packet>(row + 3, col + accCols);
+ }
+ }
} else {
- acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
- acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
- acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
- acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
- acc.packet[4] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
- acc.packet[5] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 1);
- acc.packet[6] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 2);
- acc.packet[7] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 3);
+ acc.packet[0] = res.template loadPacket<Packet>(row, col + 0);
+ if (N > 1) {
+ acc.packet[1] = res.template loadPacket<Packet>(row, col + 1);
+ }
+ if (N > 2) {
+ acc.packet[2] = res.template loadPacket<Packet>(row, col + 2);
+ }
+ if (N > 3) {
+ acc.packet[3] = res.template loadPacket<Packet>(row, col + 3);
+ }
+ if (Complex) {
+ acc.packet[0+N] = res.template loadPacket<Packet>(row + accCols, col + 0);
+ if (N > 1) {
+ acc.packet[1+N] = res.template loadPacket<Packet>(row + accCols, col + 1);
+ }
+ if (N > 2) {
+ acc.packet[2+N] = res.template loadPacket<Packet>(row + accCols, col + 2);
+ }
+ if (N > 3) {
+ acc.packet[3+N] = res.template loadPacket<Packet>(row + accCols, col + 3);
+ }
+ }
}
}
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,2>& acc, const DataMapper& res, Index row, Index col)
-{
- acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
- acc.packet[1] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
-}
-
const static Packet4i mask41 = { -1, 0, 0, 0 };
const static Packet4i mask42 = { -1, -1, 0, 0 };
const static Packet4i mask43 = { -1, -1, -1, 0 };
@@ -1275,22 +1293,44 @@ EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const int remaining_rows)
}
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha, const Packet& pMask)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha, const Packet& pMask)
{
- band<Packet>(accZ, pMask);
+ band<Packet, N>(accZ, pMask);
- bscale<Packet>(acc, accZ, pAlpha);
+ bscale<Packet, N>(acc, accZ, pAlpha);
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3)
+template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
+pbroadcastN_old(const __UNPACK_TYPE__(Packet) *a,
+ Packet& a0, Packet& a1, Packet& a2, Packet& a3)
+{
+ a0 = pset1<Packet>(a[0]);
+ if (N > 1) {
+ a1 = pset1<Packet>(a[1]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a1);
+ }
+ if (N > 2) {
+ a2 = pset1<Packet>(a[2]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a2);
+ }
+ if (N > 3) {
+ a3 = pset1<Packet>(a[3]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a3);
+ }
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet4f,4>(const float* a, Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
{
- pbroadcast4<Packet>(a, a0, a1, a2, a3);
+ pbroadcast4<Packet4f>(a, a0, a1, a2, a3);
}
template<>
-EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
+EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet2d,4>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
{
a1 = pload<Packet2d>(a);
a3 = pload<Packet2d>(a + 2);
@@ -1300,89 +1340,96 @@ EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0
a3 = vec_splat(a3, 1);
}
-// PEEL loop factor.
-#define PEEL 7
-
-template<typename Scalar, typename Packet, typename Index>
-EIGEN_ALWAYS_INLINE void MICRO_EXTRA_COL(
- const Scalar* &lhs_ptr,
- const Scalar* &rhs_ptr,
- PacketBlock<Packet,1> &accZero,
- Index remaining_rows,
- Index remaining_cols)
+template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
+pbroadcastN(const __UNPACK_TYPE__(Packet) *a,
+ Packet& a0, Packet& a1, Packet& a2, Packet& a3)
{
- Packet rhsV[1];
- rhsV[0] = pset1<Packet>(rhs_ptr[0]);
- pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
- lhs_ptr += remaining_rows;
- rhs_ptr += remaining_cols;
+ a0 = pset1<Packet>(a[0]);
+ if (N > 1) {
+ a1 = pset1<Packet>(a[1]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a1);
+ }
+ if (N > 2) {
+ a2 = pset1<Packet>(a[2]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a2);
+ }
+ if (N > 3) {
+ a3 = pset1<Packet>(a[3]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a3);
+ }
}
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
-EIGEN_STRONG_INLINE void gemm_extra_col(
- const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
- Index depth,
- Index strideA,
- Index offsetA,
- Index row,
- Index col,
- Index remaining_rows,
- Index remaining_cols,
- const Packet& pAlpha)
+template<> EIGEN_ALWAYS_INLINE void
+pbroadcastN<Packet4f,4>(const float *a,
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
{
- const Scalar* rhs_ptr = rhs_base;
- const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
- PacketBlock<Packet,1> accZero;
+ a3 = pload<Packet4f>(a);
+ a0 = vec_splat(a3, 0);
+ a1 = vec_splat(a3, 1);
+ a2 = vec_splat(a3, 2);
+ a3 = vec_splat(a3, 3);
+}
- bsetzero<Scalar, Packet>(accZero);
+// PEEL loop factor.
+#define PEEL 7
+#define PEEL_ROW 7
- Index remaining_depth = (depth & -accRows);
- Index k = 0;
- for(; k + PEEL <= remaining_depth; k+= PEEL)
- {
- EIGEN_POWER_PREFETCH(rhs_ptr);
- EIGEN_POWER_PREFETCH(lhs_ptr);
- for (int l = 0; l < PEEL; l++) {
- MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
- }
- }
- for(; k < remaining_depth; k++)
- {
- MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
+#define MICRO_UNROLL_PEEL(func) \
+ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
+
+#define MICRO_ZERO_PEEL(peel) \
+ if ((PEEL_ROW > peel) && (peel != 0)) { \
+ bsetzero<Scalar, Packet, accRows>(accZero##peel); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accZero##peel); \
}
- for(; k < depth; k++)
- {
- Packet rhsV[1];
- rhsV[0] = pset1<Packet>(rhs_ptr[0]);
- pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
- lhs_ptr += remaining_rows;
- rhs_ptr += remaining_cols;
+
+#define MICRO_ZERO_PEEL_ROW \
+ MICRO_UNROLL_PEEL(MICRO_ZERO_PEEL);
+
+#define MICRO_WORK_PEEL(peel) \
+ if (PEEL_ROW > peel) { \
+ pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
}
- accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]);
- for(Index i = 0; i < remaining_rows; i++) {
- res(row + i, col) += accZero.packet[0][i];
+#define MICRO_WORK_PEEL_ROW \
+ Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
+ MICRO_UNROLL_PEEL(MICRO_WORK_PEEL); \
+ lhs_ptr += (remaining_rows * PEEL_ROW); \
+ rhs_ptr += (accRows * PEEL_ROW);
+
+#define MICRO_ADD_PEEL(peel, sum) \
+ if (PEEL_ROW > peel) { \
+ for (Index i = 0; i < accRows; i++) { \
+ accZero##sum.packet[i] += accZero##peel.packet[i]; \
+ } \
}
-}
-template<typename Scalar, typename Packet, typename Index, const Index accRows>
+#define MICRO_ADD_PEEL_ROW \
+ MICRO_ADD_PEEL(4, 0) MICRO_ADD_PEEL(5, 1) MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) \
+ MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
+
+template<typename Scalar, typename Packet, typename Index, const Index accRows, const Index remaining_rows>
EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
const Scalar* &lhs_ptr,
const Scalar* &rhs_ptr,
- PacketBlock<Packet,4> &accZero,
- Index remaining_rows)
+ PacketBlock<Packet,accRows> &accZero)
{
Packet rhsV[4];
- pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
- pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
+ pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
lhs_ptr += remaining_rows;
rhs_ptr += accRows;
}
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
-EIGEN_STRONG_INLINE void gemm_extra_row(
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
@@ -1393,59 +1440,89 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
Index col,
Index rows,
Index cols,
- Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask)
{
const Scalar* rhs_ptr = rhs_base;
const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
- PacketBlock<Packet,4> accZero, acc;
+ PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
- bsetzero<Scalar, Packet>(accZero);
+ bsetzero<Scalar, Packet, accRows>(accZero0);
- Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
+ Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
Index k = 0;
- for(; k + PEEL <= remaining_depth; k+= PEEL)
- {
- EIGEN_POWER_PREFETCH(rhs_ptr);
- EIGEN_POWER_PREFETCH(lhs_ptr);
- for (int l = 0; l < PEEL; l++) {
- MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
- }
+ if (remaining_depth >= PEEL_ROW) {
+ MICRO_ZERO_PEEL_ROW
+ do
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ EIGEN_POWER_PREFETCH(lhs_ptr);
+ MICRO_WORK_PEEL_ROW
+ } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
+ MICRO_ADD_PEEL_ROW
}
for(; k < remaining_depth; k++)
{
- MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows, remaining_rows>(lhs_ptr, rhs_ptr, accZero0);
}
if ((remaining_depth == depth) && (rows >= accCols))
{
- for(Index j = 0; j < 4; j++) {
- acc.packet[j] = res.template loadPacket<Packet>(row, col + j);
- }
- bscale<Packet>(acc, accZero, pAlpha, pMask);
- res.template storePacketBlock<Packet,4>(row, col, acc);
+ bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row, 0);
+ bscale<Packet,accRows>(acc, accZero0, pAlpha, pMask);
+ res.template storePacketBlock<Packet,accRows>(row, 0, acc);
} else {
for(; k < depth; k++)
{
Packet rhsV[4];
- pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
- pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
+ pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<accRows, Scalar, Packet, Index, false, remaining_rows>(&accZero0, lhs_ptr, rhsV);
lhs_ptr += remaining_rows;
rhs_ptr += accRows;
}
- for(Index j = 0; j < 4; j++) {
- accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]);
- }
- for(Index j = 0; j < 4; j++) {
+ for(Index j = 0; j < accRows; j++) {
+ accZero0.packet[j] = vec_mul(pAlpha, accZero0.packet[j]);
for(Index i = 0; i < remaining_rows; i++) {
- res(row + i, col + j) += accZero.packet[j][i];
+ res(row + i, j) += accZero0.packet[j][i];
}
}
}
}
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_ALWAYS_INLINE void gemm_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask)
+{
+ switch(remaining_rows) {
+ case 1:
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
+ break;
+ case 2:
+ if (sizeof(Scalar) == sizeof(float)) {
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
+ }
+ break;
+ default:
+ if (sizeof(Scalar) == sizeof(float)) {
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
+ }
+ break;
+ }
+}
+
#define MICRO_UNROLL(func) \
func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
@@ -1464,34 +1541,24 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
#define MICRO_WORK_ONE(iter, peel) \
if (unroll_factor > iter) { \
- pger_common<Packet, false>(&accZero##iter, lhsV##iter, rhsV##peel); \
+ pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
}
#define MICRO_TYPE_PEEL4(func, func2, peel) \
if (PEEL > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
- pbroadcast4<Packet>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
- MICRO_UNROLL_WORK(func, func2, peel) \
- } else { \
- EIGEN_UNUSED_VARIABLE(rhsV##peel); \
- }
-
-#define MICRO_TYPE_PEEL1(func, func2, peel) \
- if (PEEL > peel) { \
- Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
- rhsV##peel[0] = pset1<Packet>(rhs_ptr[remaining_cols * peel]); \
+ pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
MICRO_UNROLL_WORK(func, func2, peel) \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
}
#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
- Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
func(func1,func2,0); func(func1,func2,1); \
func(func1,func2,2); func(func1,func2,3); \
func(func1,func2,4); func(func1,func2,5); \
- func(func1,func2,6); func(func1,func2,7); \
- func(func1,func2,8); func(func1,func2,9);
+ func(func1,func2,6); func(func1,func2,7);
#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
Packet rhsV0[M]; \
@@ -1505,17 +1572,9 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
rhs_ptr += accRows;
-#define MICRO_ONE_PEEL1 \
- MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
- rhs_ptr += (remaining_cols * PEEL);
-
-#define MICRO_ONE1 \
- MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
- rhs_ptr += remaining_cols;
-
#define MICRO_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- bsetzero<Scalar, Packet>(accZero##iter); \
+ bsetzero<Scalar, Packet, accRows>(accZero##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accZero##iter); \
}
@@ -1524,7 +1583,7 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
#define MICRO_SRC_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
+ lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
} else { \
EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
}
@@ -1540,25 +1599,13 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
#define MICRO_STORE_ONE(iter) \
if (unroll_factor > iter) { \
- acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
- acc.packet[1] = res.template loadPacket<Packet>(row + iter*accCols, col + 1); \
- acc.packet[2] = res.template loadPacket<Packet>(row + iter*accCols, col + 2); \
- acc.packet[3] = res.template loadPacket<Packet>(row + iter*accCols, col + 3); \
- bscale<Packet>(acc, accZero##iter, pAlpha); \
- res.template storePacketBlock<Packet,4>(row + iter*accCols, col, acc); \
+ bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
+ bscale<Packet,accRows>(acc, accZero##iter, pAlpha); \
+ res.template storePacketBlock<Packet,accRows>(row + iter*accCols, 0, acc); \
}
#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
-#define MICRO_COL_STORE_ONE(iter) \
- if (unroll_factor > iter) { \
- acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
- bscale<Packet>(acc, accZero##iter, pAlpha); \
- res.template storePacketBlock<Packet,1>(row + iter*accCols, col, acc); \
- }
-
-#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE)
-
template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
const DataMapper& res,
@@ -1566,15 +1613,13 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
const Scalar* rhs_base,
Index depth,
Index strideA,
- Index offsetA,
Index& row,
- Index col,
const Packet& pAlpha)
{
const Scalar* rhs_ptr = rhs_base;
const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
- PacketBlock<Packet,4> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
- PacketBlock<Packet,4> acc;
+ PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+ PacketBlock<Packet,accRows> acc;
MICRO_SRC_PTR
MICRO_DST_PTR
@@ -1595,101 +1640,100 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
row += unroll_factor*accCols;
}
-template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
-EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration(
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_ALWAYS_INLINE void gemm_cols(
const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
+ const Scalar* blockA,
+ const Scalar* blockB,
Index depth,
Index strideA,
Index offsetA,
- Index& row,
+ Index strideB,
+ Index offsetB,