3139 lines
135 KiB
Diff
3139 lines
135 KiB
Diff
|
From 9e3873b1dce3ba65980c7e7b979325dac2fb4bbd Mon Sep 17 00:00:00 2001
|
||
|
From: Chip-Kerchner <chip.kerchner@ibm.com>
|
||
|
Date: Wed, 20 Oct 2021 11:06:50 -0500
|
||
|
Subject: [PATCH 1/2] New branch for inverting rows and depth in non-vectorized
|
||
|
portion of packing.
|
||
|
|
||
|
---
|
||
|
Eigen/src/Core/arch/AltiVec/Complex.h | 10 +-
|
||
|
Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 1546 ++++++++---------
|
||
|
.../Core/arch/AltiVec/MatrixProductCommon.h | 206 +--
|
||
|
.../src/Core/arch/AltiVec/MatrixProductMMA.h | 335 ++--
|
||
|
4 files changed, 927 insertions(+), 1170 deletions(-)
|
||
|
|
||
|
diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h
|
||
|
index f730ce8d3..4fd923e84 100644
|
||
|
--- a/Eigen/src/Core/arch/AltiVec/Complex.h
|
||
|
+++ b/Eigen/src/Core/arch/AltiVec/Complex.h
|
||
|
@@ -129,20 +129,20 @@ template<> EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<
|
||
|
template<> EIGEN_STRONG_INLINE void pstore <std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstore((float*)to, from.v); }
|
||
|
template<> EIGEN_STRONG_INLINE void pstoreu<std::complex<float> >(std::complex<float> * to, const Packet2cf& from) { pstoreu((float*)to, from.v); }
|
||
|
|
||
|
-EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>* from0, const std::complex<float>* from1)
|
||
|
+EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex<float>& from0, const std::complex<float>& from1)
|
||
|
{
|
||
|
Packet4f res0, res1;
|
||
|
#ifdef __VSX__
|
||
|
- __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (*from0));
|
||
|
- __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (*from1));
|
||
|
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (from0));
|
||
|
+ __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (from1));
|
||
|
#ifdef _BIG_ENDIAN
|
||
|
__asm__ ("xxpermdi %x0, %x1, %x2, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
|
||
|
#else
|
||
|
__asm__ ("xxpermdi %x0, %x2, %x1, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1));
|
||
|
#endif
|
||
|
#else
|
||
|
- *reinterpret_cast<std::complex<float> *>(&res0) = *from0;
|
||
|
- *reinterpret_cast<std::complex<float> *>(&res1) = *from1;
|
||
|
+ *reinterpret_cast<std::complex<float> *>(&res0) = from0;
|
||
|
+ *reinterpret_cast<std::complex<float> *>(&res1) = from1;
|
||
|
res0 = vec_perm(res0, res1, p16uc_TRANSPOSE64_HI);
|
||
|
#endif
|
||
|
return Packet2cf(res0);
|
||
|
diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
|
||
|
index 1d67d60d0..bd5da3623 100644
|
||
|
--- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
|
||
|
+++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h
|
||
|
@@ -166,24 +166,23 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex<Scalar>* bloc
|
||
|
|
||
|
rir += vectorDelta;
|
||
|
}
|
||
|
- if (j < cols)
|
||
|
+
|
||
|
+ for(; j < cols; j++)
|
||
|
{
|
||
|
- rii = rir + ((cols - j) * rows);
|
||
|
+ rii = rir + rows;
|
||
|
|
||
|
for(Index i = k2; i < depth; i++)
|
||
|
{
|
||
|
- Index k = j;
|
||
|
- for(; k < cols; k++)
|
||
|
- {
|
||
|
- std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, k, rhs);
|
||
|
+ std::complex<Scalar> v = getAdjointVal<Scalar, Index, StorageOrder>(i, j, rhs);
|
||
|
|
||
|
- blockBf[rir] = v.real();
|
||
|
- blockBf[rii] = v.imag();
|
||
|
+ blockBf[rir] = v.real();
|
||
|
+ blockBf[rii] = v.imag();
|
||
|
|
||
|
- rir += 1;
|
||
|
- rii += 1;
|
||
|
- }
|
||
|
+ rir += 1;
|
||
|
+ rii += 1;
|
||
|
}
|
||
|
+
|
||
|
+ rir += rows;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@@ -262,19 +261,15 @@ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs
|
||
|
}
|
||
|
}
|
||
|
|
||
|
- if (j < cols)
|
||
|
+ for(; j < cols; j++)
|
||
|
{
|
||
|
for(Index i = k2; i < depth; i++)
|
||
|
{
|
||
|
- Index k = j;
|
||
|
- for(; k < cols; k++)
|
||
|
- {
|
||
|
- if(k <= i)
|
||
|
- blockB[ri] = rhs(i, k);
|
||
|
- else
|
||
|
- blockB[ri] = rhs(k, i);
|
||
|
- ri += 1;
|
||
|
- }
|
||
|
+ if(j <= i)
|
||
|
+ blockB[ri] = rhs(i, j);
|
||
|
+ else
|
||
|
+ blockB[ri] = rhs(j, i);
|
||
|
+ ri += 1;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
@@ -408,22 +403,18 @@ struct symm_pack_lhs<double, Index, Pack1, Pack2_dummy, StorageOrder>
|
||
|
* and offset and behaves accordingly.
|
||
|
**/
|
||
|
|
||
|
-template<typename Scalar, typename Packet, typename Index>
|
||
|
-EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,4>& block)
|
||
|
-{
|
||
|
- const Index size = 16 / sizeof(Scalar);
|
||
|
- pstore<Scalar>(to + (0 * size), block.packet[0]);
|
||
|
- pstore<Scalar>(to + (1 * size), block.packet[1]);
|
||
|
- pstore<Scalar>(to + (2 * size), block.packet[2]);
|
||
|
- pstore<Scalar>(to + (3 * size), block.packet[3]);
|
||
|
-}
|
||
|
-
|
||
|
-template<typename Scalar, typename Packet, typename Index>
|
||
|
-EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,2>& block)
|
||
|
+template<typename Scalar, typename Packet, typename Index, int N>
|
||
|
+EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock<Packet,N>& block)
|
||
|
{
|
||
|
const Index size = 16 / sizeof(Scalar);
|
||
|
pstore<Scalar>(to + (0 * size), block.packet[0]);
|
||
|
pstore<Scalar>(to + (1 * size), block.packet[1]);
|
||
|
+ if (N > 2) {
|
||
|
+ pstore<Scalar>(to + (2 * size), block.packet[2]);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ pstore<Scalar>(to + (3 * size), block.packet[3]);
|
||
|
+ }
|
||
|
}
|
||
|
|
||
|
// General template for lhs & rhs complex packing.
|
||
|
@@ -449,9 +440,9 @@ struct dhs_cpack {
|
||
|
PacketBlock<PacketC,8> cblock;
|
||
|
|
||
|
if (UseLhs) {
|
||
|
- bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, j, i);
|
||
|
+ bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, j, i);
|
||
|
} else {
|
||
|
- bload<DataMapper, PacketC, Index, 2, 0, StorageOrder>(cblock, lhs, i, j);
|
||
|
+ bload<DataMapper, PacketC, Index, 2, StorageOrder, true, 4>(cblock, lhs, i, j);
|
||
|
}
|
||
|
|
||
|
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32);
|
||
|
@@ -478,8 +469,8 @@ struct dhs_cpack {
|
||
|
ptranspose(blocki);
|
||
|
}
|
||
|
|
||
|
- storeBlock<Scalar, Packet, Index>(blockAt + rir, blockr);
|
||
|
- storeBlock<Scalar, Packet, Index>(blockAt + rii, blocki);
|
||
|
+ storeBlock<Scalar, Packet, Index, 4>(blockAt + rir, blockr);
|
||
|
+ storeBlock<Scalar, Packet, Index, 4>(blockAt + rii, blocki);
|
||
|
|
||
|
rir += 4*vectorSize;
|
||
|
rii += 4*vectorSize;
|
||
|
@@ -499,21 +490,12 @@ struct dhs_cpack {
|
||
|
cblock.packet[1] = lhs.template loadPacket<PacketC>(i, j + 2);
|
||
|
}
|
||
|
} else {
|
||
|
- std::complex<Scalar> lhs0, lhs1;
|
||
|
if (UseLhs) {
|
||
|
- lhs0 = lhs(j + 0, i);
|
||
|
- lhs1 = lhs(j + 1, i);
|
||
|
- cblock.packet[0] = pload2(&lhs0, &lhs1);
|
||
|
- lhs0 = lhs(j + 2, i);
|
||
|
- lhs1 = lhs(j + 3, i);
|
||
|
- cblock.packet[1] = pload2(&lhs0, &lhs1);
|
||
|
+ cblock.packet[0] = pload2(lhs(j + 0, i), lhs(j + 1, i));
|
||
|
+ cblock.packet[1] = pload2(lhs(j + 2, i), lhs(j + 3, i));
|
||
|
} else {
|
||
|
- lhs0 = lhs(i, j + 0);
|
||
|
- lhs1 = lhs(i, j + 1);
|
||
|
- cblock.packet[0] = pload2(&lhs0, &lhs1);
|
||
|
- lhs0 = lhs(i, j + 2);
|
||
|
- lhs1 = lhs(i, j + 3);
|
||
|
- cblock.packet[1] = pload2(&lhs0, &lhs1);
|
||
|
+ cblock.packet[0] = pload2(lhs(i, j + 0), lhs(i, j + 1));
|
||
|
+ cblock.packet[1] = pload2(lhs(i, j + 2), lhs(i, j + 3));
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@@ -535,34 +517,50 @@ struct dhs_cpack {
|
||
|
rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta);
|
||
|
}
|
||
|
|
||
|
- if (j < rows)
|
||
|
+ if (!UseLhs)
|
||
|
{
|
||
|
- if(PanelMode) rir += (offset*(rows - j - vectorSize));
|
||
|
- rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
|
||
|
+ if(PanelMode) rir -= (offset*(vectorSize - 1));
|
||
|
|
||
|
- for(Index i = 0; i < depth; i++)
|
||
|
+ for(; j < rows; j++)
|
||
|
{
|
||
|
- Index k = j;
|
||
|
- for(; k < rows; k++)
|
||
|
+ rii = rir + ((PanelMode) ? stride : depth);
|
||
|
+
|
||
|
+ for(Index i = 0; i < depth; i++)
|
||
|
{
|
||
|
- if (UseLhs) {
|
||
|
+ blockAt[rir] = lhs(i, j).real();
|
||
|
+
|
||
|
+ if(Conjugate)
|
||
|
+ blockAt[rii] = -lhs(i, j).imag();
|
||
|
+ else
|
||
|
+ blockAt[rii] = lhs(i, j).imag();
|
||
|
+
|
||
|
+ rir += 1;
|
||
|
+ rii += 1;
|
||
|
+ }
|
||
|
+
|
||
|
+ rir += ((PanelMode) ? (2*stride - depth) : depth);
|
||
|
+ }
|
||
|
+ } else {
|
||
|
+ if (j < rows)
|
||
|
+ {
|
||
|
+ if(PanelMode) rir += (offset*(rows - j - vectorSize));
|
||
|
+ rii = rir + (((PanelMode) ? stride : depth) * (rows - j));
|
||
|
+
|
||
|
+ for(Index i = 0; i < depth; i++)
|
||
|
+ {
|
||
|
+ Index k = j;
|
||
|
+ for(; k < rows; k++)
|
||
|
+ {
|
||
|
blockAt[rir] = lhs(k, i).real();
|
||
|
|
||
|
if(Conjugate)
|
||
|
blockAt[rii] = -lhs(k, i).imag();
|
||
|
else
|
||
|
blockAt[rii] = lhs(k, i).imag();
|
||
|
- } else {
|
||
|
- blockAt[rir] = lhs(i, k).real();
|
||
|
|
||
|
- if(Conjugate)
|
||
|
- blockAt[rii] = -lhs(i, k).imag();
|
||
|
- else
|
||
|
- blockAt[rii] = lhs(i, k).imag();
|
||
|
+ rir += 1;
|
||
|
+ rii += 1;
|
||
|
}
|
||
|
-
|
||
|
- rir += 1;
|
||
|
- rii += 1;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
@@ -588,16 +586,16 @@ struct dhs_pack{
|
||
|
PacketBlock<Packet,4> block;
|
||
|
|
||
|
if (UseLhs) {
|
||
|
- bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, j, i);
|
||
|
+ bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, j, i);
|
||
|
} else {
|
||
|
- bload<DataMapper, Packet, Index, 4, 0, StorageOrder>(block, lhs, i, j);
|
||
|
+ bload<DataMapper, Packet, Index, 4, StorageOrder, false, 4>(block, lhs, i, j);
|
||
|
}
|
||
|
if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs))
|
||
|
{
|
||
|
ptranspose(block);
|
||
|
}
|
||
|
|
||
|
- storeBlock<Scalar, Packet, Index>(blockA + ri, block);
|
||
|
+ storeBlock<Scalar, Packet, Index, 4>(blockA + ri, block);
|
||
|
|
||
|
ri += 4*vectorSize;
|
||
|
}
|
||
|
@@ -632,21 +630,33 @@ struct dhs_pack{
|
||
|
if(PanelMode) ri += vectorSize*(stride - offset - depth);
|
||
|
}
|
||
|
|
||
|
- if (j < rows)
|
||
|
+ if (!UseLhs)
|
||
|
{
|
||
|
- if(PanelMode) ri += offset*(rows - j);
|
||
|
+ if(PanelMode) ri += offset;
|
||
|
|
||
|
- for(Index i = 0; i < depth; i++)
|
||
|
+ for(; j < rows; j++)
|
||
|
{
|
||
|
- Index k = j;
|
||
|
- for(; k < rows; k++)
|
||
|
+ for(Index i = 0; i < depth; i++)
|
||
|
{
|
||
|
- if (UseLhs) {
|
||
|
+ blockA[ri] = lhs(i, j);
|
||
|
+ ri += 1;
|
||
|
+ }
|
||
|
+
|
||
|
+ if(PanelMode) ri += stride - depth;
|
||
|
+ }
|
||
|
+ } else {
|
||
|
+ if (j < rows)
|
||
|
+ {
|
||
|
+ if(PanelMode) ri += offset*(rows - j);
|
||
|
+
|
||
|
+ for(Index i = 0; i < depth; i++)
|
||
|
+ {
|
||
|
+ Index k = j;
|
||
|
+ for(; k < rows; k++)
|
||
|
+ {
|
||
|
blockA[ri] = lhs(k, i);
|
||
|
- } else {
|
||
|
- blockA[ri] = lhs(i, k);
|
||
|
+ ri += 1;
|
||
|
}
|
||
|
- ri += 1;
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
@@ -682,7 +692,7 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, tr
|
||
|
block.packet[1] = lhs.template loadPacket<Packet2d>(j, i + 1);
|
||
|
}
|
||
|
|
||
|
- storeBlock<double, Packet2d, Index>(blockA + ri, block);
|
||
|
+ storeBlock<double, Packet2d, Index, 2>(blockA + ri, block);
|
||
|
|
||
|
ri += 2*vectorSize;
|
||
|
}
|
||
|
@@ -759,7 +769,7 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, fa
|
||
|
block.packet[2] = rhs.template loadPacket<Packet2d>(i + 1, j + 0); //[b1 b2]
|
||
|
block.packet[3] = rhs.template loadPacket<Packet2d>(i + 1, j + 2); //[b3 b4]
|
||
|
|
||
|
- storeBlock<double, Packet2d, Index>(blockB + ri, block);
|
||
|
+ storeBlock<double, Packet2d, Index, 4>(blockB + ri, block);
|
||
|
}
|
||
|
|
||
|
ri += 4*vectorSize;
|
||
|
@@ -790,19 +800,17 @@ struct dhs_pack<double, Index, DataMapper, Packet2d, StorageOrder, PanelMode, fa
|
||
|
if(PanelMode) ri += (2*vectorSize)*(stride - offset - depth);
|
||
|
}
|
||
|
|
||
|
- if (j < cols)
|
||
|
- {
|
||
|
- if(PanelMode) ri += offset*(cols - j);
|
||
|
+ if(PanelMode) ri += offset;
|
||
|
|
||
|
+ for(; j < cols; j++)
|
||
|
+ {
|
||
|
for(Index i = 0; i < depth; i++)
|
||
|
{
|
||
|
- Index k = j;
|
||
|
- for(; k < cols; k++)
|
||
|
- {
|
||
|
- blockB[ri] = rhs(i, k);
|
||
|
- ri += 1;
|
||
|
- }
|
||
|
+ blockB[ri] = rhs(i, j);
|
||
|
+ ri += 1;
|
||
|
}
|
||
|
+
|
||
|
+ if(PanelMode) ri += stride - depth;
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
@@ -863,8 +871,8 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
|
||
|
blocki.packet[1] = -blocki.packet[1];
|
||
|
}
|
||
|
|
||
|
- storeBlock<double, Packet, Index>(blockAt + rir, blockr);
|
||
|
- storeBlock<double, Packet, Index>(blockAt + rii, blocki);
|
||
|
+ storeBlock<double, Packet, Index, 2>(blockAt + rir, blockr);
|
||
|
+ storeBlock<double, Packet, Index, 2>(blockAt + rii, blocki);
|
||
|
|
||
|
rir += 2*vectorSize;
|
||
|
rii += 2*vectorSize;
|
||
|
@@ -943,7 +951,7 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
|
||
|
PacketBlock<PacketC,4> cblock;
|
||
|
PacketBlock<Packet,2> blockr, blocki;
|
||
|
|
||
|
- bload<DataMapper, PacketC, Index, 2, 0, ColMajor>(cblock, rhs, i, j);
|
||
|
+ bload<DataMapper, PacketC, Index, 2, ColMajor, false, 4>(cblock, rhs, i, j);
|
||
|
|
||
|
blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64);
|
||
|
blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64);
|
||
|
@@ -957,8 +965,8 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
|
||
|
blocki.packet[1] = -blocki.packet[1];
|
||
|
}
|
||
|
|
||
|
- storeBlock<double, Packet, Index>(blockBt + rir, blockr);
|
||
|
- storeBlock<double, Packet, Index>(blockBt + rii, blocki);
|
||
|
+ storeBlock<double, Packet, Index, 2>(blockBt + rir, blockr);
|
||
|
+ storeBlock<double, Packet, Index, 2>(blockBt + rii, blocki);
|
||
|
|
||
|
rir += 2*vectorSize;
|
||
|
rii += 2*vectorSize;
|
||
|
@@ -967,27 +975,26 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
|
||
|
rir += ((PanelMode) ? (2*vectorSize*(2*stride - depth)) : vectorDelta);
|
||
|
}
|
||
|
|
||
|
- if (j < cols)
|
||
|
+ if(PanelMode) rir -= (offset*(2*vectorSize - 1));
|
||
|
+
|
||
|
+ for(; j < cols; j++)
|
||
|
{
|
||
|
- if(PanelMode) rir += (offset*(cols - j - 2*vectorSize));
|
||
|
- rii = rir + (((PanelMode) ? stride : depth) * (cols - j));
|
||
|
+ rii = rir + ((PanelMode) ? stride : depth);
|
||
|
|
||
|
for(Index i = 0; i < depth; i++)
|
||
|
{
|
||
|
- Index k = j;
|
||
|
- for(; k < cols; k++)
|
||
|
- {
|
||
|
- blockBt[rir] = rhs(i, k).real();
|
||
|
+ blockBt[rir] = rhs(i, j).real();
|
||
|
|
||
|
- if(Conjugate)
|
||
|
- blockBt[rii] = -rhs(i, k).imag();
|
||
|
- else
|
||
|
- blockBt[rii] = rhs(i, k).imag();
|
||
|
+ if(Conjugate)
|
||
|
+ blockBt[rii] = -rhs(i, j).imag();
|
||
|
+ else
|
||
|
+ blockBt[rii] = rhs(i, j).imag();
|
||
|
|
||
|
- rir += 1;
|
||
|
- rii += 1;
|
||
|
- }
|
||
|
+ rir += 1;
|
||
|
+ rii += 1;
|
||
|
}
|
||
|
+
|
||
|
+ rir += ((PanelMode) ? (2*stride - depth) : depth);
|
||
|
}
|
||
|
}
|
||
|
};
|
||
|
@@ -997,31 +1004,32 @@ struct dhs_cpack<double, Index, DataMapper, Packet, PacketC, StorageOrder, Conju
|
||
|
**************/
|
||
|
|
||
|
// 512-bits rank1-update of acc. It can either positive or negative accumulate (useful for complex gemm).
|
||
|
-template<typename Packet, bool NegativeAccumulate>
|
||
|
-EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,4>* acc, const Packet& lhsV, const Packet* rhsV)
|
||
|
-{
|
||
|
- if(NegativeAccumulate)
|
||
|
- {
|
||
|
- acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
|
||
|
- acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
|
||
|
- acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
|
||
|
- acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
|
||
|
- } else {
|
||
|
- acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
|
||
|
- acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
|
||
|
- acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
|
||
|
- acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
|
||
|
- }
|
||
|
-}
|
||
|
-
|
||
|
-template<typename Packet, bool NegativeAccumulate>
|
||
|
-EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,1>* acc, const Packet& lhsV, const Packet* rhsV)
|
||
|
+template<typename Packet, bool NegativeAccumulate, int N>
|
||
|
+EIGEN_ALWAYS_INLINE void pger_common(PacketBlock<Packet,N>* acc, const Packet& lhsV, const Packet* rhsV)
|
||
|
{
|
||
|
if(NegativeAccumulate)
|
||
|
{
|
||
|
acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]);
|
||
|
+ if (N > 1) {
|
||
|
+ acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]);
|
||
|
+ }
|
||
|
} else {
|
||
|
acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]);
|
||
|
+ if (N > 1) {
|
||
|
+ acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]);
|
||
|
+ }
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@@ -1030,11 +1038,11 @@ EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, con
|
||
|
{
|
||
|
Packet lhsV = pload<Packet>(lhs);
|
||
|
|
||
|
- pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
|
||
|
+ pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
|
||
|
}
|
||
|
|
||
|
-template<typename Scalar, typename Packet, typename Index>
|
||
|
-EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows)
|
||
|
+template<typename Scalar, typename Packet, typename Index, const Index remaining_rows>
|
||
|
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV)
|
||
|
{
|
||
|
#ifdef _ARCH_PWR9
|
||
|
lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar));
|
||
|
@@ -1046,32 +1054,32 @@ EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, In
|
||
|
#endif
|
||
|
}
|
||
|
|
||
|
-template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate>
|
||
|
-EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows)
|
||
|
+template<int N, typename Scalar, typename Packet, typename Index, bool NegativeAccumulate, const Index remaining_rows>
|
||
|
+EIGEN_ALWAYS_INLINE void pger(PacketBlock<Packet,N>* acc, const Scalar* lhs, const Packet* rhsV)
|
||
|
{
|
||
|
Packet lhsV;
|
||
|
- loadPacketRemaining<Scalar, Packet, Index>(lhs, lhsV, remaining_rows);
|
||
|
+ loadPacketRemaining<Scalar, Packet, Index, remaining_rows>(lhs, lhsV);
|
||
|
|
||
|
- pger_common<Packet, NegativeAccumulate>(acc, lhsV, rhsV);
|
||
|
+ pger_common<Packet, NegativeAccumulate, N>(acc, lhsV, rhsV);
|
||
|
}
|
||
|
|
||
|
// 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real.
|
||
|
template<int N, typename Packet, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||
|
EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi)
|
||
|
{
|
||
|
- pger_common<Packet, false>(accReal, lhsV, rhsV);
|
||
|
+ pger_common<Packet, false, N>(accReal, lhsV, rhsV);
|
||
|
if(LhsIsReal)
|
||
|
{
|
||
|
- pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
|
||
|
+ pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
|
||
|
EIGEN_UNUSED_VARIABLE(lhsVi);
|
||
|
} else {
|
||
|
if (!RhsIsReal) {
|
||
|
- pger_common<Packet, ConjugateLhs == ConjugateRhs>(accReal, lhsVi, rhsVi);
|
||
|
- pger_common<Packet, ConjugateRhs>(accImag, lhsV, rhsVi);
|
||
|
+ pger_common<Packet, ConjugateLhs == ConjugateRhs, N>(accReal, lhsVi, rhsVi);
|
||
|
+ pger_common<Packet, ConjugateRhs, N>(accImag, lhsV, rhsVi);
|
||
|
} else {
|
||
|
EIGEN_UNUSED_VARIABLE(rhsVi);
|
||
|
}
|
||
|
- pger_common<Packet, ConjugateLhs>(accImag, lhsVi, rhsV);
|
||
|
+ pger_common<Packet, ConjugateLhs, N>(accImag, lhsVi, rhsV);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@@ -1086,8 +1094,8 @@ EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packe
|
||
|
pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
|
||
|
}
|
||
|
|
||
|
-template<typename Scalar, typename Packet, typename Index, bool LhsIsReal>
|
||
|
-EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows)
|
||
|
+template<typename Scalar, typename Packet, typename Index, bool LhsIsReal, const Index remaining_rows>
|
||
|
+EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi)
|
||
|
{
|
||
|
#ifdef _ARCH_PWR9
|
||
|
lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar));
|
||
|
@@ -1103,11 +1111,11 @@ EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar
|
||
|
#endif
|
||
|
}
|
||
|
|
||
|
-template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal>
|
||
|
-EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows)
|
||
|
+template<int N, typename Scalar, typename Packet, typename Index, bool ConjugateLhs, bool ConjugateRhs, bool LhsIsReal, bool RhsIsReal, const Index remaining_rows>
|
||
|
+EIGEN_ALWAYS_INLINE void pgerc(PacketBlock<Packet,N>* accReal, PacketBlock<Packet,N>* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi)
|
||
|
{
|
||
|
Packet lhsV, lhsVi;
|
||
|
- loadPacketRemaining<Scalar, Packet, Index, LhsIsReal>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows);
|
||
|
+ loadPacketRemaining<Scalar, Packet, Index, LhsIsReal, remaining_rows>(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi);
|
||
|
|
||
|
pgerc_common<N, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi);
|
||
|
}
|
||
|
@@ -1119,132 +1127,142 @@ EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs)
|
||
|
}
|
||
|
|
||
|
// Zero the accumulator on PacketBlock.
|
||
|
-template<typename Scalar, typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,4>& acc)
|
||
|
-{
|
||
|
- acc.packet[0] = pset1<Packet>((Scalar)0);
|
||
|
- acc.packet[1] = pset1<Packet>((Scalar)0);
|
||
|
- acc.packet[2] = pset1<Packet>((Scalar)0);
|
||
|
- acc.packet[3] = pset1<Packet>((Scalar)0);
|
||
|
-}
|
||
|
-
|
||
|
-template<typename Scalar, typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,1>& acc)
|
||
|
+template<typename Scalar, typename Packet, int N>
|
||
|
+EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock<Packet,N>& acc)
|
||
|
{
|
||
|
acc.packet[0] = pset1<Packet>((Scalar)0);
|
||
|
+ if (N > 1) {
|
||
|
+ acc.packet[1] = pset1<Packet>((Scalar)0);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc.packet[2] = pset1<Packet>((Scalar)0);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc.packet[3] = pset1<Packet>((Scalar)0);
|
||
|
+ }
|
||
|
}
|
||
|
|
||
|
// Scale the PacketBlock vectors by alpha.
|
||
|
-template<typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
|
||
|
-{
|
||
|
- acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
|
||
|
- acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
|
||
|
- acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
|
||
|
- acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
|
||
|
-}
|
||
|
-
|
||
|
-template<typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
|
||
|
+template<typename Packet, int N>
|
||
|
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
|
||
|
{
|
||
|
acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]);
|
||
|
+ if (N > 1) {
|
||
|
+ acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]);
|
||
|
+ }
|
||
|
}
|
||
|
|
||
|
-template<typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha)
|
||
|
-{
|
||
|
- acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
|
||
|
- acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
|
||
|
- acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
|
||
|
- acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
|
||
|
-}
|
||
|
-
|
||
|
-template<typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,1>& acc, PacketBlock<Packet,1>& accZ, const Packet& pAlpha)
|
||
|
+template<typename Packet, int N>
|
||
|
+EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha)
|
||
|
{
|
||
|
acc.packet[0] = pmul<Packet>(accZ.packet[0], pAlpha);
|
||
|
+ if (N > 1) {
|
||
|
+ acc.packet[1] = pmul<Packet>(accZ.packet[1], pAlpha);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc.packet[2] = pmul<Packet>(accZ.packet[2], pAlpha);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc.packet[3] = pmul<Packet>(accZ.packet[3], pAlpha);
|
||
|
+ }
|
||
|
}
|
||
|
|
||
|
// Complex version of PacketBlock scaling.
|
||
|
template<typename Packet, int N>
|
||
|
EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag)
|
||
|
{
|
||
|
- bscalec_common<Packet>(cReal, aReal, bReal);
|
||
|
+ bscalec_common<Packet, N>(cReal, aReal, bReal);
|
||
|
|
||
|
- bscalec_common<Packet>(cImag, aImag, bReal);
|
||
|
+ bscalec_common<Packet, N>(cImag, aImag, bReal);
|
||
|
|
||
|
- pger_common<Packet, true>(&cReal, bImag, aImag.packet);
|
||
|
+ pger_common<Packet, true, N>(&cReal, bImag, aImag.packet);
|
||
|
|
||
|
- pger_common<Packet, false>(&cImag, bImag, aReal.packet);
|
||
|
+ pger_common<Packet, false, N>(&cImag, bImag, aReal.packet);
|
||
|
}
|
||
|
|
||
|
-template<typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,4>& acc, const Packet& pMask)
|
||
|
+template<typename Packet, int N>
|
||
|
+EIGEN_ALWAYS_INLINE void band(PacketBlock<Packet,N>& acc, const Packet& pMask)
|
||
|
{
|
||
|
acc.packet[0] = pand(acc.packet[0], pMask);
|
||
|
- acc.packet[1] = pand(acc.packet[1], pMask);
|
||
|
- acc.packet[2] = pand(acc.packet[2], pMask);
|
||
|
- acc.packet[3] = pand(acc.packet[3], pMask);
|
||
|
+ if (N > 1) {
|
||
|
+ acc.packet[1] = pand(acc.packet[1], pMask);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc.packet[2] = pand(acc.packet[2], pMask);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc.packet[3] = pand(acc.packet[3], pMask);
|
||
|
+ }
|
||
|
}
|
||
|
|
||
|
-template<typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,4>& aReal, PacketBlock<Packet,4>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,4>& cReal, PacketBlock<Packet,4>& cImag, const Packet& pMask)
|
||
|
+template<typename Packet, int N>
|
||
|
+EIGEN_ALWAYS_INLINE void bscalec(PacketBlock<Packet,N>& aReal, PacketBlock<Packet,N>& aImag, const Packet& bReal, const Packet& bImag, PacketBlock<Packet,N>& cReal, PacketBlock<Packet,N>& cImag, const Packet& pMask)
|
||
|
{
|
||
|
- band<Packet>(aReal, pMask);
|
||
|
- band<Packet>(aImag, pMask);
|
||
|
+ band<Packet, N>(aReal, pMask);
|
||
|
+ band<Packet, N>(aImag, pMask);
|
||
|
|
||
|
- bscalec<Packet,4>(aReal, aImag, bReal, bImag, cReal, cImag);
|
||
|
+ bscalec<Packet,N>(aReal, aImag, bReal, bImag, cReal, cImag);
|
||
|
}
|
||
|
|
||
|
// Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed.
|
||
|
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
|
||
|
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,4>& acc, const DataMapper& res, Index row, Index col)
|
||
|
-{
|
||
|
- if (StorageOrder == RowMajor) {
|
||
|
- acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
|
||
|
- acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
|
||
|
- acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
|
||
|
- acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
|
||
|
- } else {
|
||
|
- acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
|
||
|
- acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
|
||
|
- acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
|
||
|
- acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
|
||
|
- }
|
||
|
-}
|
||
|
-
|
||
|
-// An overload of bload when you have a PacketBLock with 8 vectors.
|
||
|
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
|
||
|
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,8>& acc, const DataMapper& res, Index row, Index col)
|
||
|
+template<typename DataMapper, typename Packet, typename Index, const Index accCols, int StorageOrder, bool Complex, int N>
|
||
|
+EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,N*(Complex?2:1)>& acc, const DataMapper& res, Index row, Index col)
|
||
|
{
|
||
|
if (StorageOrder == RowMajor) {
|
||
|
- acc.packet[0] = res.template loadPacket<Packet>(row + 0, col + N*accCols);
|
||
|
- acc.packet[1] = res.template loadPacket<Packet>(row + 1, col + N*accCols);
|
||
|
- acc.packet[2] = res.template loadPacket<Packet>(row + 2, col + N*accCols);
|
||
|
- acc.packet[3] = res.template loadPacket<Packet>(row + 3, col + N*accCols);
|
||
|
- acc.packet[4] = res.template loadPacket<Packet>(row + 0, col + (N+1)*accCols);
|
||
|
- acc.packet[5] = res.template loadPacket<Packet>(row + 1, col + (N+1)*accCols);
|
||
|
- acc.packet[6] = res.template loadPacket<Packet>(row + 2, col + (N+1)*accCols);
|
||
|
- acc.packet[7] = res.template loadPacket<Packet>(row + 3, col + (N+1)*accCols);
|
||
|
+ acc.packet[0] = res.template loadPacket<Packet>(row + 0, col);
|
||
|
+ if (N > 1) {
|
||
|
+ acc.packet[1] = res.template loadPacket<Packet>(row + 1, col);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc.packet[2] = res.template loadPacket<Packet>(row + 2, col);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc.packet[3] = res.template loadPacket<Packet>(row + 3, col);
|
||
|
+ }
|
||
|
+ if (Complex) {
|
||
|
+ acc.packet[0+N] = res.template loadPacket<Packet>(row + 0, col + accCols);
|
||
|
+ if (N > 1) {
|
||
|
+ acc.packet[1+N] = res.template loadPacket<Packet>(row + 1, col + accCols);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc.packet[2+N] = res.template loadPacket<Packet>(row + 2, col + accCols);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc.packet[3+N] = res.template loadPacket<Packet>(row + 3, col + accCols);
|
||
|
+ }
|
||
|
+ }
|
||
|
} else {
|
||
|
- acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
|
||
|
- acc.packet[1] = res.template loadPacket<Packet>(row + N*accCols, col + 1);
|
||
|
- acc.packet[2] = res.template loadPacket<Packet>(row + N*accCols, col + 2);
|
||
|
- acc.packet[3] = res.template loadPacket<Packet>(row + N*accCols, col + 3);
|
||
|
- acc.packet[4] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
|
||
|
- acc.packet[5] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 1);
|
||
|
- acc.packet[6] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 2);
|
||
|
- acc.packet[7] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 3);
|
||
|
+ acc.packet[0] = res.template loadPacket<Packet>(row, col + 0);
|
||
|
+ if (N > 1) {
|
||
|
+ acc.packet[1] = res.template loadPacket<Packet>(row, col + 1);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc.packet[2] = res.template loadPacket<Packet>(row, col + 2);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc.packet[3] = res.template loadPacket<Packet>(row, col + 3);
|
||
|
+ }
|
||
|
+ if (Complex) {
|
||
|
+ acc.packet[0+N] = res.template loadPacket<Packet>(row + accCols, col + 0);
|
||
|
+ if (N > 1) {
|
||
|
+ acc.packet[1+N] = res.template loadPacket<Packet>(row + accCols, col + 1);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ acc.packet[2+N] = res.template loadPacket<Packet>(row + accCols, col + 2);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ acc.packet[3+N] = res.template loadPacket<Packet>(row + accCols, col + 3);
|
||
|
+ }
|
||
|
+ }
|
||
|
}
|
||
|
}
|
||
|
|
||
|
-template<typename DataMapper, typename Packet, typename Index, const Index accCols, int N, int StorageOrder>
|
||
|
-EIGEN_ALWAYS_INLINE void bload(PacketBlock<Packet,2>& acc, const DataMapper& res, Index row, Index col)
|
||
|
-{
|
||
|
- acc.packet[0] = res.template loadPacket<Packet>(row + N*accCols, col + 0);
|
||
|
- acc.packet[1] = res.template loadPacket<Packet>(row + (N+1)*accCols, col + 0);
|
||
|
-}
|
||
|
-
|
||
|
const static Packet4i mask41 = { -1, 0, 0, 0 };
|
||
|
const static Packet4i mask42 = { -1, -1, 0, 0 };
|
||
|
const static Packet4i mask43 = { -1, -1, -1, 0 };
|
||
|
@@ -1275,22 +1293,44 @@ EIGEN_ALWAYS_INLINE Packet2d bmask<Packet2d>(const int remaining_rows)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
-template<typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,4>& acc, PacketBlock<Packet,4>& accZ, const Packet& pAlpha, const Packet& pMask)
|
||
|
+template<typename Packet, int N>
|
||
|
+EIGEN_ALWAYS_INLINE void bscale(PacketBlock<Packet,N>& acc, PacketBlock<Packet,N>& accZ, const Packet& pAlpha, const Packet& pMask)
|
||
|
{
|
||
|
- band<Packet>(accZ, pMask);
|
||
|
+ band<Packet, N>(accZ, pMask);
|
||
|
|
||
|
- bscale<Packet>(acc, accZ, pAlpha);
|
||
|
+ bscale<Packet, N>(acc, accZ, pAlpha);
|
||
|
}
|
||
|
|
||
|
-template<typename Packet>
|
||
|
-EIGEN_ALWAYS_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3)
|
||
|
+template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
|
||
|
+pbroadcastN_old(const __UNPACK_TYPE__(Packet) *a,
|
||
|
+ Packet& a0, Packet& a1, Packet& a2, Packet& a3)
|
||
|
+{
|
||
|
+ a0 = pset1<Packet>(a[0]);
|
||
|
+ if (N > 1) {
|
||
|
+ a1 = pset1<Packet>(a[1]);
|
||
|
+ } else {
|
||
|
+ EIGEN_UNUSED_VARIABLE(a1);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ a2 = pset1<Packet>(a[2]);
|
||
|
+ } else {
|
||
|
+ EIGEN_UNUSED_VARIABLE(a2);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ a3 = pset1<Packet>(a[3]);
|
||
|
+ } else {
|
||
|
+ EIGEN_UNUSED_VARIABLE(a3);
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
+template<>
|
||
|
+EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet4f,4>(const float* a, Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
|
||
|
{
|
||
|
- pbroadcast4<Packet>(a, a0, a1, a2, a3);
|
||
|
+ pbroadcast4<Packet4f>(a, a0, a1, a2, a3);
|
||
|
}
|
||
|
|
||
|
template<>
|
||
|
-EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
|
||
|
+EIGEN_ALWAYS_INLINE void pbroadcastN_old<Packet2d,4>(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3)
|
||
|
{
|
||
|
a1 = pload<Packet2d>(a);
|
||
|
a3 = pload<Packet2d>(a + 2);
|
||
|
@@ -1300,89 +1340,96 @@ EIGEN_ALWAYS_INLINE void pbroadcast4_old<Packet2d>(const double* a, Packet2d& a0
|
||
|
a3 = vec_splat(a3, 1);
|
||
|
}
|
||
|
|
||
|
-// PEEL loop factor.
|
||
|
-#define PEEL 7
|
||
|
-
|
||
|
-template<typename Scalar, typename Packet, typename Index>
|
||
|
-EIGEN_ALWAYS_INLINE void MICRO_EXTRA_COL(
|
||
|
- const Scalar* &lhs_ptr,
|
||
|
- const Scalar* &rhs_ptr,
|
||
|
- PacketBlock<Packet,1> &accZero,
|
||
|
- Index remaining_rows,
|
||
|
- Index remaining_cols)
|
||
|
+template<typename Packet, int N> EIGEN_ALWAYS_INLINE void
|
||
|
+pbroadcastN(const __UNPACK_TYPE__(Packet) *a,
|
||
|
+ Packet& a0, Packet& a1, Packet& a2, Packet& a3)
|
||
|
{
|
||
|
- Packet rhsV[1];
|
||
|
- rhsV[0] = pset1<Packet>(rhs_ptr[0]);
|
||
|
- pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
|
||
|
- lhs_ptr += remaining_rows;
|
||
|
- rhs_ptr += remaining_cols;
|
||
|
+ a0 = pset1<Packet>(a[0]);
|
||
|
+ if (N > 1) {
|
||
|
+ a1 = pset1<Packet>(a[1]);
|
||
|
+ } else {
|
||
|
+ EIGEN_UNUSED_VARIABLE(a1);
|
||
|
+ }
|
||
|
+ if (N > 2) {
|
||
|
+ a2 = pset1<Packet>(a[2]);
|
||
|
+ } else {
|
||
|
+ EIGEN_UNUSED_VARIABLE(a2);
|
||
|
+ }
|
||
|
+ if (N > 3) {
|
||
|
+ a3 = pset1<Packet>(a[3]);
|
||
|
+ } else {
|
||
|
+ EIGEN_UNUSED_VARIABLE(a3);
|
||
|
+ }
|
||
|
}
|
||
|
|
||
|
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows>
|
||
|
-EIGEN_STRONG_INLINE void gemm_extra_col(
|
||
|
- const DataMapper& res,
|
||
|
- const Scalar* lhs_base,
|
||
|
- const Scalar* rhs_base,
|
||
|
- Index depth,
|
||
|
- Index strideA,
|
||
|
- Index offsetA,
|
||
|
- Index row,
|
||
|
- Index col,
|
||
|
- Index remaining_rows,
|
||
|
- Index remaining_cols,
|
||
|
- const Packet& pAlpha)
|
||
|
+template<> EIGEN_ALWAYS_INLINE void
|
||
|
+pbroadcastN<Packet4f,4>(const float *a,
|
||
|
+ Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3)
|
||
|
{
|
||
|
- const Scalar* rhs_ptr = rhs_base;
|
||
|
- const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
|
||
|
- PacketBlock<Packet,1> accZero;
|
||
|
+ a3 = pload<Packet4f>(a);
|
||
|
+ a0 = vec_splat(a3, 0);
|
||
|
+ a1 = vec_splat(a3, 1);
|
||
|
+ a2 = vec_splat(a3, 2);
|
||
|
+ a3 = vec_splat(a3, 3);
|
||
|
+}
|
||
|
|
||
|
- bsetzero<Scalar, Packet>(accZero);
|
||
|
+// PEEL loop factor.
|
||
|
+#define PEEL 7
|
||
|
+#define PEEL_ROW 7
|
||
|
|
||
|
- Index remaining_depth = (depth & -accRows);
|
||
|
- Index k = 0;
|
||
|
- for(; k + PEEL <= remaining_depth; k+= PEEL)
|
||
|
- {
|
||
|
- EIGEN_POWER_PREFETCH(rhs_ptr);
|
||
|
- EIGEN_POWER_PREFETCH(lhs_ptr);
|
||
|
- for (int l = 0; l < PEEL; l++) {
|
||
|
- MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
|
||
|
- }
|
||
|
- }
|
||
|
- for(; k < remaining_depth; k++)
|
||
|
- {
|
||
|
- MICRO_EXTRA_COL<Scalar, Packet, Index>(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols);
|
||
|
+#define MICRO_UNROLL_PEEL(func) \
|
||
|
+ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
|
||
|
+
|
||
|
+#define MICRO_ZERO_PEEL(peel) \
|
||
|
+ if ((PEEL_ROW > peel) && (peel != 0)) { \
|
||
|
+ bsetzero<Scalar, Packet, accRows>(accZero##peel); \
|
||
|
+ } else { \
|
||
|
+ EIGEN_UNUSED_VARIABLE(accZero##peel); \
|
||
|
}
|
||
|
- for(; k < depth; k++)
|
||
|
- {
|
||
|
- Packet rhsV[1];
|
||
|
- rhsV[0] = pset1<Packet>(rhs_ptr[0]);
|
||
|
- pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
|
||
|
- lhs_ptr += remaining_rows;
|
||
|
- rhs_ptr += remaining_cols;
|
||
|
+
|
||
|
+#define MICRO_ZERO_PEEL_ROW \
|
||
|
+ MICRO_UNROLL_PEEL(MICRO_ZERO_PEEL);
|
||
|
+
|
||
|
+#define MICRO_WORK_PEEL(peel) \
|
||
|
+ if (PEEL_ROW > peel) { \
|
||
|
+ pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
|
||
|
+ pger<accRows, Scalar, Packet, false>(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \
|
||
|
+ } else { \
|
||
|
+ EIGEN_UNUSED_VARIABLE(rhsV##peel); \
|
||
|
}
|
||
|
|
||
|
- accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]);
|
||
|
- for(Index i = 0; i < remaining_rows; i++) {
|
||
|
- res(row + i, col) += accZero.packet[0][i];
|
||
|
+#define MICRO_WORK_PEEL_ROW \
|
||
|
+ Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \
|
||
|
+ MICRO_UNROLL_PEEL(MICRO_WORK_PEEL); \
|
||
|
+ lhs_ptr += (remaining_rows * PEEL_ROW); \
|
||
|
+ rhs_ptr += (accRows * PEEL_ROW);
|
||
|
+
|
||
|
+#define MICRO_ADD_PEEL(peel, sum) \
|
||
|
+ if (PEEL_ROW > peel) { \
|
||
|
+ for (Index i = 0; i < accRows; i++) { \
|
||
|
+ accZero##sum.packet[i] += accZero##peel.packet[i]; \
|
||
|
+ } \
|
||
|
}
|
||
|
-}
|
||
|
|
||
|
-template<typename Scalar, typename Packet, typename Index, const Index accRows>
|
||
|
+#define MICRO_ADD_PEEL_ROW \
|
||
|
+ MICRO_ADD_PEEL(4, 0) MICRO_ADD_PEEL(5, 1) MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) \
|
||
|
+ MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0)
|
||
|
+
|
||
|
+template<typename Scalar, typename Packet, typename Index, const Index accRows, const Index remaining_rows>
|
||
|
EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW(
|
||
|
const Scalar* &lhs_ptr,
|
||
|
const Scalar* &rhs_ptr,
|
||
|
- PacketBlock<Packet,4> &accZero,
|
||
|
- Index remaining_rows)
|
||
|
+ PacketBlock<Packet,accRows> &accZero)
|
||
|
{
|
||
|
Packet rhsV[4];
|
||
|
- pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
|
||
|
- pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
|
||
|
+ pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
|
||
|
+ pger<accRows, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV);
|
||
|
lhs_ptr += remaining_rows;
|
||
|
rhs_ptr += accRows;
|
||
|
}
|
||
|
|
||
|
-template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
||
|
-EIGEN_STRONG_INLINE void gemm_extra_row(
|
||
|
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols, const Index remaining_rows>
|
||
|
+EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration(
|
||
|
const DataMapper& res,
|
||
|
const Scalar* lhs_base,
|
||
|
const Scalar* rhs_base,
|
||
|
@@ -1393,59 +1440,89 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
|
||
|
Index col,
|
||
|
Index rows,
|
||
|
Index cols,
|
||
|
- Index remaining_rows,
|
||
|
const Packet& pAlpha,
|
||
|
const Packet& pMask)
|
||
|
{
|
||
|
const Scalar* rhs_ptr = rhs_base;
|
||
|
const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA;
|
||
|
- PacketBlock<Packet,4> accZero, acc;
|
||
|
+ PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc;
|
||
|
|
||
|
- bsetzero<Scalar, Packet>(accZero);
|
||
|
+ bsetzero<Scalar, Packet, accRows>(accZero0);
|
||
|
|
||
|
- Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows);
|
||
|
+ Index remaining_depth = (col + quad_traits<Scalar>::rows < cols) ? depth : (depth & -quad_traits<Scalar>::rows);
|
||
|
Index k = 0;
|
||
|
- for(; k + PEEL <= remaining_depth; k+= PEEL)
|
||
|
- {
|
||
|
- EIGEN_POWER_PREFETCH(rhs_ptr);
|
||
|
- EIGEN_POWER_PREFETCH(lhs_ptr);
|
||
|
- for (int l = 0; l < PEEL; l++) {
|
||
|
- MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
|
||
|
- }
|
||
|
+ if (remaining_depth >= PEEL_ROW) {
|
||
|
+ MICRO_ZERO_PEEL_ROW
|
||
|
+ do
|
||
|
+ {
|
||
|
+ EIGEN_POWER_PREFETCH(rhs_ptr);
|
||
|
+ EIGEN_POWER_PREFETCH(lhs_ptr);
|
||
|
+ MICRO_WORK_PEEL_ROW
|
||
|
+ } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth);
|
||
|
+ MICRO_ADD_PEEL_ROW
|
||
|
}
|
||
|
for(; k < remaining_depth; k++)
|
||
|
{
|
||
|
- MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows>(lhs_ptr, rhs_ptr, accZero, remaining_rows);
|
||
|
+ MICRO_EXTRA_ROW<Scalar, Packet, Index, accRows, remaining_rows>(lhs_ptr, rhs_ptr, accZero0);
|
||
|
}
|
||
|
|
||
|
if ((remaining_depth == depth) && (rows >= accCols))
|
||
|
{
|
||
|
- for(Index j = 0; j < 4; j++) {
|
||
|
- acc.packet[j] = res.template loadPacket<Packet>(row, col + j);
|
||
|
- }
|
||
|
- bscale<Packet>(acc, accZero, pAlpha, pMask);
|
||
|
- res.template storePacketBlock<Packet,4>(row, col, acc);
|
||
|
+ bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row, 0);
|
||
|
+ bscale<Packet,accRows>(acc, accZero0, pAlpha, pMask);
|
||
|
+ res.template storePacketBlock<Packet,accRows>(row, 0, acc);
|
||
|
} else {
|
||
|
for(; k < depth; k++)
|
||
|
{
|
||
|
Packet rhsV[4];
|
||
|
- pbroadcast4<Packet>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
|
||
|
- pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows);
|
||
|
+ pbroadcastN<Packet,accRows>(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]);
|
||
|
+ pger<accRows, Scalar, Packet, Index, false, remaining_rows>(&accZero0, lhs_ptr, rhsV);
|
||
|
lhs_ptr += remaining_rows;
|
||
|
rhs_ptr += accRows;
|
||
|
}
|
||
|
|
||
|
- for(Index j = 0; j < 4; j++) {
|
||
|
- accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]);
|
||
|
- }
|
||
|
- for(Index j = 0; j < 4; j++) {
|
||
|
+ for(Index j = 0; j < accRows; j++) {
|
||
|
+ accZero0.packet[j] = vec_mul(pAlpha, accZero0.packet[j]);
|
||
|
for(Index i = 0; i < remaining_rows; i++) {
|
||
|
- res(row + i, col + j) += accZero.packet[j][i];
|
||
|
+ res(row + i, j) += accZero0.packet[j][i];
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
||
|
+EIGEN_ALWAYS_INLINE void gemm_extra_row(
|
||
|
+ const DataMapper& res,
|
||
|
+ const Scalar* lhs_base,
|
||
|
+ const Scalar* rhs_base,
|
||
|
+ Index depth,
|
||
|
+ Index strideA,
|
||
|
+ Index offsetA,
|
||
|
+ Index row,
|
||
|
+ Index col,
|
||
|
+ Index rows,
|
||
|
+ Index cols,
|
||
|
+ Index remaining_rows,
|
||
|
+ const Packet& pAlpha,
|
||
|
+ const Packet& pMask)
|
||
|
+{
|
||
|
+ switch(remaining_rows) {
|
||
|
+ case 1:
|
||
|
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 1>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
|
||
|
+ break;
|
||
|
+ case 2:
|
||
|
+ if (sizeof(Scalar) == sizeof(float)) {
|
||
|
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 2>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
|
||
|
+ }
|
||
|
+ break;
|
||
|
+ default:
|
||
|
+ if (sizeof(Scalar) == sizeof(float)) {
|
||
|
+ gemm_unrolled_row_iteration<Scalar, Packet, DataMapper, Index, accRows, accCols, 3>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask);
|
||
|
+ }
|
||
|
+ break;
|
||
|
+ }
|
||
|
+}
|
||
|
+
|
||
|
#define MICRO_UNROLL(func) \
|
||
|
func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7)
|
||
|
|
||
|
@@ -1464,34 +1541,24 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
|
||
|
|
||
|
#define MICRO_WORK_ONE(iter, peel) \
|
||
|
if (unroll_factor > iter) { \
|
||
|
- pger_common<Packet, false>(&accZero##iter, lhsV##iter, rhsV##peel); \
|
||
|
+ pger_common<Packet, false, accRows>(&accZero##iter, lhsV##iter, rhsV##peel); \
|
||
|
}
|
||
|
|
||
|
#define MICRO_TYPE_PEEL4(func, func2, peel) \
|
||
|
if (PEEL > peel) { \
|
||
|
Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
|
||
|
- pbroadcast4<Packet>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
|
||
|
- MICRO_UNROLL_WORK(func, func2, peel) \
|
||
|
- } else { \
|
||
|
- EIGEN_UNUSED_VARIABLE(rhsV##peel); \
|
||
|
- }
|
||
|
-
|
||
|
-#define MICRO_TYPE_PEEL1(func, func2, peel) \
|
||
|
- if (PEEL > peel) { \
|
||
|
- Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \
|
||
|
- rhsV##peel[0] = pset1<Packet>(rhs_ptr[remaining_cols * peel]); \
|
||
|
+ pbroadcastN<Packet,accRows>(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \
|
||
|
MICRO_UNROLL_WORK(func, func2, peel) \
|
||
|
} else { \
|
||
|
EIGEN_UNUSED_VARIABLE(rhsV##peel); \
|
||
|
}
|
||
|
|
||
|
#define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \
|
||
|
- Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \
|
||
|
+ Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \
|
||
|
func(func1,func2,0); func(func1,func2,1); \
|
||
|
func(func1,func2,2); func(func1,func2,3); \
|
||
|
func(func1,func2,4); func(func1,func2,5); \
|
||
|
- func(func1,func2,6); func(func1,func2,7); \
|
||
|
- func(func1,func2,8); func(func1,func2,9);
|
||
|
+ func(func1,func2,6); func(func1,func2,7);
|
||
|
|
||
|
#define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \
|
||
|
Packet rhsV0[M]; \
|
||
|
@@ -1505,17 +1572,9 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
|
||
|
MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
|
||
|
rhs_ptr += accRows;
|
||
|
|
||
|
-#define MICRO_ONE_PEEL1 \
|
||
|
- MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
|
||
|
- rhs_ptr += (remaining_cols * PEEL);
|
||
|
-
|
||
|
-#define MICRO_ONE1 \
|
||
|
- MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \
|
||
|
- rhs_ptr += remaining_cols;
|
||
|
-
|
||
|
#define MICRO_DST_PTR_ONE(iter) \
|
||
|
if (unroll_factor > iter) { \
|
||
|
- bsetzero<Scalar, Packet>(accZero##iter); \
|
||
|
+ bsetzero<Scalar, Packet, accRows>(accZero##iter); \
|
||
|
} else { \
|
||
|
EIGEN_UNUSED_VARIABLE(accZero##iter); \
|
||
|
}
|
||
|
@@ -1524,7 +1583,7 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
|
||
|
|
||
|
#define MICRO_SRC_PTR_ONE(iter) \
|
||
|
if (unroll_factor > iter) { \
|
||
|
- lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \
|
||
|
+ lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \
|
||
|
} else { \
|
||
|
EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \
|
||
|
}
|
||
|
@@ -1540,25 +1599,13 @@ EIGEN_STRONG_INLINE void gemm_extra_row(
|
||
|
|
||
|
#define MICRO_STORE_ONE(iter) \
|
||
|
if (unroll_factor > iter) { \
|
||
|
- acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
|
||
|
- acc.packet[1] = res.template loadPacket<Packet>(row + iter*accCols, col + 1); \
|
||
|
- acc.packet[2] = res.template loadPacket<Packet>(row + iter*accCols, col + 2); \
|
||
|
- acc.packet[3] = res.template loadPacket<Packet>(row + iter*accCols, col + 3); \
|
||
|
- bscale<Packet>(acc, accZero##iter, pAlpha); \
|
||
|
- res.template storePacketBlock<Packet,4>(row + iter*accCols, col, acc); \
|
||
|
+ bload<DataMapper, Packet, Index, 0, ColMajor, false, accRows>(acc, res, row + iter*accCols, 0); \
|
||
|
+ bscale<Packet,accRows>(acc, accZero##iter, pAlpha); \
|
||
|
+ res.template storePacketBlock<Packet,accRows>(row + iter*accCols, 0, acc); \
|
||
|
}
|
||
|
|
||
|
#define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE)
|
||
|
|
||
|
-#define MICRO_COL_STORE_ONE(iter) \
|
||
|
- if (unroll_factor > iter) { \
|
||
|
- acc.packet[0] = res.template loadPacket<Packet>(row + iter*accCols, col + 0); \
|
||
|
- bscale<Packet>(acc, accZero##iter, pAlpha); \
|
||
|
- res.template storePacketBlock<Packet,1>(row + iter*accCols, col, acc); \
|
||
|
- }
|
||
|
-
|
||
|
-#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE)
|
||
|
-
|
||
|
template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
||
|
EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
|
||
|
const DataMapper& res,
|
||
|
@@ -1566,15 +1613,13 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
|
||
|
const Scalar* rhs_base,
|
||
|
Index depth,
|
||
|
Index strideA,
|
||
|
- Index offsetA,
|
||
|
Index& row,
|
||
|
- Index col,
|
||
|
const Packet& pAlpha)
|
||
|
{
|
||
|
const Scalar* rhs_ptr = rhs_base;
|
||
|
const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL;
|
||
|
- PacketBlock<Packet,4> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
|
||
|
- PacketBlock<Packet,4> acc;
|
||
|
+ PacketBlock<Packet,accRows> accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7;
|
||
|
+ PacketBlock<Packet,accRows> acc;
|
||
|
|
||
|
MICRO_SRC_PTR
|
||
|
MICRO_DST_PTR
|
||
|
@@ -1595,101 +1640,100 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration(
|
||
|
row += unroll_factor*accCols;
|
||
|
}
|
||
|
|
||
|
-template<int unroll_factor, typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accCols>
|
||
|
-EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration(
|
||
|
+template<typename Scalar, typename Packet, typename DataMapper, typename Index, const Index accRows, const Index accCols>
|
||
|
+EIGEN_ALWAYS_INLINE void gemm_cols(
|
||
|
const DataMapper& res,
|
||
|
- const Scalar* lhs_base,
|
||
|
- const Scalar* rhs_base,
|
||
|
+ const Scalar* blockA,
|
||
|
+ const Scalar* blockB,
|
||
|
Index depth,
|
||
|
Index strideA,
|
||
|
Index offsetA,
|
||
|
- Index& row,
|
||
|
+ Index strideB,
|
||
|