eigen3/fix_ppc64le_always_inline_680.patch

3139 lines
135 KiB
Diff
Raw Normal View History

From 9e3873b1dce3ba65980c7e7b979325dac2fb4bbd Mon Sep 17 00:00:00 2001
From: Chip-Kerchner <chip.kerchner@ibm.com>
Date: Wed, 20 Oct 2021 11:06:50 -0500
Subject: [PATCH 1/2] New branch for inverting rows and depth in non-vectorized
portion of packing.
---
Eigen/src/Core/arch/AltiVec/Complex.h | 10 +-
Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 1546 ++++++++---------
.../Core/arch/AltiVec/MatrixProductCommon.h | 206 +--
.../src/Core/arch/AltiVec/MatrixProductMMA.h | 335 ++--
4 files changed, 927 insertions(+), 1170 deletions(-)
diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h
index f730ce8d3..4fd923e84 100644
--- a/Eigen/src/Core/arch/AltiVec/Complex.h
+++ b/Eigen/src/Core/arch/AltiVec/Complex.h
@@ -129,20 +129,20 @@ template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<
template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstore((float*)to, from.v); }
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstoreu((float*)to, from.v); }
-EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>* from0, const std::complex<float>* from1)
+EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>& from0, const std::complex<float>& from1)
{
Packet4f res0, res1;
#ifdef __VSX__
- __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (*from0));
- __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (*from1));
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (from0));
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (from1));
#ifdef _BIG_ENDIAN
__asm__ ("xxpermdi %x0, %x1, %x2, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
#else
__asm__ ("xxpermdi %x0, %x2, %x1, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
#endif
#else
- *reinterpret_cast<std::complex<float> *>(&res0) = *from0;
- *reinterpret_cast<std::complex<float> *>(&res1) = *from1;
+ *reinterpret_cast<std::complex<float> *>(&res0) = from0;
+ *reinterpret_cast<std::complex<float> *>(&res1) = from1;
res0 = vec_perm(res0, res1, p16uc_TRANSPOSE64_HI);
#endif
return Packet2cf(res0);
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
index 1d67d60d0..bd5da3623 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
@@ -166,24 +166,23 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* bloc
rir += vectorDelta;
}
- if (j < cols)
+
+ for(; j < cols; j++)
{
- rii = rir + ((cols - j) * rows);
+ rii = rir + rows;
for(Index i = k2; i < depth; i++)
{
- Index k = j;
- for(; k < cols; k++)
- {
- std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, k, rhs);
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j, rhs);
- blockBf[rir] = v.real();
- blockBf[rii] = v.imag();
+ blockBf[rir] = v.real();
+ blockBf[rii] = v.imag();
- rir += 1;
- rii += 1;
- }
+ rir += 1;
+ rii += 1;
}
+
+ rir += rows;
}
}
@@ -262,19 +261,15 @@ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs
}
}
- if (j < cols)
+ for(; j < cols; j++)
{
for(Index i = k2; i < depth; i++)
{
- Index k = j;
- for(; k < cols; k++)
- {
- if(k <= i)
- blockB[ri] = rhs(i, k);
- else
- blockB[ri] = rhs(k, i);
- ri += 1;
- }
+ if(j <= i)
+ blockB[ri] = rhs(i, j);
+ else
+ blockB[ri] = rhs(j, i);
+ ri += 1;
}
}
}
@@ -408,22 +403,18 @@ struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
* and offset and behaves accordingly.
**/
-template<typename Scalar, typename Packet, typename Index>
-EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,4>& block)
-{
- const Index size = 16 / sizeof(Scalar);
- pstore<Scalar>(to + (0 * size), block.packet[0]);
- pstore<Scalar>(to + (1 * size), block.packet[1]);
- pstore<Scalar>(to + (2 * size), block.packet[2]);
- pstore<Scalar>(to + (3 * size), block.packet[3]);
-}
-
-template<typename Scalar, typename Packet, typename Index>
-EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,2>& block)
+template<typename Scalar, typename Packet, typename Index, int N>
+EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
{
const Index size = 16 / sizeof(Scalar);
pstore<Scalar>(to + (0 * size), block.packet[0]);
pstore<Scalar>(to + (1 * size), block.packet[1]);
+ if (N > 2) {
+ pstore<Scalar>(to + (2 * size), block.packet[2]);
+ }
+ if (N > 3) {
+ pstore<Scalar>(to + (3 * size), block.packet[3]);
+ }
}
// General template for lhs & rhs complex packing.
@@ -449,9 +440,9 @@ struct dhs_cpack {
PacketBlock<PacketC,8> cblock;
if (UseLhs) {
- bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, j, i);
+ bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, j, i);
} else {
- bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, i, j);
+ bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, i, j);
}
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
@@ -478,8 +469,8 @@ struct dhs_cpack {
ptranspose(blocki);
}
- storeBlock<Scalar, Packet, Index>(blockAt + rir, blockr);
- storeBlock<Scalar, Packet, Index>(blockAt + rii, blocki);
+ storeBlock<Scalar, Packet, Index, 4>(blockAt + rir, blockr);
+ storeBlock<Scalar, Packet, Index, 4>(blockAt + rii, blocki);
rir += 4*vectorSize;
rii += 4*vectorSize;
@@ -499,21 +490,12 @@ struct dhs_cpack {
cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2);
}
} else {
- std::complex<Scalar> lhs0, lhs1;
if (UseLhs) {
- lhs0 = lhs(j + 0, i);
- lhs1 = lhs(j + 1, i);
- cblock.packet[0] = pload2(&lhs0, &lhs1);
- lhs0 = lhs(j + 2, i);
- lhs1 = lhs(j + 3, i);
- cblock.packet[1] = pload2(&lhs0, &lhs1);
+ cblock.packet[0] = pload2(lhs(j + 0, i), lhs(j + 1, i));
+ cblock.packet[1] = pload2(lhs(j + 2, i), lhs(j + 3, i));
} else {
- lhs0 = lhs(i, j + 0);
- lhs1 = lhs(i, j + 1);
- cblock.packet[0] = pload2(&lhs0, &lhs1);
- lhs0 = lhs(i, j + 2);
- lhs1 = lhs(i, j + 3);
- cblock.packet[1] = pload2(&lhs0, &lhs1);
+ cblock.packet[0] = pload2(lhs(i, j + 0), lhs(i, j + 1));
+ cblock.packet[1] = pload2(lhs(i, j + 2), lhs(i, j + 3));
}
}
@@ -535,34 +517,50 @@ struct dhs_cpack {
rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
}
- if (j < rows)
+ if (!UseLhs)
{
- if(PanelMode) rir += (offset*(rows - j - vectorSize));
- rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
+ if(PanelMode) rir -= (offset*(vectorSize - 1));
- for(Index i = 0; i < depth; i++)
+ for(; j < rows; j++)
{
- Index k = j;
- for(; k < rows; k++)
+ rii = rir + ((PanelMode) ? stride : depth);
+
+ for(Index i = 0; i < depth; i++)
{
- if (UseLhs) {
+ blockAt[rir] = lhs(i, j).real();
+
+ if(Conjugate)
+ blockAt[rii] = -lhs(i, j).imag();
+ else
+ blockAt[rii] = lhs(i, j).imag();
+
+ rir += 1;
+ rii += 1;
+ }
+
+ rir += ((PanelMode) ? (2*stride - depth) : depth);
+ }
+ } else {
+ if (j < rows)
+ {
+ if(PanelMode) rir += (offset*(rows - j - vectorSize));
+ rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
blockAt[rir] = lhs(k, i).real();
if(Conjugate)
blockAt[rii] = -lhs(k, i).imag();
else
blockAt[rii] = lhs(k, i).imag();
- } else {
- blockAt[rir] = lhs(i, k).real();
- if(Conjugate)
- blockAt[rii] = -lhs(i, k).imag();
- else
- blockAt[rii] = lhs(i, k).imag();
+ rir += 1;
+ rii += 1;
}
-
- rir += 1;
- rii += 1;
}
}
}
@@ -588,16 +586,16 @@ struct dhs_pack{
PacketBlock<Packet,4> block;
if (UseLhs) {
- bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, j, i);
+ bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, j, i);
} else {
- bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, i, j);
+ bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, i, j);
}
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
{
ptranspose(block);
}
- storeBlock<Scalar, Packet, Index>(blockA + ri, block);
+ storeBlock<Scalar, Packet, Index, 4>(blockA + ri, block);
ri += 4*vectorSize;
}
@@ -632,21 +630,33 @@ struct dhs_pack{
if(PanelMode) ri += vectorSize*(stride - offset - depth);
}
- if (j < rows)
+ if (!UseLhs)
{
- if(PanelMode) ri += offset*(rows - j);
+ if(PanelMode) ri += offset;
- for(Index i = 0; i < depth; i++)
+ for(; j < rows; j++)
{
- Index k = j;
- for(; k < rows; k++)
+ for(Index i = 0; i < depth; i++)
{
- if (UseLhs) {
+ blockA[ri] = lhs(i, j);
+ ri += 1;
+ }
+
+ if(PanelMode) ri += stride - depth;
+ }
+ } else {
+ if (j < rows)
+ {
+ if(PanelMode) ri += offset*(rows - j);
+
+ for(Index i = 0; i < depth; i++)
+ {
+ Index k = j;
+ for(; k < rows; k++)
+ {
blockA[ri] = lhs(k, i);
- } else {
- blockA[ri] = lhs(i, k);
+ ri += 1;
}
- ri += 1;
}
}
}
@@ -682,7 +692,7 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, tr
block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
}
- storeBlock<double, Packet2d, Index>(blockA + ri, block);
+ storeBlock<double, Packet2d, Index, 2>(blockA + ri, block);
ri += 2*vectorSize;
}
@@ -759,7 +769,7 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, fa
block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
- storeBlock<double, Packet2d, Index>(blockB + ri, block);
+ storeBlock<double, Packet2d, Index, 4>(blockB + ri, block);
}
ri += 4*vectorSize;
@@ -790,19 +800,17 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, fa
if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
}
- if (j < cols)
- {
- if(PanelMode) ri += offset*(cols - j);
+ if(PanelMode) ri += offset;
+ for(; j < cols; j++)
+ {
for(Index i = 0; i < depth; i++)
{
- Index k = j;
- for(; k < cols; k++)
- {
- blockB[ri] = rhs(i, k);
- ri += 1;
- }
+ blockB[ri] = rhs(i, j);
+ ri += 1;
}
+
+ if(PanelMode) ri += stride - depth;
}
}
};
@@ -863,8 +871,8 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
blocki.packet[1] = -blocki.packet[1];
}
- storeBlock<double, Packet, Index>(blockAt + rir, blockr);
- storeBlock<double, Packet, Index>(blockAt + rii, blocki);
+ storeBlock<double, Packet, Index, 2>(blockAt + rir, blockr);
+ storeBlock<double, Packet, Index, 2>(blockAt + rii, blocki);
rir += 2*vectorSize;
rii += 2*vectorSize;
@@ -943,7 +951,7 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
PacketBlock<PacketC,4> cblock;
PacketBlock<Packet,2> blockr, blocki;
- bload<DataMapper, PacketC, Index, 2, 0, ColMajor>(cblock, rhs, i, j);
+ bload<DataMapper, PacketC, Index, 2, ColMajor, false, 4>(cblock, rhs, i, j);
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
@@ -957,8 +965,8 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
blocki.packet[1] = -blocki.packet[1];
}
- storeBlock<double, Packet, Index>(blockBt + rir, blockr);
- storeBlock<double, Packet, Index>(blockBt + rii, blocki);
+ storeBlock<double, Packet, Index, 2>(blockBt + rir, blockr);
+ storeBlock<double, Packet, Index, 2>(blockBt + rii, blocki);
rir += 2*vectorSize;
rii += 2*vectorSize;
@@ -967,27 +975,26 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
}
- if (j < cols)
+ if(PanelMode) rir -= (offset*(2*vectorSize - 1));
+
+ for(; j < cols; j++)
{
- if(PanelMode) rir += (offset*(cols - j - 2*vectorSize));
- rii = rir + (((PanelMode) ? stride : depth) * (cols - j));
+ rii = rir + ((PanelMode) ? stride : depth);
for(Index i = 0; i < depth; i++)
{
- Index k = j;
- for(; k < cols; k++)
- {
- blockBt[rir] = rhs(i, k).real();
+ blockBt[rir] = rhs(i, j).real();
- if(Conjugate)
- blockBt[rii] = -rhs(i, k).imag();
- else
- blockBt[rii] = rhs(i, k).imag();
+ if(Conjugate)
+ blockBt[rii] = -rhs(i, j).imag();
+ else
+ blockBt[rii] = rhs(i, j).imag();
- rir += 1;
- rii += 1;
- }
+ rir += 1;
+ rii += 1;
}
+
+ rir += ((PanelMode) ? (2*stride - depth) : depth);
}
}
};
@@ -997,31 +1004,32 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
**************/
// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
-template<typename Packet, bool NegativeAccumulate>
-EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,4>* acc, const Packet& lhsV, const Packet* rhsV)
-{
- if(NegativeAccumulate)
- {
- acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
- acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
- acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
- acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
- } else {
- acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
- acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
- acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
- acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
- }
-}
-
-template<typename Packet, bool NegativeAccumulate>
-EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,1>* acc, const Packet& lhsV, const Packet* rhsV)
+template<typename Packet, bool NegativeAccumulate, int N>
+EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,N>* acc, const Packet& lhsV, const Packet* rhsV)
{
if(NegativeAccumulate)
{
acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
+ if (N > 1) {
+ acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
+ }
+ if (N > 2) {
+ acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
+ }
+ if (N > 3) {
+ acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
+ }
} else {
acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
+ if (N > 1) {
+ acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
+ }
+ if (N > 2) {
+ acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
+ }
+ if (N > 3) {
+ acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
+ }
}
}
@@ -1030,11 +1038,11 @@ EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, con
{
Packet lhsV = pload<Packet>(lhs);
- pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
+ pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
}
-template<typename Scalar, typename Packet, typename Index>
-EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows)
+template<typename Scalar, typename Packet, typename Index, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV)
{
#ifdef _ARCH_PWR9
lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
@@ -1046,32 +1054,32 @@ EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, In
#endif
}
-template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate>
-EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows)
+template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
{
Packet lhsV;
- loadPacketRemaining<Scalar, Packet, Index>(lhs, lhsV, remaining_rows);
+ loadPacketRemaining<Scalar, Packet, Index, remaining_rows>(lhs, lhsV);
- pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
+ pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
}
// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
{
- pger_common<Packet, false>(accReal, lhsV, rhsV);
+ pger_common<Packet, false, N>(accReal, lhsV, rhsV);
if(LhsIsReal)
{
- pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
+ pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
EIGEN_UNUSED_VARIABLE(lhsVi);
} else {
if (!RhsIsReal) {
- pger_common<Packet, ConjugateLhs == ConjugateRhs>(accReal, lhsVi, rhsVi);
- pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
+ pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
+ pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
} else {
EIGEN_UNUSED_VARIABLE(rhsVi);
}
- pger_common<Packet, ConjugateLhs>(accImag, lhsVi, rhsV);
+ pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
}
}
@@ -1086,8 +1094,8 @@ EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packe
pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
}
-template<typename Scalar, typename Packet, typename Index, bool LhsIsReal>
-EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows)
+template<typename Scalar, typename Packet, typename Index, bool LhsIsReal, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi)
{
#ifdef _ARCH_PWR9
lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar));
@@ -1103,11 +1111,11 @@ EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar
#endif
}
-template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows)
+template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
{
Packet lhsV, lhsVi;
- loadPacketRemaining<Scalar, Packet, Index, LhsIsReal>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows);
+ loadPacketRemaining<Scalar, Packet, Index, LhsIsReal, remaining_rows>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi);
pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
}
@@ -1119,132 +1127,142 @@ EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs)
}
// Zero the accumulator on PacketBlock.
-template<typename Scalar, typename Packet>
-EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,4>& acc)
-{
- acc.packet[0] = pset1<Packet>((Scalar)0);
- acc.packet[1] = pset1<Packet>((Scalar)0);
- acc.packet[2] = pset1<Packet>((Scalar)0);
- acc.packet[3] = pset1<Packet>((Scalar)0);
-}
-
-template<typename Scalar, typename Packet>
-EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,1>& acc)
+template<typename Scalar, typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,N>& acc)
{
acc.packet[0] = pset1<Packet>((Scalar)0);
+ if (N > 1) {
+ acc.packet[1] = pset1<Packet>((Scalar)0);
+ }
+ if (N > 2) {
+ acc.packet[2] = pset1<Packet>((Scalar)0);
+ }
+ if (N > 3) {
+ acc.packet[3] = pset1<Packet>((Scalar)0);
+ }
}
// Scale the PacketBlock vectors by alpha.
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
-{
- acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
- acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
- acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
- acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
-}
-
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
{
acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
+ if (N > 1) {
+ acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
+ }
+ if (N > 2) {
+ acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
+ }
+ if (N > 3) {
+ acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
+ }
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
-{
- acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
- acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
- acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
- acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
-}
-
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
{
acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
+ if (N > 1) {
+ acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
+ }
+ if (N > 2) {
+ acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
+ }
+ if (N > 3) {
+ acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
+ }
}
// Complex version of PacketBlock scaling.
template<typename Packet, int N>
EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag)
{
- bscalec_common<Packet>(cReal, aReal, bReal);
+ bscalec_common<Packet, N>(cReal, aReal, bReal);
- bscalec_common<Packet>(cImag, aImag, bReal);
+ bscalec_common<Packet, N>(cImag, aImag, bReal);
- pger_common<Packet, true>(&cReal, bImag, aImag.packet);
+ pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
- pger_common<Packet, false>(&cImag, bImag, aReal.packet);
+ pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,4>& acc, const Packet& pMask)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,N>& acc, const Packet& pMask)
{
acc.packet[0] = pand(acc.packet[0], pMask);
- acc.packet[1] = pand(acc.packet[1], pMask);
- acc.packet[2] = pand(acc.packet[2], pMask);
- acc.packet[3] = pand(acc.packet[3], pMask);
+ if (N > 1) {
+ acc.packet[1] = pand(acc.packet[1], pMask);
+ }
+ if (N > 2) {
+ acc.packet[2] = pand(acc.packet[2], pMask);
+ }
+ if (N > 3) {
+ acc.packet[3] = pand(acc.packet[3], pMask);
+ }
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,4>& aReal, PacketBlock<Packet,4>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,4>& cReal, PacketBlock<Packet,4>& cImag, const Packet& pMask)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag, const Packet& pMask)
{
- band<Packet>(aReal, pMask);
- band<Packet>(aImag, pMask);
+ band<Packet, N>(aReal, pMask);
+ band<Packet, N>(aImag, pMask);
- bscalec<Packet,4>(aReal, aImag, bReal, bImag, cReal, cImag);
+ bscalec<Packet,N>(aReal, aImag, bReal, bImag, cReal, cImag);
}
// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col)
-{
- if (StorageOrder == RowMajor) {
- acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
- acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
- acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
- acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
- } else {
- acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
- acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
- acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
- acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
- }
-}
-
-// An overload of bload when you have a PacketBLock with 8 vectors.
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col)
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col)
{
if (StorageOrder == RowMajor) {
- acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
- acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
- acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
- acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
- acc.packet[4] = res.template loadPacket<Packet>(row + 0, col + (N+1)*accCols);
- acc.packet[5] = res.template loadPacket<Packet>(row + 1, col + (N+1)*accCols);
- acc.packet[6] = res.template loadPacket<Packet>(row + 2, col + (N+1)*accCols);
- acc.packet[7] = res.template loadPacket<Packet>(row + 3, col + (N+1)*accCols);
+ acc.packet[0] = res.template loadPacket<Packet>(row + 0, col);
+ if (N > 1) {
+ acc.packet[1] = res.template loadPacket<Packet>(row + 1, col);
+ }
+ if (N > 2) {
+ acc.packet[2] = res.template loadPacket<Packet>(row + 2, col);
+ }
+ if (N > 3) {
+ acc.packet[3] = res.template loadPacket<Packet>(row + 3, col);
+ }
+ if (Complex) {
+ acc.packet[0+N] = res.template loadPacket<Packet>(row + 0, col + accCols);
+ if (N > 1) {
+ acc.packet[1+N] = res.template loadPacket<Packet>(row + 1, col + accCols);
+ }
+ if (N > 2) {
+ acc.packet[2+N] = res.template loadPacket<Packet>(row + 2, col + accCols);
+ }
+ if (N > 3) {
+ acc.packet[3+N] = res.template loadPacket<Packet>(row + 3, col + accCols);
+ }
+ }
} else {
- acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
- acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
- acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
- acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
- acc.packet[4] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
- acc.packet[5] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 1);
- acc.packet[6] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 2);
- acc.packet[7] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 3);
+ acc.packet[0] = res.template loadPacket<Packet>(row, col + 0);
+ if (N > 1) {
+ acc.packet[1] = res.template loadPacket<Packet>(row, col + 1);
+ }
+ if (N > 2) {
+ acc.packet[2] = res.template loadPacket<Packet>(row, col + 2);
+ }
+ if (N > 3) {
+ acc.packet[3] = res.template loadPacket<Packet>(row, col + 3);
+ }
+ if (Complex) {
+ acc.packet[0+N] = res.template loadPacket<Packet>(row + accCols, col + 0);
+ if (N > 1) {
+ acc.packet[1+N] = res.template loadPacket<Packet>(row + accCols, col + 1);
+ }
+ if (N > 2) {
+ acc.packet[2+N] = res.template loadPacket<Packet>(row + accCols, col + 2);
+ }
+ if (N > 3) {
+ acc.packet[3+N] = res.template loadPacket<Packet>(row + accCols, col + 3);
+ }
+ }
}
}
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,2>& acc, const DataMapper& res, Index row, Index col)
-{
- acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
- acc.packet[1] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
-}
-
const static Packet4i mask41 = { -1, 0, 0, 0 };
const static Packet4i mask42 = { -1, -1, 0, 0 };
const static Packet4i mask43 = { -1, -1, -1, 0 };
@@ -1275,22 +1293,44 @@ EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const int remaining_rows)
}
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha, const Packet& pMask)
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha, const Packet& pMask)
{
- band<Packet>(accZ, pMask);
+ band<Packet, N>(accZ, pMask);
- bscale<Packet>(acc, accZ, pAlpha);
+ bscale<Packet, N>(acc, accZ, pAlpha);
}
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3)
+template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
+pbroadcastN_old(const __UNPACK_TYPE__(Packet) *a,
+ Packet& a0, Packet& a1, Packet& a2, Packet& a3)
+{
+ a0 = pset1<Packet>(a[0]);
+ if (N > 1) {
+ a1 = pset1<Packet>(a[1]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a1);
+ }
+ if (N > 2) {
+ a2 = pset1<Packet>(a[2]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a2);
+ }
+ if (N > 3) {
+ a3 = pset1<Packet>(a[3]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a3);
+ }
+}
+
+template<>
+EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet4f,4>(const float* a, Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
{
- pbroadcast4<Packet>(a, a0, a1, a2, a3);
+ pbroadcast4<Packet4f>(a, a0, a1, a2, a3);
}
template<>
-EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
+EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet2d,4>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
{
a1 = pload<Packet2d>(a);
a3 = pload<Packet2d>(a + 2);
@@ -1300,89 +1340,96 @@ EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0
a3 = vec_splat(a3, 1);
}
-// PEEL loop factor.
-#define PEEL 7
-
-template<typename Scalar, typename Packet, typename Index>
-EIGEN_ALWAYS_INLINE void MICRO_EXTRA_COL(
- const Scalar* &lhs_ptr,
- const Scalar* &rhs_ptr,
- PacketBlock<Packet,1> &accZero,
- Index remaining_rows,
- Index remaining_cols)
+template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
+pbroadcastN(const __UNPACK_TYPE__(Packet) *a,
+ Packet& a0, Packet& a1, Packet& a2, Packet& a3)
{
- Packet rhsV[1];
- rhsV[0] = pset1<Packet>(rhs_ptr[0]);
- pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
- lhs_ptr += remaining_rows;
- rhs_ptr += remaining_cols;
+ a0 = pset1<Packet>(a[0]);
+ if (N > 1) {
+ a1 = pset1<Packet>(a[1]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a1);
+ }
+ if (N > 2) {
+ a2 = pset1<Packet>(a[2]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a2);
+ }
+ if (N > 3) {
+ a3 = pset1<Packet>(a[3]);
+ } else {
+ EIGEN_UNUSED_VARIABLE(a3);
+ }
}
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
-EIGEN_STRONG_INLINE void gemm_extra_col(
- const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
- Index depth,
- Index strideA,
- Index offsetA,
- Index row,
- Index col,
- Index remaining_rows,
- Index remaining_cols,
- const Packet& pAlpha)
+template<> EIGEN_ALWAYS_INLINE void
+pbroadcastN<Packet4f,4>(const float *a,
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
{
- const Scalar* rhs_ptr = rhs_base;
- const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
- PacketBlock<Packet,1> accZero;
+ a3 = pload<Packet4f>(a);
+ a0 = vec_splat(a3, 0);
+ a1 = vec_splat(a3, 1);
+ a2 = vec_splat(a3, 2);
+ a3 = vec_splat(a3, 3);
+}
- bsetzero<Scalar, Packet>(accZero);
+// PEEL loop factor.
+#define PEEL 7
+#define PEEL_ROW 7
- Index remaining_depth = (depth & -accRows);
- Index k = 0;
- for(; k + PEEL <= remaining_depth; k+= PEEL)
- {
- EIGEN_POWER_PREFETCH(rhs_ptr);
- EIGEN_POWER_PREFETCH(lhs_ptr);
- for (int l = 0; l < PEEL; l++) {
- MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
- }
- }
- for(; k < remaining_depth; k++)
- {
- MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
+#define MICRO_UNROLL_PEEL(func) \
+ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
+
+#define MICRO_ZERO_PEEL(peel) \
+ if ((PEEL_ROW > peel) && (peel != 0)) { \
+ bsetzero<Scalar, Packet, accRows>(accZero##peel); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accZero##peel); \
}
- for(; k < depth; k++)
- {
- Packet rhsV[1];
- rhsV[0] = pset1<Packet>(rhs_ptr[0]);
- pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
- lhs_ptr += remaining_rows;
- rhs_ptr += remaining_cols;
+
+#define MICRO_ZERO_PEEL_ROW \
+ MICRO_UNROLL_PEEL(MICRO_ZERO_PEEL);
+
+#define MICRO_WORK_PEEL(peel) \
+ if (PEEL_ROW > peel) { \
+ pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
}
- accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]);
- for(Index i = 0; i < remaining_rows; i++) {
- res(row + i, col) += accZero.packet[0][i];
+#define MICRO_WORK_PEEL_ROW \
+ Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
+ MICRO_UNROLL_PEEL(MICRO_WORK_PEEL); \
+ lhs_ptr += (remaining_rows * PEEL_ROW); \
+ rhs_ptr += (accRows * PEEL_ROW);
+
+#define MICRO_ADD_PEEL(peel, sum) \
+ if (PEEL_ROW > peel) { \
+ for (Index i = 0; i < accRows; i++) { \
+ accZero##sum.packet[i] += accZero##peel.packet[i]; \
+ } \
}
-}
-template<typename Scalar, typename Packet, typename Index, const Index accRows>
+#define MICRO_ADD_PEEL_ROW \
+ MICRO_ADD_PEEL(4, 0) MICRO_ADD_PEEL(5, 1) MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) \
+ MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
+
+template<typename Scalar, typename Packet, typename Index, const Index accRows, const Index remaining_rows>
EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
const Scalar* &lhs_ptr,
const Scalar* &rhs_ptr,
- PacketBlock<Packet,4> &accZero,
- Index remaining_rows)
+ PacketBlock<Packet,accRows> &accZero)
{
Packet rhsV[4];
- pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
- pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
+ pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
lhs_ptr += remaining_rows;
rhs_ptr += accRows;
}
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
-EIGEN_STRONG_INLINE void gemm_extra_row(
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
@@ -1393,59 +1440,89 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
Index col,
Index rows,
Index cols,
- Index remaining_rows,
const Packet& pAlpha,
const Packet& pMask)
{
const Scalar* rhs_ptr = rhs_base;
const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
- PacketBlock<Packet,4> accZero, acc;
+ PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
- bsetzero<Scalar, Packet>(accZero);
+ bsetzero<Scalar, Packet, accRows>(accZero0);
- Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
+ Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
Index k = 0;
- for(; k + PEEL <= remaining_depth; k+= PEEL)
- {
- EIGEN_POWER_PREFETCH(rhs_ptr);
- EIGEN_POWER_PREFETCH(lhs_ptr);
- for (int l = 0; l < PEEL; l++) {
- MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
- }
+ if (remaining_depth >= PEEL_ROW) {
+ MICRO_ZERO_PEEL_ROW
+ do
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr);
+ EIGEN_POWER_PREFETCH(lhs_ptr);
+ MICRO_WORK_PEEL_ROW
+ } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
+ MICRO_ADD_PEEL_ROW
}
for(; k < remaining_depth; k++)
{
- MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows, remaining_rows>(lhs_ptr, rhs_ptr, accZero0);
}
if ((remaining_depth == depth) && (rows >= accCols))
{
- for(Index j = 0; j < 4; j++) {
- acc.packet[j] = res.template loadPacket<Packet>(row, col + j);
- }
- bscale<Packet>(acc, accZero, pAlpha, pMask);
- res.template storePacketBlock<Packet,4>(row, col, acc);
+ bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row, 0);
+ bscale<Packet,accRows>(acc, accZero0, pAlpha, pMask);
+ res.template storePacketBlock<Packet,accRows>(row, 0, acc);
} else {
for(; k < depth; k++)
{
Packet rhsV[4];
- pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
- pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
+ pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ pger<accRows, Scalar, Packet, Index, false, remaining_rows>(&accZero0, lhs_ptr, rhsV);
lhs_ptr += remaining_rows;
rhs_ptr += accRows;
}
- for(Index j = 0; j < 4; j++) {
- accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]);
- }
- for(Index j = 0; j < 4; j++) {
+ for(Index j = 0; j < accRows; j++) {
+ accZero0.packet[j] = vec_mul(pAlpha, accZero0.packet[j]);
for(Index i = 0; i < remaining_rows; i++) {
- res(row + i, col + j) += accZero.packet[j][i];
+ res(row + i, j) += accZero0.packet[j][i];
}
}
}
}
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_ALWAYS_INLINE void gemm_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask)
+{
+ switch(remaining_rows) {
+ case 1:
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
+ break;
+ case 2:
+ if (sizeof(Scalar) == sizeof(float)) {
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
+ }
+ break;
+ default:
+ if (sizeof(Scalar) == sizeof(float)) {
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
+ }
+ break;
+ }
+}
+
#define MICRO_UNROLL(func) \
func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
@@ -1464,34 +1541,24 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
#define MICRO_WORK_ONE(iter, peel) \
if (unroll_factor > iter) { \
- pger_common<Packet, false>(&accZero##iter, lhsV##iter, rhsV##peel); \
+ pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
}
#define MICRO_TYPE_PEEL4(func, func2, peel) \
if (PEEL > peel) { \
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
- pbroadcast4<Packet>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
- MICRO_UNROLL_WORK(func, func2, peel) \
- } else { \
- EIGEN_UNUSED_VARIABLE(rhsV##peel); \
- }
-
-#define MICRO_TYPE_PEEL1(func, func2, peel) \
- if (PEEL > peel) { \
- Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
- rhsV##peel[0] = pset1<Packet>(rhs_ptr[remaining_cols * peel]); \
+ pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
MICRO_UNROLL_WORK(func, func2, peel) \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
}
#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
- Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
func(func1,func2,0); func(func1,func2,1); \
func(func1,func2,2); func(func1,func2,3); \
func(func1,func2,4); func(func1,func2,5); \
- func(func1,func2,6); func(func1,func2,7); \
- func(func1,func2,8); func(func1,func2,9);
+ func(func1,func2,6); func(func1,func2,7);
#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
Packet rhsV0[M]; \
@@ -1505,17 +1572,9 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
rhs_ptr += accRows;
-#define MICRO_ONE_PEEL1 \
- MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
- rhs_ptr += (remaining_cols * PEEL);
-
-#define MICRO_ONE1 \
- MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
- rhs_ptr += remaining_cols;
-
#define MICRO_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- bsetzero<Scalar, Packet>(accZero##iter); \
+ bsetzero<Scalar, Packet, accRows>(accZero##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accZero##iter); \
}
@@ -1524,7 +1583,7 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
#define MICRO_SRC_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
+ lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
} else { \
EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
}
@@ -1540,25 +1599,13 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
#define MICRO_STORE_ONE(iter) \
if (unroll_factor > iter) { \
- acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
- acc.packet[1] = res.template loadPacket<Packet>(row + iter*accCols, col + 1); \
- acc.packet[2] = res.template loadPacket<Packet>(row + iter*accCols, col + 2); \
- acc.packet[3] = res.template loadPacket<Packet>(row + iter*accCols, col + 3); \
- bscale<Packet>(acc, accZero##iter, pAlpha); \
- res.template storePacketBlock<Packet,4>(row + iter*accCols, col, acc); \
+ bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
+ bscale<Packet,accRows>(acc, accZero##iter, pAlpha); \
+ res.template storePacketBlock<Packet,accRows>(row + iter*accCols, 0, acc); \
}
#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
-#define MICRO_COL_STORE_ONE(iter) \
- if (unroll_factor > iter) { \
- acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
- bscale<Packet>(acc, accZero##iter, pAlpha); \
- res.template storePacketBlock<Packet,1>(row + iter*accCols, col, acc); \
- }
-
-#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE)
-
template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
const DataMapper& res,
@@ -1566,15 +1613,13 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
const Scalar* rhs_base,
Index depth,
Index strideA,
- Index offsetA,
Index& row,
- Index col,
const Packet& pAlpha)
{
const Scalar* rhs_ptr = rhs_base;
const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
- PacketBlock<Packet,4> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
- PacketBlock<Packet,4> acc;
+ PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
+ PacketBlock<Packet,accRows> acc;
MICRO_SRC_PTR
MICRO_DST_PTR
@@ -1595,101 +1640,100 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
row += unroll_factor*accCols;
}
-template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
-EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration(
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_ALWAYS_INLINE void gemm_cols(
const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
+ const Scalar* blockA,
+ const Scalar* blockB,
Index depth,
Index strideA,
Index offsetA,
- Index& row,
+ Index strideB,
+ Index offsetB,
Index col,
- Index remaining_cols,
- const Packet& pAlpha)
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask)
{
- const Scalar* rhs_ptr = rhs_base;
- const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, *lhs_ptr7 = NULL;
- PacketBlock<Packet,1> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
- PacketBlock<Packet,1> acc;
+ const DataMapper res3 = res.getSubMapper(0, col);
- MICRO_SRC_PTR
- MICRO_DST_PTR
+ const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA + accCols*offsetA;
+ Index row = 0;
- Index k = 0;
- for(; k + PEEL <= depth; k+= PEEL)
- {
- EIGEN_POWER_PREFETCH(rhs_ptr);
- MICRO_PREFETCH
- MICRO_ONE_PEEL1
- }
- for(; k < depth; k++)
- {
- MICRO_ONE1
- }
- MICRO_COL_STORE
-
- row += unroll_factor*accCols;
-}
-
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
-EIGEN_STRONG_INLINE void gemm_unrolled_col(
- const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
- Index depth,
- Index strideA,
- Index offsetA,
- Index& row,
- Index rows,
- Index col,
- Index remaining_cols,
- const Packet& pAlpha)
-{
#define MAX_UNROLL 6
while(row + MAX_UNROLL*accCols <= rows) {
- gemm_unrolled_col_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
}
switch( (rows-row)/accCols ) {
#if MAX_UNROLL > 7
case 7:
- gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
break;
#endif
#if MAX_UNROLL > 6
case 6:
- gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
break;
#endif
#if MAX_UNROLL > 5
- case 5:
- gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ case 5:
+ gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
break;
#endif
#if MAX_UNROLL > 4
- case 4:
- gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
+ case 4:
+ gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
break;
#endif
#if MAX_UNROLL > 3
- case 3:
- gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
- break;
+ case 3:
+ gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
#if MAX_UNROLL > 2
- case 2:
- gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
- break;
+ case 2:
+ gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
#if MAX_UNROLL > 1
- case 1:
- gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha);
- break;
+ case 1:
+ gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
- default:
- break;
+ default:
+ break;
}
#undef MAX_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
+ }
+}
+
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
+EIGEN_STRONG_INLINE void gemm_extra_cols(
+ const DataMapper& res,
+ const Scalar* blockA,
+ const Scalar* blockB,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index offsetB,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask)
+{
+ for (; col < cols; col++) {
+ gemm_cols<Scalar, Packet, DataMapper, Index, 1, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
+ }
}
/****************
@@ -1699,7 +1743,6 @@ template<typename Scalar, typename Index, typename Packet, typename RhsPacket, t
EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
const Index remaining_rows = rows % accCols;
- const Index remaining_cols = cols % accRows;
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
@@ -1710,79 +1753,10 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
- const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
- const Scalar* lhs_base = blockA;
- Index row = 0;
-
-#define MAX_UNROLL 6
- while(row + MAX_UNROLL*accCols <= rows) {
- gemm_unrolled_iteration<MAX_UNROLL, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- }
- switch( (rows-row)/accCols ) {
-#if MAX_UNROLL > 7
- case 7:
- gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
-#endif
-#if MAX_UNROLL > 6
- case 6:
- gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
-#endif
-#if MAX_UNROLL > 5
- case 5:
- gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
-#endif
-#if MAX_UNROLL > 4
- case 4:
- gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
-#endif
-#if MAX_UNROLL > 3
- case 3:
- gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
-#endif
-#if MAX_UNROLL > 2
- case 2:
- gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
-#endif
-#if MAX_UNROLL > 1
- case 1:
- gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
-#endif
- default:
- break;
- }
-#undef MAX_UNROLL
-
- if(remaining_rows > 0)
- {
- gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
- }
- }
-
- if(remaining_cols > 0)
- {
- const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
- const Scalar* lhs_base = blockA;
-
- for(; col < cols; col++)
- {
- Index row = 0;
-
- gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
-
- if (remaining_rows > 0)
- {
- gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
- }
- rhs_base++;
+ gemm_cols<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
}
- }
+
+ gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
}
#define accColsC (accCols / 2)
@@ -1791,117 +1765,66 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const
// PEEL_COMPLEX loop factor.
#define PEEL_COMPLEX 3
+#define PEEL_COMPLEX_ROW 3
-template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_COL(
- const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
- const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
- PacketBlock<Packet,1> &accReal, PacketBlock<Packet,1> &accImag,
- Index remaining_rows,
- Index remaining_cols)
-{
- Packet rhsV[1], rhsVi[1];
- rhsV[0] = pset1<Packet>(rhs_ptr_real[0]);
- if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]);
- pgerc<1, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
- lhs_ptr_real += remaining_rows;
- if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
- else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
- rhs_ptr_real += remaining_cols;
- if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
- else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
-}
-
-template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_STRONG_INLINE void gemm_complex_extra_col(
- const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
- Index depth,
- Index strideA,
- Index offsetA,
- Index strideB,
- Index row,
- Index col,
- Index remaining_rows,
- Index remaining_cols,
- const Packet& pAlphaReal,
- const Packet& pAlphaImag)
-{
- const Scalar* rhs_ptr_real = rhs_base;
- const Scalar* rhs_ptr_imag;
- if(!RhsIsReal) rhs_ptr_imag = rhs_base + remaining_cols*strideB;
- else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
- const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
- const Scalar* lhs_ptr_imag;
- if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
- else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
- PacketBlock<Packet,1> accReal, accImag;
- PacketBlock<Packet,1> taccReal, taccImag;
- PacketBlock<Packetc,1> acc0, acc1;
-
- bsetzero<Scalar, Packet>(accReal);
- bsetzero<Scalar, Packet>(accImag);
+#define MICRO_COMPLEX_UNROLL_PEEL(func) \
+ func(0) func(1) func(2) func(3)
- Index remaining_depth = (depth & -accRows);
- Index k = 0;
- for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX)
- {
- EIGEN_POWER_PREFETCH(rhs_ptr_real);
- if(!RhsIsReal) {
- EIGEN_POWER_PREFETCH(rhs_ptr_imag);
- }
- EIGEN_POWER_PREFETCH(lhs_ptr_real);
- if(!LhsIsReal) {
- EIGEN_POWER_PREFETCH(lhs_ptr_imag);
- }
- for (int l = 0; l < PEEL_COMPLEX; l++) {
- MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols);
- }
- }
- for(; k < remaining_depth; k++)
- {
- MICRO_COMPLEX_EXTRA_COL<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols);
+#define MICRO_COMPLEX_ZERO_PEEL(peel) \
+ if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \
+ bsetzero<Scalar, Packet, accRows>(accReal##peel); \
+ bsetzero<Scalar, Packet, accRows>(accImag##peel); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(accReal##peel); \
+ EIGEN_UNUSED_VARIABLE(accImag##peel); \
}
- for(; k < depth; k++)
- {
- Packet rhsV[1], rhsVi[1];
- rhsV[0] = pset1<Packet>(rhs_ptr_real[0]);
- if(!RhsIsReal) rhsVi[0] = pset1<Packet>(rhs_ptr_imag[0]);
- pgerc<1, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows);
- lhs_ptr_real += remaining_rows;
- if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
- rhs_ptr_real += remaining_cols;
- if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
+#define MICRO_COMPLEX_ZERO_PEEL_ROW \
+ MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_ZERO_PEEL);
+
+#define MICRO_COMPLEX_WORK_PEEL(peel) \
+ if (PEEL_COMPLEX_ROW > peel) { \
+ pbroadcastN_old<Packet,accRows>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
+ pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \
+ } else { \
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
+ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
}
- bscalec<Packet,1>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag);
- bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1);
+#define MICRO_COMPLEX_WORK_PEEL_ROW \
+ Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \
+ Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \
+ MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_WORK_PEEL); \
+ lhs_ptr_real += (remaining_rows * PEEL_COMPLEX_ROW); \
+ if(!LhsIsReal) lhs_ptr_imag += (remaining_rows * PEEL_COMPLEX_ROW); \
+ else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); \
+ rhs_ptr_real += (accRows * PEEL_COMPLEX_ROW); \
+ if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_ROW); \
+ else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
- if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
- {
- res(row + 0, col + 0) += pfirst<Packetc>(acc0.packet[0]);
- } else {
- acc0.packet[0] += res.template loadPacket<Packetc>(row + 0, col + 0);
- res.template storePacketBlock<Packetc,1>(row + 0, col + 0, acc0);
- if(remaining_rows > accColsC) {
- res(row + accColsC, col + 0) += pfirst<Packetc>(acc1.packet[0]);
- }
+#define MICRO_COMPLEX_ADD_PEEL(peel, sum) \
+ if (PEEL_COMPLEX_ROW > peel) { \
+ for (Index i = 0; i < accRows; i++) { \
+ accReal##sum.packet[i] += accReal##peel.packet[i]; \
+ accImag##sum.packet[i] += accImag##peel.packet[i]; \
+ } \
}
-}
-template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+#define MICRO_COMPLEX_ADD_PEEL_ROW \
+ MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) \
+ MICRO_COMPLEX_ADD_PEEL(1, 0)
+
+template<typename Scalar, typename Packet, typename Index, const Index accRows, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag,
const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag,
- PacketBlock<Packet,4> &accReal, PacketBlock<Packet,4> &accImag,
- Index remaining_rows)
+ PacketBlock<Packet,accRows> &accReal, PacketBlock<Packet,accRows> &accImag)
{
Packet rhsV[4], rhsVi[4];
- pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
- if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
- pgerc<4, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
+ pbroadcastN_old<Packet,accRows>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
+ pgerc<accRows, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
lhs_ptr_real += remaining_rows;
if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
@@ -1910,8 +1833,8 @@ EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW(
else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
}
-template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_STRONG_INLINE void gemm_complex_extra_row(
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
+EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
@@ -1923,7 +1846,6 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row(
Index col,
Index rows,
Index cols,
- Index remaining_rows,
const Packet& pAlphaReal,
const Packet& pAlphaImag,
const Packet& pMask)
@@ -1936,93 +1858,129 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row(
const Scalar* lhs_ptr_imag;
if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
- PacketBlock<Packet,4> accReal, accImag;
- PacketBlock<Packet,4> taccReal, taccImag;
- PacketBlock<Packetc,4> acc0, acc1;
- PacketBlock<Packetc,8> tRes;
+ PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
+ PacketBlock<Packet,accRows> taccReal, taccImag;
+ PacketBlock<Packetc,accRows> acc0, acc1;
+ PacketBlock<Packetc,accRows*2> tRes;
- bsetzero<Scalar, Packet>(accReal);
- bsetzero<Scalar, Packet>(accImag);
+ bsetzero<Scalar, Packet, accRows>(accReal0);
+ bsetzero<Scalar, Packet, accRows>(accImag0);
- Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
+ Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
Index k = 0;
- for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX)
- {
- EIGEN_POWER_PREFETCH(rhs_ptr_real);
- if(!RhsIsReal) {
- EIGEN_POWER_PREFETCH(rhs_ptr_imag);
- }
- EIGEN_POWER_PREFETCH(lhs_ptr_real);
- if(!LhsIsReal) {
- EIGEN_POWER_PREFETCH(lhs_ptr_imag);
- }
- for (int l = 0; l < PEEL_COMPLEX; l++) {
- MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows);
- }
+ if (remaining_depth >= PEEL_COMPLEX_ROW) {
+ MICRO_COMPLEX_ZERO_PEEL_ROW
+ do
+ {
+ EIGEN_POWER_PREFETCH(rhs_ptr_real);
+ if(!RhsIsReal) {
+ EIGEN_POWER_PREFETCH(rhs_ptr_imag);
+ }
+ EIGEN_POWER_PREFETCH(lhs_ptr_real);
+ if(!LhsIsReal) {
+ EIGEN_POWER_PREFETCH(lhs_ptr_imag);
+ }
+ MICRO_COMPLEX_WORK_PEEL_ROW
+ } while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth);
+ MICRO_COMPLEX_ADD_PEEL_ROW
}
for(; k < remaining_depth; k++)
{
- MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows);
+ MICRO_COMPLEX_EXTRA_ROW<Scalar, Packet, Index, accRows, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal0, accImag0);
}
if ((remaining_depth == depth) && (rows >= accCols))
{
- bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row, col);
- bscalec<Packet>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
- bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1);
- res.template storePacketBlock<Packetc,4>(row + 0, col, acc0);
- res.template storePacketBlock<Packetc,4>(row + accColsC, col, acc1);
+ bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows>(tRes, res, row, 0);
+ bscalec<Packet,accRows>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask);
+ bcouple<Packet, Packetc, accRows>(taccReal, taccImag, tRes, acc0, acc1);
+ res.template storePacketBlock<Packetc,accRows>(row + 0, 0, acc0);
+ res.template storePacketBlock<Packetc,accRows>(row + accColsC, 0, acc1);
} else {
for(; k < depth; k++)
{
Packet rhsV[4], rhsVi[4];
- pbroadcast4_old<Packet>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
- if(!RhsIsReal) pbroadcast4_old<Packet>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
- pgerc<4, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows);
+ pbroadcastN_old<Packet,accRows>(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
+ if(!RhsIsReal) pbroadcastN_old<Packet,accRows>(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]);
+ pgerc<accRows, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, remaining_rows>(&accReal0, &accImag0, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi);
lhs_ptr_real += remaining_rows;
if(!LhsIsReal) lhs_ptr_imag += remaining_rows;
rhs_ptr_real += accRows;
if(!RhsIsReal) rhs_ptr_imag += accRows;
}
- bscalec<Packet,4>(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag);
- bcouple_common<Packet, Packetc>(taccReal, taccImag, acc0, acc1);
+ bscalec<Packet,accRows>(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag);
+ bcouple_common<Packet, Packetc, accRows>(taccReal, taccImag, acc0, acc1);
if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1))
{
- for(Index j = 0; j < 4; j++) {
- res(row + 0, col + j) += pfirst<Packetc>(acc0.packet[j]);
+ for(Index j = 0; j < accRows; j++) {
+ res(row + 0, j) += pfirst<Packetc>(acc0.packet[j]);
}
} else {
- for(Index j = 0; j < 4; j++) {
+ for(Index j = 0; j < accRows; j++) {
PacketBlock<Packetc,1> acc2;
- acc2.packet[0] = res.template loadPacket<Packetc>(row + 0, col + j) + acc0.packet[j];
- res.template storePacketBlock<Packetc,1>(row + 0, col + j, acc2);
+ acc2.packet[0] = res.template loadPacket<Packetc>(row + 0, j) + acc0.packet[j];
+ res.template storePacketBlock<Packetc,1>(row + 0, j, acc2);
if(remaining_rows > accColsC) {
- res(row + accColsC, col + j) += pfirst<Packetc>(acc1.packet[j]);
+ res(row + accColsC, j) += pfirst<Packetc>(acc1.packet[j]);
}
}
}
}
}
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
+ const DataMapper& res,
+ const Scalar* lhs_base,
+ const Scalar* rhs_base,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index row,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag,
+ const Packet& pMask)
+{
+ switch(remaining_rows) {
+ case 1:
+ gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
+ break;
+ case 2:
+ if (sizeof(Scalar) == sizeof(float)) {
+ gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
+ }
+ break;
+ default:
+ if (sizeof(Scalar) == sizeof(float)) {
+ gemm_unrolled_complex_row_iteration<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask);
+ }
+ break;
+ }
+}
+
#define MICRO_COMPLEX_UNROLL(func) \
- func(0) func(1) func(2) func(3) func(4)
+ func(0) func(1) func(2) func(3)
#define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
MICRO_COMPLEX_UNROLL(func2); \
- func(0,peel) func(1,peel) func(2,peel) func(3,peel) func(4,peel)
+ func(0,peel) func(1,peel) func(2,peel) func(3,peel)
#define MICRO_COMPLEX_LOAD_ONE(iter) \
if (unroll_factor > iter) { \
lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
- lhs_ptr_real##iter += accCols; \
if(!LhsIsReal) { \
- lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
- lhs_ptr_imag##iter += accCols; \
+ lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
} else { \
EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
} \
+ lhs_ptr_real##iter += accCols; \
} else { \
EIGEN_UNUSED_VARIABLE(lhsV##iter); \
EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
@@ -2030,37 +1988,16 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row(
#define MICRO_COMPLEX_WORK_ONE4(iter, peel) \
if (unroll_factor > iter) { \
- pgerc_common<4, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
- }
-
-#define MICRO_COMPLEX_WORK_ONE1(iter, peel) \
- if (unroll_factor > iter) { \
- pgerc_common<1, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
+ pgerc_common<accRows, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \
}
#define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \
if (PEEL_COMPLEX > peel) { \
- Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
- Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
- pbroadcast4_old<Packet>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3; \
+ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
+ pbroadcastN_old<Packet,accRows>(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
if(!RhsIsReal) { \
- pbroadcast4_old<Packet>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
- } else { \
- EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
- } \
- MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \
- } else { \
- EIGEN_UNUSED_VARIABLE(rhsV##peel); \
- EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
- }
-
-#define MICRO_COMPLEX_TYPE_PEEL1(func, func2, peel) \
- if (PEEL_COMPLEX > peel) { \
- Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
- Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
- rhsV##peel[0] = pset1<Packet>(rhs_ptr_real[remaining_cols * peel]); \
- if(!RhsIsReal) { \
- rhsVi##peel[0] = pset1<Packet>(rhs_ptr_imag[remaining_cols * peel]); \
+ pbroadcastN_old<Packet,accRows>(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \
} else { \
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
} \
@@ -2071,13 +2008,10 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row(
}
#define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \
- Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
- Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M], rhsVi4[M], rhsVi5[M], rhsVi6[M], rhsVi7[M], rhsVi8[M], rhsVi9[M]; \
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \
+ Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \
func(func1,func2,0); func(func1,func2,1); \
- func(func1,func2,2); func(func1,func2,3); \
- func(func1,func2,4); func(func1,func2,5); \
- func(func1,func2,6); func(func1,func2,7); \
- func(func1,func2,8); func(func1,func2,9);
+ func(func1,func2,2); func(func1,func2,3);
#define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \
Packet rhsV0[M], rhsVi0[M];\
@@ -2093,20 +2027,10 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row(
rhs_ptr_real += accRows; \
if(!RhsIsReal) rhs_ptr_imag += accRows;
-#define MICRO_COMPLEX_ONE_PEEL1 \
- MICRO_COMPLEX_UNROLL_TYPE_PEEL(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \
- rhs_ptr_real += (remaining_cols * PEEL_COMPLEX); \
- if(!RhsIsReal) rhs_ptr_imag += (remaining_cols * PEEL_COMPLEX);
-
-#define MICRO_COMPLEX_ONE1 \
- MICRO_COMPLEX_UNROLL_TYPE_ONE(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \
- rhs_ptr_real += remaining_cols; \
- if(!RhsIsReal) rhs_ptr_imag += remaining_cols;
-
#define MICRO_COMPLEX_DST_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- bsetzero<Scalar, Packet>(accReal##iter); \
- bsetzero<Scalar, Packet>(accImag##iter); \
+ bsetzero<Scalar, Packet, accRows>(accReal##iter); \
+ bsetzero<Scalar, Packet, accRows>(accImag##iter); \
} else { \
EIGEN_UNUSED_VARIABLE(accReal##iter); \
EIGEN_UNUSED_VARIABLE(accImag##iter); \
@@ -2116,15 +2040,9 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row(
#define MICRO_COMPLEX_SRC_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
- if(!LhsIsReal) { \
- lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
- } \
+ lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
} else { \
EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
- EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
}
#define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE)
@@ -2132,35 +2050,21 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row(
#define MICRO_COMPLEX_PREFETCH_ONE(iter) \
if (unroll_factor > iter) { \
EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
- if(!LhsIsReal) { \
- EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
- } \
}
#define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE)
#define MICRO_COMPLEX_STORE_ONE(iter) \
if (unroll_factor > iter) { \
- bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \
- bscalec<Packet,4>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
- bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \
- res.template storePacketBlock<Packetc,4>(row + iter*accCols + 0, col, acc0); \
- res.template storePacketBlock<Packetc,4>(row + iter*accCols + accColsC, col, acc1); \
+ bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, accRows>(tRes, res, row + iter*accCols, 0); \
+ bscalec<Packet,accRows>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
+ bcouple<Packet, Packetc, accRows>(taccReal, taccImag, tRes, acc0, acc1); \
+ res.template storePacketBlock<Packetc,accRows>(row + iter*accCols + 0, 0, acc0); \
+ res.template storePacketBlock<Packetc,accRows>(row + iter*accCols + accColsC, 0, acc1); \
}
#define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE)
-#define MICRO_COMPLEX_COL_STORE_ONE(iter) \
- if (unroll_factor > iter) { \
- bload<DataMapper, Packetc, Index, accColsC, 0, ColMajor>(tRes, res, row + iter*accCols, col); \
- bscalec<Packet,1>(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \
- bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc0, acc1); \
- res.template storePacketBlock<Packetc,1>(row + iter*accCols + 0, col, acc0); \
- res.template storePacketBlock<Packetc,1>(row + iter*accCols + accColsC, col, acc1); \
- }
-
-#define MICRO_COMPLEX_COL_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_COL_STORE_ONE)
-
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
const DataMapper& res,
@@ -2168,29 +2072,26 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
const Scalar* rhs_base,
Index depth,
Index strideA,
- Index offsetA,
Index strideB,
Index& row,
- Index col,
const Packet& pAlphaReal,
const Packet& pAlphaImag)
{
const Scalar* rhs_ptr_real = rhs_base;
const Scalar* rhs_ptr_imag;
+ const Index imag_delta = accCols*strideA;
if(!RhsIsReal) {
rhs_ptr_imag = rhs_base + accRows*strideB;
} else {
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
}
- const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
- const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
- const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
- PacketBlock<Packet,4> accReal0, accImag0, accReal1, accImag1;
- PacketBlock<Packet,4> accReal2, accImag2, accReal3, accImag3;
- PacketBlock<Packet,4> accReal4, accImag4;
- PacketBlock<Packet,4> taccReal, taccImag;
- PacketBlock<Packetc,4> acc0, acc1;
- PacketBlock<Packetc,8> tRes;
+ const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
+ const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
+ PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1;
+ PacketBlock<Packet,accRows> accReal2, accImag2, accReal3, accImag3;
+ PacketBlock<Packet,accRows> taccReal, taccImag;
+ PacketBlock<Packetc,accRows> acc0, acc1;
+ PacketBlock<Packetc,accRows*2> tRes;
MICRO_COMPLEX_SRC_PTR
MICRO_COMPLEX_DST_PTR
@@ -2214,112 +2115,93 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
row += unroll_factor*accCols;
}
-template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_STRONG_INLINE void gemm_complex_unrolled_col_iteration(
+template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void gemm_complex_cols(
const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
+ const Scalar* blockA,
+ const Scalar* blockB,
Index depth,
Index strideA,
Index offsetA,
Index strideB,
- Index& row,
+ Index offsetB,
Index col,
- Index remaining_cols,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
const Packet& pAlphaReal,
- const Packet& pAlphaImag)
+ const Packet& pAlphaImag,
+ const Packet& pMask)
{
- const Scalar* rhs_ptr_real = rhs_base;
- const Scalar* rhs_ptr_imag;
- if(!RhsIsReal) {
- rhs_ptr_imag = rhs_base + remaining_cols*strideB;
- } else {
- EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
- }
- const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
- const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
- const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
- PacketBlock<Packet,1> accReal0, accImag0, accReal1, accImag1;
- PacketBlock<Packet,1> accReal2, accImag2, accReal3, accImag3;
- PacketBlock<Packet,1> accReal4, accImag4;
- PacketBlock<Packet,1> taccReal, taccImag;
- PacketBlock<Packetc,1> acc0, acc1;
- PacketBlock<Packetc,2> tRes;
+ const DataMapper res3 = res.getSubMapper(0, col);
- MICRO_COMPLEX_SRC_PTR
- MICRO_COMPLEX_DST_PTR
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA + accCols*offsetA;
+ Index row = 0;
- Index k = 0;
- for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX)
- {
- EIGEN_POWER_PREFETCH(rhs_ptr_real);
- if(!RhsIsReal) {
- EIGEN_POWER_PREFETCH(rhs_ptr_imag);
- }
- MICRO_COMPLEX_PREFETCH
- MICRO_COMPLEX_ONE_PEEL1
+#define MAX_COMPLEX_UNROLL 3
+ while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
+ gemm_complex_unrolled_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
}
- for(; k < depth; k++)
- {
- MICRO_COMPLEX_ONE1
+ switch( (rows-row)/accCols ) {
+#if MAX_COMPLEX_UNROLL > 4
+ case 4:
+ gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 3
+ case 3:
+ gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 2
+ case 2:
+ gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_UNROLL > 1
+ case 1:
+ gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ break;
+#endif
+ default:
+ break;
}
- MICRO_COMPLEX_COL_STORE
+#undef MAX_COMPLEX_UNROLL
- row += unroll_factor*accCols;
+ if(remaining_rows > 0)
+ {
+ gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ }
}
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
+EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
+ const Scalar* blockA,
+ const Scalar* blockB,
Index depth,
Index strideA,
Index offsetA,
Index strideB,
- Index& row,
- Index rows,
+ Index offsetB,
Index col,
- Index remaining_cols,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
const Packet& pAlphaReal,
- const Packet& pAlphaImag)
+ const Packet& pAlphaImag,
+ const Packet& pMask)
{
-#define MAX_COMPLEX_UNROLL 3
- while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
- gemm_complex_unrolled_col_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
+ for (; col < cols; col++) {
+ gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, 1, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
- switch( (rows-row)/accCols ) {
-#if MAX_COMPLEX_UNROLL > 4
- case 4:
- gemm_complex_unrolled_col_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
- break;
-#endif
-#if MAX_COMPLEX_UNROLL > 3
- case 3:
- gemm_complex_unrolled_col_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
- break;
-#endif
-#if MAX_COMPLEX_UNROLL > 2
- case 2:
- gemm_complex_unrolled_col_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
- break;
-#endif
-#if MAX_COMPLEX_UNROLL > 1
- case 1:
- gemm_complex_unrolled_col_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag);
- break;
-#endif
- default:
- break;
- }
-#undef MAX_COMPLEX_UNROLL
}
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
const Index remaining_rows = rows % accCols;
- const Index remaining_cols = cols % accRows;
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
@@ -2334,64 +2216,10 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
- const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
- const Scalar* lhs_base = blockA;
- Index row = 0;
-
-#define MAX_COMPLEX_UNROLL 3
- while(row + MAX_COMPLEX_UNROLL*accCols <= rows) {
- gemm_complex_unrolled_iteration<MAX_COMPLEX_UNROLL, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- }
- switch( (rows-row)/accCols ) {
-#if MAX_COMPLEX_UNROLL > 4
- case 4:
- gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- break;
-#endif
-#if MAX_COMPLEX_UNROLL > 3
- case 3:
- gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- break;
-#endif
-#if MAX_COMPLEX_UNROLL > 2
- case 2:
- gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- break;
-#endif
-#if MAX_COMPLEX_UNROLL > 1
- case 1:
- gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- break;
-#endif
- default:
- break;
- }
-#undef MAX_COMPLEX_UNROLL
-
- if(remaining_rows > 0)
- {
- gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
- }
+ gemm_complex_cols<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
- if(remaining_cols > 0)
- {
- const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
- const Scalar* lhs_base = blockA;
-
- for(; col < cols; col++)
- {
- Index row = 0;
-
- gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
-
- if (remaining_rows > 0)
- {
- gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
- }
- rhs_base++;
- }
- }
+ gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
#undef accColsC
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
index d4287cc6f..768d9c7c4 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h
@@ -11,22 +11,8 @@ namespace Eigen {
namespace internal {
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
-EIGEN_STRONG_INLINE void gemm_extra_col(
- const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
- Index depth,
- Index strideA,
- Index offsetA,
- Index row,
- Index col,
- Index remaining_rows,
- Index remaining_cols,
- const Packet& pAlpha);
-
template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
-EIGEN_STRONG_INLINE void gemm_extra_row(
+EIGEN_ALWAYS_INLINE void gemm_extra_row(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
@@ -41,41 +27,28 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
const Packet& pAlpha,
const Packet& pMask);
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
-EIGEN_STRONG_INLINE void gemm_unrolled_col(
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_STRONG_INLINE void gemm_extra_cols(
const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
+ const Scalar* blockA,
+ const Scalar* blockB,
Index depth,
Index strideA,
Index offsetA,
- Index& row,
- Index rows,
+ Index strideB,
+ Index offsetB,
Index col,
- Index remaining_cols,
- const Packet& pAlpha);
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask);
template<typename Packet>
EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows);
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_STRONG_INLINE void gemm_complex_extra_col(
- const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
- Index depth,
- Index strideA,
- Index offsetA,
- Index strideB,
- Index row,
- Index col,
- Index remaining_rows,
- Index remaining_cols,
- const Packet& pAlphaReal,
- const Packet& pAlphaImag);
-
-template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_STRONG_INLINE void gemm_complex_extra_row(
+EIGEN_ALWAYS_INLINE void gemm_complex_extra_row(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
@@ -93,123 +66,88 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row(
const Packet& pMask);
template<typename Scalar, typename Packet, typename Packetc, typename DataMapper, typename Index, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_STRONG_INLINE void gemm_complex_unrolled_col(
+EIGEN_STRONG_INLINE void gemm_complex_extra_cols(
const DataMapper& res,
- const Scalar* lhs_base,
- const Scalar* rhs_base,
+ const Scalar* blockA,
+ const Scalar* blockB,
Index depth,
Index strideA,
Index offsetA,
Index strideB,
- Index& row,
- Index rows,
+ Index offsetB,
Index col,
- Index remaining_cols,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
const Packet& pAlphaReal,
- const Packet& pAlphaImag);
+ const Packet& pAlphaImag,
+ const Packet& pMask);
template<typename Scalar, typename Packet>
EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs);
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col);
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N>
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N>& acc, const DataMapper& res, Index row, Index col);
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col);
-
-template<typename Packet>
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha);
+template<typename Packet, int N>
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha);
template<typename Packet, int N>
EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag);
-const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3,
- 16, 17, 18, 19,
- 4, 5, 6, 7,
- 20, 21, 22, 23};
-
-const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11,
- 24, 25, 26, 27,
- 12, 13, 14, 15,
- 28, 29, 30, 31};
-//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64
-const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7,
- 16, 17, 18, 19, 20, 21, 22, 23};
-
-//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64
-const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15,
- 24, 25, 26, 27, 28, 29, 30, 31};
-
-
// Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks.
-template<typename Packet, typename Packetc>
-EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
-{
- acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
- acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST);
- acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST);
- acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST);
-
- acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
- acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND);
- acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND);
- acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND);
-}
-
-template<typename Packet, typename Packetc>
-EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,4>& taccReal, PacketBlock<Packet,4>& taccImag, PacketBlock<Packetc,8>& tRes, PacketBlock<Packetc, 4>& acc1, PacketBlock<Packetc, 4>& acc2)
-{
- bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
-
- acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
- acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
- acc1.packet[2] = padd<Packetc>(tRes.packet[2], acc1.packet[2]);
- acc1.packet[3] = padd<Packetc>(tRes.packet[3], acc1.packet[3]);
-
- acc2.packet[0] = padd<Packetc>(tRes.packet[4], acc2.packet[0]);
- acc2.packet[1] = padd<Packetc>(tRes.packet[5], acc2.packet[1]);
- acc2.packet[2] = padd<Packetc>(tRes.packet[6], acc2.packet[2]);
- acc2.packet[3] = padd<Packetc>(tRes.packet[7], acc2.packet[3]);
-}
-
-template<typename Packet, typename Packetc>
-EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
+template<typename Packet, typename Packetc, int N>
+EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2)
{
- acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST);
-
- acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND);
+ acc1.packet[0].v = vec_mergeh(taccReal.packet[0], taccImag.packet[0]);
+ if (N > 1) {
+ acc1.packet[1].v = vec_mergeh(taccReal.packet[1], taccImag.packet[1]);
+ }
+ if (N > 2) {
+ acc1.packet[2].v = vec_mergeh(taccReal.packet[2], taccImag.packet[2]);
+ }
+ if (N > 3) {
+ acc1.packet[3].v = vec_mergeh(taccReal.packet[3], taccImag.packet[3]);
+ }
+
+ acc2.packet[0].v = vec_mergel(taccReal.packet[0], taccImag.packet[0]);
+ if (N > 1) {
+ acc2.packet[1].v = vec_mergel(taccReal.packet[1], taccImag.packet[1]);
+ }
+ if (N > 2) {
+ acc2.packet[2].v = vec_mergel(taccReal.packet[2], taccImag.packet[2]);
+ }
+ if (N > 3) {
+ acc2.packet[3].v = vec_mergel(taccReal.packet[3], taccImag.packet[3]);
+ }
}
-template<typename Packet, typename Packetc>
-EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,1>& taccReal, PacketBlock<Packet,1>& taccImag, PacketBlock<Packetc,2>& tRes, PacketBlock<Packetc, 1>& acc1, PacketBlock<Packetc, 1>& acc2)
+template<typename Packet, typename Packetc, int N>
+EIGEN_ALWAYS_INLINE void bcouple(PacketBlock<Packet,N>& taccReal, PacketBlock<Packet,N>& taccImag, PacketBlock<Packetc,N*2>& tRes, PacketBlock<Packetc, N>& acc1, PacketBlock<Packetc, N>& acc2)
{
- bcouple_common<Packet, Packetc>(taccReal, taccImag, acc1, acc2);
+ bcouple_common<Packet, Packetc, N>(taccReal, taccImag, acc1, acc2);
acc1.packet[0] = padd<Packetc>(tRes.packet[0], acc1.packet[0]);
-
- acc2.packet[0] = padd<Packetc>(tRes.packet[1], acc2.packet[0]);
-}
-
-template<>
-EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,4>& taccReal, PacketBlock<Packet2d,4>& taccImag, PacketBlock<Packet1cd, 4>& acc1, PacketBlock<Packet1cd, 4>& acc2)
-{
- acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
- acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST);
- acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST);
- acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST);
-
- acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
- acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND);
- acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND);
- acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND);
-}
-
-template<>
-EIGEN_ALWAYS_INLINE void bcouple_common<Packet2d, Packet1cd>(PacketBlock<Packet2d,1>& taccReal, PacketBlock<Packet2d,1>& taccImag, PacketBlock<Packet1cd, 1>& acc1, PacketBlock<Packet1cd, 1>& acc2)
-{
- acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST);
-
- acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND);
+ if (N > 1) {
+ acc1.packet[1] = padd<Packetc>(tRes.packet[1], acc1.packet[1]);
+ }
+ if (N > 2) {
+ acc1.packet[2] = padd<Packetc>(tRes.packet[2], acc1.packet[2]);
+ }
+ if (N > 3) {
+ acc1.packet[3] = padd<Packetc>(tRes.packet[3], acc1.packet[3]);
+ }
+
+ acc2.packet[0] = padd<Packetc>(tRes.packet[0+N], acc2.packet[0]);
+ if (N > 1) {
+ acc2.packet[1] = padd<Packetc>(tRes.packet[1+N], acc2.packet[1]);
+ }
+ if (N > 2) {
+ acc2.packet[2] = padd<Packetc>(tRes.packet[2+N], acc2.packet[2]);
+ }
+ if (N > 3) {
+ acc2.packet[3] = padd<Packetc>(tRes.packet[3+N], acc2.packet[3]);
+ }
}
// This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled.
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
index f1f8352c9..e18b7f267 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
@@ -11,7 +11,7 @@
#ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
#define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H
-#pragma GCC target("cpu=power10")
+#pragma GCC target("cpu=power10,htm")
#ifdef __has_builtin
#if !__has_builtin(__builtin_vsx_assemble_pair)
@@ -32,37 +32,37 @@ EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc)
}
template<typename DataMapper, typename Index, typename Packet, const Index accCols>
-EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
+EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, __vector_quad* acc)
{
PacketBlock<Packet, 4> result;
__builtin_mma_disassemble_acc(&result.packet, acc);
PacketBlock<Packet, 4> tRes;
- bload<DataMapper, Packet, Index, accCols, 0, ColMajor>(tRes, data, i, j);
+ bload<DataMapper, Packet, Index, accCols, ColMajor, false, 4>(tRes, data, i, 0);
- bscale<Packet>(tRes, result, alpha);
+ bscale<Packet, 4>(tRes, result, alpha);
- data.template storePacketBlock<Packet, 4>(i, j, tRes);
+ data.template storePacketBlock<Packet, 4>(i, 0, tRes);
}
-template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC, int N>
-EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
+template<typename DataMapper, typename Index, typename Packet, typename Packetc, const Index accColsC>
+EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag)
{
PacketBlock<Packet, 4> resultReal, resultImag;
__builtin_mma_disassemble_acc(&resultReal.packet, accReal);
__builtin_mma_disassemble_acc(&resultImag.packet, accImag);
PacketBlock<Packetc, 8> tRes;
- bload<DataMapper, Packetc, Index, accColsC, N, ColMajor>(tRes, data, i, j);
+ bload<DataMapper, Packetc, Index, accColsC, ColMajor, true, 4>(tRes, data, i, 0);
PacketBlock<Packet,4> taccReal, taccImag;
bscalec<Packet,4>(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag);
PacketBlock<Packetc, 4> acc1, acc2;
- bcouple<Packet, Packetc>(taccReal, taccImag, tRes, acc1, acc2);
+ bcouple<Packet, Packetc, 4>(taccReal, taccImag, tRes, acc1, acc2);
- data.template storePacketBlock<Packetc, 4>(i + N*accColsC, j, acc1);
- data.template storePacketBlock<Packetc, 4>(i + (N+1)*accColsC, j, acc2);
+ data.template storePacketBlock<Packetc, 4>(i, 0, acc1);
+ data.template storePacketBlock<Packetc, 4>(i + accColsC, 0, acc2);
}
// Defaults to float32, since Eigen still supports C++03 we can't use default template arguments
@@ -127,7 +127,7 @@ EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag
template<typename Scalar, typename Packet>
EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV)
{
- rhsV = ploadRhs<Scalar, Packet>((const Scalar*)(rhs));
+ rhsV = ploadRhs<Scalar, Packet>(rhs);
}
template<>
@@ -186,12 +186,11 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
}
#define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
- type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
+ type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \
MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \
MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \
MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \
- MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \
- MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9);
+ MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7);
#define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \
type rhsV0; \
@@ -224,7 +223,7 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
#define MICRO_MMA_SRC_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
+ lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
} else { \
EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
}
@@ -240,21 +239,19 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&)
#define MICRO_MMA_STORE_ONE(iter) \
if (unroll_factor > iter) { \
- storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, col, res, pAlpha, &accZero##iter); \
+ storeAccumulator<DataMapper, Index, Packet, accCols>(row + iter*accCols, res, pAlpha, &accZero##iter); \
}
#define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE)
template<int unroll_factor, typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
-EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration(
+EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
- Index offsetA,
Index& row,
- Index col,
const Packet& pAlpha)
{
const Scalar* rhs_ptr = rhs_base;
@@ -280,94 +277,98 @@ EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration(
row += unroll_factor*accCols;
}
-template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
-void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+template<typename Scalar, typename Packet, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols>
+EIGEN_ALWAYS_INLINE void gemmMMA_cols(
+ const DataMapper& res,
+ const Scalar* blockA,
+ const Scalar* blockB,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index offsetB,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlpha,
+ const Packet& pMask)
{
- const Index remaining_rows = rows % accCols;
- const Index remaining_cols = cols % accRows;
-
- if( strideA == -1 ) strideA = depth;
- if( strideB == -1 ) strideB = depth;
-
- const Packet pAlpha = pset1<Packet>(alpha);
- const Packet pMask = bmask<Packet>((const int)(remaining_rows));
+ const DataMapper res3 = res.getSubMapper(0, col);
- Index col = 0;
- for(; col + accRows <= cols; col += accRows)
- {
- const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
- const Scalar* lhs_base = blockA;
+ const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA + accCols*offsetA;
+ Index row = 0;
- Index row = 0;
#define MAX_MMA_UNROLL 7
- while(row + MAX_MMA_UNROLL*accCols <= rows) {
- gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- }
- switch( (rows-row)/accCols ) {
+ while(row + MAX_MMA_UNROLL*accCols <= rows) {
+ gemm_unrolled_MMA_iteration<MAX_MMA_UNROLL, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ }
+ switch( (rows-row)/accCols ) {
#if MAX_MMA_UNROLL > 7
- case 7:
- gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
+ case 7:
+ gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
#if MAX_MMA_UNROLL > 6
- case 6:
- gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
+ case 6:
+ gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
#if MAX_MMA_UNROLL > 5
- case 5:
- gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
+ case 5:
+ gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
#if MAX_MMA_UNROLL > 4
- case 4:
- gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
+ case 4:
+ gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
#if MAX_MMA_UNROLL > 3
- case 3:
- gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
+ case 3:
+ gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
#if MAX_MMA_UNROLL > 2
- case 2:
- gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
+ case 2:
+ gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
#if MAX_MMA_UNROLL > 1
- case 1:
- gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha);
- break;
+ case 1:
+ gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha);
+ break;
#endif
- default:
- break;
- }
+ default:
+ break;
+ }
#undef MAX_MMA_UNROLL
- if(remaining_rows > 0)
- {
- gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
- }
- }
+ if(remaining_rows > 0)
+ {
+ gemm_extra_row<Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask);
+ }
+}
- if(remaining_cols > 0)
- {
- const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB;
- const Scalar* lhs_base = blockA;
+template<typename Scalar, typename Index, typename Packet, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols>
+void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
+{
+ const Index remaining_rows = rows % accCols;
- for(; col < cols; col++)
- {
- Index row = 0;
+ if( strideA == -1 ) strideA = depth;
+ if( strideB == -1 ) strideB = depth;
- gemm_unrolled_col<Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha);
+ const Packet pAlpha = pset1<Packet>(alpha);
+ const Packet pMask = bmask<Packet>((const int)(remaining_rows));
- if (remaining_rows > 0)
- {
- gemm_extra_col<Scalar, Packet, DataMapper, Index, accRows>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha);
- }
- rhs_base++;
- }
+ Index col = 0;
+ for(; col + accRows <= cols; col += accRows)
+ {
+ gemmMMA_cols<Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
}
+
+ gemm_extra_cols<Scalar, Packet, DataMapper, Index, accCols>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask);
}
#define accColsC (accCols / 2)
@@ -375,21 +376,20 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
#define advanceCols ((RhsIsReal) ? 1 : 2)
// PEEL_COMPLEX_MMA loop factor.
-#define PEEL_COMPLEX_MMA 7
+#define PEEL_COMPLEX_MMA 3
#define MICRO_COMPLEX_MMA_UNROLL(func) \
- func(0) func(1) func(2) func(3) func(4)
+ func(0) func(1) func(2) func(3)
#define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \
if (unroll_factor > iter) { \
lhsV##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter); \
- lhs_ptr_real##iter += accCols; \
if(!LhsIsReal) { \
- lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_imag##iter); \
- lhs_ptr_imag##iter += accCols; \
+ lhsVi##iter = ploadLhs<Scalar, Packet>(lhs_ptr_real##iter + imag_delta); \
} else { \
EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
} \
+ lhs_ptr_real##iter += accCols; \
} else { \
EIGEN_UNUSED_VARIABLE(lhsV##iter); \
EIGEN_UNUSED_VARIABLE(lhsVi##iter); \
@@ -402,8 +402,8 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
#define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \
if (PEEL_COMPLEX_MMA > peel) { \
- Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \
- Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \
+ Packet lhsV0, lhsV1, lhsV2, lhsV3; \
+ Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \
ploadRhsMMA<Scalar, type>(rhs_ptr_real + (accRows * peel), rhsV##peel); \
if(!RhsIsReal) { \
ploadRhsMMA<Scalar, type>(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \
@@ -411,20 +411,17 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
} \
MICRO_COMPLEX_MMA_UNROLL(func2); \
- func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \
+ func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \
} else { \
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
EIGEN_UNUSED_VARIABLE(rhsVi##peel); \
}
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \
- type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \
- type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \
+ type rhsV0, rhsV1, rhsV2, rhsV3; \
+ type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \
MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \
- MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \
- MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \
- MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \
- MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9);
+ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3);
#define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \
type rhsV0, rhsVi0; \
@@ -461,15 +458,9 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
#define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \
if (unroll_factor > iter) { \
- lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \
- if(!LhsIsReal) { \
- lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \
- } else { \
- EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
- } \
+ lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \
} else { \
EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \
- EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \
}
#define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE)
@@ -477,45 +468,40 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB,
#define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \
if (unroll_factor > iter) { \
EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \
- if(!LhsIsReal) { \
- EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \
- } \
}
#define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE)
#define MICRO_COMPLEX_MMA_STORE_ONE(iter) \
if (unroll_factor > iter) { \
- storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC, 0>(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
+ storeComplexAccumulator<DataMapper, Index, Packet, Packetc, accColsC>(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \
}
#define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE)
template<int unroll_factor, typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
-EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration(
+EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
const DataMapper& res,
const Scalar* lhs_base,
const Scalar* rhs_base,
Index depth,
Index strideA,
- Index offsetA,
Index strideB,
Index& row,
- Index col,
const Packet& pAlphaReal,
const Packet& pAlphaImag)
{
const Scalar* rhs_ptr_real = rhs_base;
const Scalar* rhs_ptr_imag;
+ const Index imag_delta = accCols*strideA;
if(!RhsIsReal) {
rhs_ptr_imag = rhs_base + accRows*strideB;
} else {
EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
}
- const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL;
- const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL;
- const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL;
- __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4;
+ const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL;
+ const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL;
+ __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
MICRO_COMPLEX_MMA_SRC_PTR
MICRO_COMPLEX_MMA_DST_PTR
@@ -539,11 +525,70 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration(
row += unroll_factor*accCols;
}
+template<typename Scalar, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, typename Index, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
+EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols(
+ const DataMapper& res,
+ const Scalar* blockA,
+ const Scalar* blockB,
+ Index depth,
+ Index strideA,
+ Index offsetA,
+ Index strideB,
+ Index offsetB,
+ Index col,
+ Index rows,
+ Index cols,
+ Index remaining_rows,
+ const Packet& pAlphaReal,
+ const Packet& pAlphaImag,
+ const Packet& pMask)
+{
+ const DataMapper res3 = res.getSubMapper(0, col);
+
+ const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
+ const Scalar* lhs_base = blockA + accCols*offsetA;
+ Index row = 0;
+
+#define MAX_COMPLEX_MMA_UNROLL 4
+ while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
+ gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ }
+ switch( (rows-row)/accCols ) {
+#if MAX_COMPLEX_MMA_UNROLL > 4
+ case 4:
+ gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_MMA_UNROLL > 3
+ case 3:
+ gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_MMA_UNROLL > 2
+ case 2:
+ gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ break;
+#endif
+#if MAX_COMPLEX_MMA_UNROLL > 1
+ case 1:
+ gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag);
+ break;
+#endif
+ default:
+ break;
+ }
+#undef MAX_COMPLEX_MMA_UNROLL
+
+ if(remaining_rows > 0)
+ {
+ gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
+ }
+}
+
template<typename LhsScalar, typename RhsScalar, typename Scalarc, typename Scalar, typename Index, typename Packet, typename Packetc, typename RhsPacket, typename DataMapper, const Index accRows, const Index accCols, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB)
{
const Index remaining_rows = rows % accCols;
- const Index remaining_cols = cols % accRows;
if( strideA == -1 ) strideA = depth;
if( strideB == -1 ) strideB = depth;
@@ -558,64 +603,10 @@ void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsS
Index col = 0;
for(; col + accRows <= cols; col += accRows)
{
- const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB;
- const Scalar* lhs_base = blockA;
- Index row = 0;
-
-#define MAX_COMPLEX_MMA_UNROLL 4
- while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) {
- gemm_complex_unrolled_MMA_iteration<MAX_COMPLEX_MMA_UNROLL, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- }
- switch( (rows-row)/accCols ) {
-#if MAX_COMPLEX_MMA_UNROLL > 4
- case 4:
- gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- break;
-#endif
-#if MAX_COMPLEX_MMA_UNROLL > 3
- case 3:
- gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- break;
-#endif
-#if MAX_COMPLEX_MMA_UNROLL > 2
- case 2:
- gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- break;
-#endif
-#if MAX_COMPLEX_MMA_UNROLL > 1
- case 1:
- gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag);
- break;
-#endif
- default:
- break;
- }
-#undef MAX_COMPLEX_MMA_UNROLL
-
- if(remaining_rows > 0)
- {
- gemm_complex_extra_row<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
- }
+ gemmMMA_complex_cols<Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
- if(remaining_cols > 0)
- {
- const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB;
- const Scalar* lhs_base = blockA;
-
- for(; col < cols; col++)
- {
- Index row = 0;
-
- gemm_complex_unrolled_col<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag);
-
- if (remaining_rows > 0)
- {
- gemm_complex_extra_col<Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag);
- }
- rhs_base++;
- }
- }
+ gemm_complex_extra_cols<Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask);
}
#undef accColsC
--
GitLab
From efdb8ac4662beb0c171adec9d36bbb8a6269488b Mon Sep 17 00:00:00 2001
From: Chip-Kerchner <chip.kerchner@ibm.com>
Date: Tue, 26 Oct 2021 16:42:23 -0500
Subject: [PATCH 2/2] Fix used uninitialized warnings.
---
Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 6 +++---
Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h | 2 +-
2 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
index bd5da3623..3745a87cb 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
@@ -1851,11 +1851,11 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration(
const Packet& pMask)
{
const Scalar* rhs_ptr_real = rhs_base;
- const Scalar* rhs_ptr_imag;
+ const Scalar* rhs_ptr_imag = NULL;
if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB;
else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag);
const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA;
- const Scalar* lhs_ptr_imag;
+ const Scalar* lhs_ptr_imag = NULL;
if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA;
else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag);
PacketBlock<Packet,accRows> accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3;
@@ -2078,7 +2078,7 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration(
const Packet& pAlphaImag)
{
const Scalar* rhs_ptr_real = rhs_base;
- const Scalar* rhs_ptr_imag;
+ const Scalar* rhs_ptr_imag = NULL;
const Index imag_delta = accCols*strideA;
if(!RhsIsReal) {
rhs_ptr_imag = rhs_base + accRows*strideB;
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
index e18b7f267..9a3132276 100644
--- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h
@@ -492,7 +492,7 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration(
const Packet& pAlphaImag)
{
const Scalar* rhs_ptr_real = rhs_base;
- const Scalar* rhs_ptr_imag;
+ const Scalar* rhs_ptr_imag = NULL;
const Index imag_delta = accCols*strideA;
if(!RhsIsReal) {
rhs_ptr_imag = rhs_base + accRows*strideB;
--
GitLab