From 9e3873b1dce3ba65980c7e7b979325dac2fb4bbd Mon Sep 17 00:00:00 2001 From: Chip-Kerchner Date: Wed, 20 Oct 2021 11:06:50 -0500 Subject: [PATCH 1/2] New branch for inverting rows and depth in non-vectorized portion of packing. --- Eigen/src/Core/arch/AltiVec/Complex.h | 10 +- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 1546 ++++++++--------- .../Core/arch/AltiVec/MatrixProductCommon.h | 206 +-- .../src/Core/arch/AltiVec/MatrixProductMMA.h | 335 ++-- 4 files changed, 927 insertions(+), 1170 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/Complex.h b/Eigen/src/Core/arch/AltiVec/Complex.h index f730ce8d3..4fd923e84 100644 --- a/Eigen/src/Core/arch/AltiVec/Complex.h +++ b/Eigen/src/Core/arch/AltiVec/Complex.h @@ -129,20 +129,20 @@ template<> EIGEN_STRONG_INLINE Packet2cf ploaddup(const std::complex< template<> EIGEN_STRONG_INLINE void pstore >(std::complex * to, const Packet2cf& from) { pstore((float*)to, from.v); } template<> EIGEN_STRONG_INLINE void pstoreu >(std::complex * to, const Packet2cf& from) { pstoreu((float*)to, from.v); } -EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex* from0, const std::complex* from1) +EIGEN_STRONG_INLINE Packet2cf pload2(const std::complex& from0, const std::complex& from1) { Packet4f res0, res1; #ifdef __VSX__ - __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (*from0)); - __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (*from1)); + __asm__ ("lxsdx %x0,%y1" : "=wa" (res0) : "Z" (from0)); + __asm__ ("lxsdx %x0,%y1" : "=wa" (res1) : "Z" (from1)); #ifdef _BIG_ENDIAN __asm__ ("xxpermdi %x0, %x1, %x2, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1)); #else __asm__ ("xxpermdi %x0, %x2, %x1, 0" : "=wa" (res0) : "wa" (res0), "wa" (res1)); #endif #else - *reinterpret_cast *>(&res0) = *from0; - *reinterpret_cast *>(&res1) = *from1; + *reinterpret_cast *>(&res0) = from0; + *reinterpret_cast *>(&res1) = from1; res0 = vec_perm(res0, res1, p16uc_TRANSPOSE64_HI); #endif return Packet2cf(res0); diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index 1d67d60d0..bd5da3623 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -166,24 +166,23 @@ EIGEN_STRONG_INLINE void symm_pack_complex_rhs_helper(std::complex* bloc rir += vectorDelta; } - if (j < cols) + + for(; j < cols; j++) { - rii = rir + ((cols - j) * rows); + rii = rir + rows; for(Index i = k2; i < depth; i++) { - Index k = j; - for(; k < cols; k++) - { - std::complex v = getAdjointVal(i, k, rhs); + std::complex v = getAdjointVal(i, j, rhs); - blockBf[rir] = v.real(); - blockBf[rii] = v.imag(); + blockBf[rir] = v.real(); + blockBf[rii] = v.imag(); - rir += 1; - rii += 1; - } + rir += 1; + rii += 1; } + + rir += rows; } } @@ -262,19 +261,15 @@ EIGEN_STRONG_INLINE void symm_pack_rhs_helper(Scalar* blockB, const Scalar* _rhs } } - if (j < cols) + for(; j < cols; j++) { for(Index i = k2; i < depth; i++) { - Index k = j; - for(; k < cols; k++) - { - if(k <= i) - blockB[ri] = rhs(i, k); - else - blockB[ri] = rhs(k, i); - ri += 1; - } + if(j <= i) + blockB[ri] = rhs(i, j); + else + blockB[ri] = rhs(j, i); + ri += 1; } } } @@ -408,22 +403,18 @@ struct symm_pack_lhs * and offset and behaves accordingly. **/ -template -EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock& block) -{ - const Index size = 16 / sizeof(Scalar); - pstore(to + (0 * size), block.packet[0]); - pstore(to + (1 * size), block.packet[1]); - pstore(to + (2 * size), block.packet[2]); - pstore(to + (3 * size), block.packet[3]); -} - -template -EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock& block) +template +EIGEN_ALWAYS_INLINE void storeBlock(Scalar* to, PacketBlock& block) { const Index size = 16 / sizeof(Scalar); pstore(to + (0 * size), block.packet[0]); pstore(to + (1 * size), block.packet[1]); + if (N > 2) { + pstore(to + (2 * size), block.packet[2]); + } + if (N > 3) { + pstore(to + (3 * size), block.packet[3]); + } } // General template for lhs & rhs complex packing. @@ -449,9 +440,9 @@ struct dhs_cpack { PacketBlock cblock; if (UseLhs) { - bload(cblock, lhs, j, i); + bload(cblock, lhs, j, i); } else { - bload(cblock, lhs, i, j); + bload(cblock, lhs, i, j); } blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[4].v, p16uc_GETREAL32); @@ -478,8 +469,8 @@ struct dhs_cpack { ptranspose(blocki); } - storeBlock(blockAt + rir, blockr); - storeBlock(blockAt + rii, blocki); + storeBlock(blockAt + rir, blockr); + storeBlock(blockAt + rii, blocki); rir += 4*vectorSize; rii += 4*vectorSize; @@ -499,21 +490,12 @@ struct dhs_cpack { cblock.packet[1] = lhs.template loadPacket(i, j + 2); } } else { - std::complex lhs0, lhs1; if (UseLhs) { - lhs0 = lhs(j + 0, i); - lhs1 = lhs(j + 1, i); - cblock.packet[0] = pload2(&lhs0, &lhs1); - lhs0 = lhs(j + 2, i); - lhs1 = lhs(j + 3, i); - cblock.packet[1] = pload2(&lhs0, &lhs1); + cblock.packet[0] = pload2(lhs(j + 0, i), lhs(j + 1, i)); + cblock.packet[1] = pload2(lhs(j + 2, i), lhs(j + 3, i)); } else { - lhs0 = lhs(i, j + 0); - lhs1 = lhs(i, j + 1); - cblock.packet[0] = pload2(&lhs0, &lhs1); - lhs0 = lhs(i, j + 2); - lhs1 = lhs(i, j + 3); - cblock.packet[1] = pload2(&lhs0, &lhs1); + cblock.packet[0] = pload2(lhs(i, j + 0), lhs(i, j + 1)); + cblock.packet[1] = pload2(lhs(i, j + 2), lhs(i, j + 3)); } } @@ -535,34 +517,50 @@ struct dhs_cpack { rir += ((PanelMode) ? (vectorSize*(2*stride - depth)) : vectorDelta); } - if (j < rows) + if (!UseLhs) { - if(PanelMode) rir += (offset*(rows - j - vectorSize)); - rii = rir + (((PanelMode) ? stride : depth) * (rows - j)); + if(PanelMode) rir -= (offset*(vectorSize - 1)); - for(Index i = 0; i < depth; i++) + for(; j < rows; j++) { - Index k = j; - for(; k < rows; k++) + rii = rir + ((PanelMode) ? stride : depth); + + for(Index i = 0; i < depth; i++) { - if (UseLhs) { + blockAt[rir] = lhs(i, j).real(); + + if(Conjugate) + blockAt[rii] = -lhs(i, j).imag(); + else + blockAt[rii] = lhs(i, j).imag(); + + rir += 1; + rii += 1; + } + + rir += ((PanelMode) ? (2*stride - depth) : depth); + } + } else { + if (j < rows) + { + if(PanelMode) rir += (offset*(rows - j - vectorSize)); + rii = rir + (((PanelMode) ? stride : depth) * (rows - j)); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { blockAt[rir] = lhs(k, i).real(); if(Conjugate) blockAt[rii] = -lhs(k, i).imag(); else blockAt[rii] = lhs(k, i).imag(); - } else { - blockAt[rir] = lhs(i, k).real(); - if(Conjugate) - blockAt[rii] = -lhs(i, k).imag(); - else - blockAt[rii] = lhs(i, k).imag(); + rir += 1; + rii += 1; } - - rir += 1; - rii += 1; } } } @@ -588,16 +586,16 @@ struct dhs_pack{ PacketBlock block; if (UseLhs) { - bload(block, lhs, j, i); + bload(block, lhs, j, i); } else { - bload(block, lhs, i, j); + bload(block, lhs, i, j); } if(((StorageOrder == RowMajor) && UseLhs) || ((StorageOrder == ColMajor) && !UseLhs)) { ptranspose(block); } - storeBlock(blockA + ri, block); + storeBlock(blockA + ri, block); ri += 4*vectorSize; } @@ -632,21 +630,33 @@ struct dhs_pack{ if(PanelMode) ri += vectorSize*(stride - offset - depth); } - if (j < rows) + if (!UseLhs) { - if(PanelMode) ri += offset*(rows - j); + if(PanelMode) ri += offset; - for(Index i = 0; i < depth; i++) + for(; j < rows; j++) { - Index k = j; - for(; k < rows; k++) + for(Index i = 0; i < depth; i++) { - if (UseLhs) { + blockA[ri] = lhs(i, j); + ri += 1; + } + + if(PanelMode) ri += stride - depth; + } + } else { + if (j < rows) + { + if(PanelMode) ri += offset*(rows - j); + + for(Index i = 0; i < depth; i++) + { + Index k = j; + for(; k < rows; k++) + { blockA[ri] = lhs(k, i); - } else { - blockA[ri] = lhs(i, k); + ri += 1; } - ri += 1; } } } @@ -682,7 +692,7 @@ struct dhs_pack(j, i + 1); } - storeBlock(blockA + ri, block); + storeBlock(blockA + ri, block); ri += 2*vectorSize; } @@ -759,7 +769,7 @@ struct dhs_pack(i + 1, j + 0); //[b1 b2] block.packet[3] = rhs.template loadPacket(i + 1, j + 2); //[b3 b4] - storeBlock(blockB + ri, block); + storeBlock(blockB + ri, block); } ri += 4*vectorSize; @@ -790,19 +800,17 @@ struct dhs_pack(blockAt + rir, blockr); - storeBlock(blockAt + rii, blocki); + storeBlock(blockAt + rir, blockr); + storeBlock(blockAt + rii, blocki); rir += 2*vectorSize; rii += 2*vectorSize; @@ -943,7 +951,7 @@ struct dhs_cpack cblock; PacketBlock blockr, blocki; - bload(cblock, rhs, i, j); + bload(cblock, rhs, i, j); blockr.packet[0] = vec_perm(cblock.packet[0].v, cblock.packet[1].v, p16uc_GETREAL64); blockr.packet[1] = vec_perm(cblock.packet[2].v, cblock.packet[3].v, p16uc_GETREAL64); @@ -957,8 +965,8 @@ struct dhs_cpack(blockBt + rir, blockr); - storeBlock(blockBt + rii, blocki); + storeBlock(blockBt + rir, blockr); + storeBlock(blockBt + rii, blocki); rir += 2*vectorSize; rii += 2*vectorSize; @@ -967,27 +975,26 @@ struct dhs_cpack -EIGEN_ALWAYS_INLINE void pger_common(PacketBlock* acc, const Packet& lhsV, const Packet* rhsV) -{ - if(NegativeAccumulate) - { - acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); - acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); - acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); - acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); - } else { - acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); - acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); - acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); - acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); - } -} - -template -EIGEN_ALWAYS_INLINE void pger_common(PacketBlock* acc, const Packet& lhsV, const Packet* rhsV) +template +EIGEN_ALWAYS_INLINE void pger_common(PacketBlock* acc, const Packet& lhsV, const Packet* rhsV) { if(NegativeAccumulate) { acc->packet[0] = vec_nmsub(lhsV, rhsV[0], acc->packet[0]); + if (N > 1) { + acc->packet[1] = vec_nmsub(lhsV, rhsV[1], acc->packet[1]); + } + if (N > 2) { + acc->packet[2] = vec_nmsub(lhsV, rhsV[2], acc->packet[2]); + } + if (N > 3) { + acc->packet[3] = vec_nmsub(lhsV, rhsV[3], acc->packet[3]); + } } else { acc->packet[0] = vec_madd(lhsV, rhsV[0], acc->packet[0]); + if (N > 1) { + acc->packet[1] = vec_madd(lhsV, rhsV[1], acc->packet[1]); + } + if (N > 2) { + acc->packet[2] = vec_madd(lhsV, rhsV[2], acc->packet[2]); + } + if (N > 3) { + acc->packet[3] = vec_madd(lhsV, rhsV[3], acc->packet[3]); + } } } @@ -1030,11 +1038,11 @@ EIGEN_ALWAYS_INLINE void pger(PacketBlock* acc, const Scalar* lhs, con { Packet lhsV = pload(lhs); - pger_common(acc, lhsV, rhsV); + pger_common(acc, lhsV, rhsV); } -template -EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, Index remaining_rows) +template +EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV) { #ifdef _ARCH_PWR9 lhsV = vec_xl_len((Scalar *)lhs, remaining_rows * sizeof(Scalar)); @@ -1046,32 +1054,32 @@ EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs, Packet &lhsV, In #endif } -template -EIGEN_ALWAYS_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV, Index remaining_rows) +template +EIGEN_ALWAYS_INLINE void pger(PacketBlock* acc, const Scalar* lhs, const Packet* rhsV) { Packet lhsV; - loadPacketRemaining(lhs, lhsV, remaining_rows); + loadPacketRemaining(lhs, lhsV); - pger_common(acc, lhsV, rhsV); + pger_common(acc, lhsV, rhsV); } // 512-bits rank1-update of complex acc. It takes decoupled accumulators as entries. It also takes cares of mixed types real * complex and complex * real. template EIGEN_ALWAYS_INLINE void pgerc_common(PacketBlock* accReal, PacketBlock* accImag, const Packet &lhsV, const Packet &lhsVi, const Packet* rhsV, const Packet* rhsVi) { - pger_common(accReal, lhsV, rhsV); + pger_common(accReal, lhsV, rhsV); if(LhsIsReal) { - pger_common(accImag, lhsV, rhsVi); + pger_common(accImag, lhsV, rhsVi); EIGEN_UNUSED_VARIABLE(lhsVi); } else { if (!RhsIsReal) { - pger_common(accReal, lhsVi, rhsVi); - pger_common(accImag, lhsV, rhsVi); + pger_common(accReal, lhsVi, rhsVi); + pger_common(accImag, lhsV, rhsVi); } else { EIGEN_UNUSED_VARIABLE(rhsVi); } - pger_common(accImag, lhsVi, rhsV); + pger_common(accImag, lhsVi, rhsV); } } @@ -1086,8 +1094,8 @@ EIGEN_ALWAYS_INLINE void pgerc(PacketBlock* accReal, PacketBlock(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); } -template -EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi, Index remaining_rows) +template +EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, Packet &lhsV, Packet &lhsVi) { #ifdef _ARCH_PWR9 lhsV = vec_xl_len((Scalar *)lhs_ptr, remaining_rows * sizeof(Scalar)); @@ -1103,11 +1111,11 @@ EIGEN_ALWAYS_INLINE void loadPacketRemaining(const Scalar* lhs_ptr, const Scalar #endif } -template -EIGEN_ALWAYS_INLINE void pgerc(PacketBlock* accReal, PacketBlock* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi, Index remaining_rows) +template +EIGEN_ALWAYS_INLINE void pgerc(PacketBlock* accReal, PacketBlock* accImag, const Scalar* lhs_ptr, const Scalar* lhs_ptr_imag, const Packet* rhsV, const Packet* rhsVi) { Packet lhsV, lhsVi; - loadPacketRemaining(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi, remaining_rows); + loadPacketRemaining(lhs_ptr, lhs_ptr_imag, lhsV, lhsVi); pgerc_common(accReal, accImag, lhsV, lhsVi, rhsV, rhsVi); } @@ -1119,132 +1127,142 @@ EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs) } // Zero the accumulator on PacketBlock. -template -EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock& acc) -{ - acc.packet[0] = pset1((Scalar)0); - acc.packet[1] = pset1((Scalar)0); - acc.packet[2] = pset1((Scalar)0); - acc.packet[3] = pset1((Scalar)0); -} - -template -EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock& acc) +template +EIGEN_ALWAYS_INLINE void bsetzero(PacketBlock& acc) { acc.packet[0] = pset1((Scalar)0); + if (N > 1) { + acc.packet[1] = pset1((Scalar)0); + } + if (N > 2) { + acc.packet[2] = pset1((Scalar)0); + } + if (N > 3) { + acc.packet[3] = pset1((Scalar)0); + } } // Scale the PacketBlock vectors by alpha. -template -EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) -{ - acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); - acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]); - acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]); - acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]); -} - -template -EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +template +EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) { acc.packet[0] = pmadd(pAlpha, accZ.packet[0], acc.packet[0]); + if (N > 1) { + acc.packet[1] = pmadd(pAlpha, accZ.packet[1], acc.packet[1]); + } + if (N > 2) { + acc.packet[2] = pmadd(pAlpha, accZ.packet[2], acc.packet[2]); + } + if (N > 3) { + acc.packet[3] = pmadd(pAlpha, accZ.packet[3], acc.packet[3]); + } } -template -EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) -{ - acc.packet[0] = pmul(accZ.packet[0], pAlpha); - acc.packet[1] = pmul(accZ.packet[1], pAlpha); - acc.packet[2] = pmul(accZ.packet[2], pAlpha); - acc.packet[3] = pmul(accZ.packet[3], pAlpha); -} - -template -EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) +template +EIGEN_ALWAYS_INLINE void bscalec_common(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha) { acc.packet[0] = pmul(accZ.packet[0], pAlpha); + if (N > 1) { + acc.packet[1] = pmul(accZ.packet[1], pAlpha); + } + if (N > 2) { + acc.packet[2] = pmul(accZ.packet[2], pAlpha); + } + if (N > 3) { + acc.packet[3] = pmul(accZ.packet[3], pAlpha); + } } // Complex version of PacketBlock scaling. template EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag) { - bscalec_common(cReal, aReal, bReal); + bscalec_common(cReal, aReal, bReal); - bscalec_common(cImag, aImag, bReal); + bscalec_common(cImag, aImag, bReal); - pger_common(&cReal, bImag, aImag.packet); + pger_common(&cReal, bImag, aImag.packet); - pger_common(&cImag, bImag, aReal.packet); + pger_common(&cImag, bImag, aReal.packet); } -template -EIGEN_ALWAYS_INLINE void band(PacketBlock& acc, const Packet& pMask) +template +EIGEN_ALWAYS_INLINE void band(PacketBlock& acc, const Packet& pMask) { acc.packet[0] = pand(acc.packet[0], pMask); - acc.packet[1] = pand(acc.packet[1], pMask); - acc.packet[2] = pand(acc.packet[2], pMask); - acc.packet[3] = pand(acc.packet[3], pMask); + if (N > 1) { + acc.packet[1] = pand(acc.packet[1], pMask); + } + if (N > 2) { + acc.packet[2] = pand(acc.packet[2], pMask); + } + if (N > 3) { + acc.packet[3] = pand(acc.packet[3], pMask); + } } -template -EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag, const Packet& pMask) +template +EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag, const Packet& pMask) { - band(aReal, pMask); - band(aImag, pMask); + band(aReal, pMask); + band(aImag, pMask); - bscalec(aReal, aImag, bReal, bImag, cReal, cImag); + bscalec(aReal, aImag, bReal, bImag, cReal, cImag); } // Load a PacketBlock, the N parameters make tunning gemm easier so we can add more accumulators as needed. -template -EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) -{ - if (StorageOrder == RowMajor) { - acc.packet[0] = res.template loadPacket(row + 0, col + N*accCols); - acc.packet[1] = res.template loadPacket(row + 1, col + N*accCols); - acc.packet[2] = res.template loadPacket(row + 2, col + N*accCols); - acc.packet[3] = res.template loadPacket(row + 3, col + N*accCols); - } else { - acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); - acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); - acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); - acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); - } -} - -// An overload of bload when you have a PacketBLock with 8 vectors. -template -EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) +template +EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) { if (StorageOrder == RowMajor) { - acc.packet[0] = res.template loadPacket(row + 0, col + N*accCols); - acc.packet[1] = res.template loadPacket(row + 1, col + N*accCols); - acc.packet[2] = res.template loadPacket(row + 2, col + N*accCols); - acc.packet[3] = res.template loadPacket(row + 3, col + N*accCols); - acc.packet[4] = res.template loadPacket(row + 0, col + (N+1)*accCols); - acc.packet[5] = res.template loadPacket(row + 1, col + (N+1)*accCols); - acc.packet[6] = res.template loadPacket(row + 2, col + (N+1)*accCols); - acc.packet[7] = res.template loadPacket(row + 3, col + (N+1)*accCols); + acc.packet[0] = res.template loadPacket(row + 0, col); + if (N > 1) { + acc.packet[1] = res.template loadPacket(row + 1, col); + } + if (N > 2) { + acc.packet[2] = res.template loadPacket(row + 2, col); + } + if (N > 3) { + acc.packet[3] = res.template loadPacket(row + 3, col); + } + if (Complex) { + acc.packet[0+N] = res.template loadPacket(row + 0, col + accCols); + if (N > 1) { + acc.packet[1+N] = res.template loadPacket(row + 1, col + accCols); + } + if (N > 2) { + acc.packet[2+N] = res.template loadPacket(row + 2, col + accCols); + } + if (N > 3) { + acc.packet[3+N] = res.template loadPacket(row + 3, col + accCols); + } + } } else { - acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); - acc.packet[1] = res.template loadPacket(row + N*accCols, col + 1); - acc.packet[2] = res.template loadPacket(row + N*accCols, col + 2); - acc.packet[3] = res.template loadPacket(row + N*accCols, col + 3); - acc.packet[4] = res.template loadPacket(row + (N+1)*accCols, col + 0); - acc.packet[5] = res.template loadPacket(row + (N+1)*accCols, col + 1); - acc.packet[6] = res.template loadPacket(row + (N+1)*accCols, col + 2); - acc.packet[7] = res.template loadPacket(row + (N+1)*accCols, col + 3); + acc.packet[0] = res.template loadPacket(row, col + 0); + if (N > 1) { + acc.packet[1] = res.template loadPacket(row, col + 1); + } + if (N > 2) { + acc.packet[2] = res.template loadPacket(row, col + 2); + } + if (N > 3) { + acc.packet[3] = res.template loadPacket(row, col + 3); + } + if (Complex) { + acc.packet[0+N] = res.template loadPacket(row + accCols, col + 0); + if (N > 1) { + acc.packet[1+N] = res.template loadPacket(row + accCols, col + 1); + } + if (N > 2) { + acc.packet[2+N] = res.template loadPacket(row + accCols, col + 2); + } + if (N > 3) { + acc.packet[3+N] = res.template loadPacket(row + accCols, col + 3); + } + } } } -template -EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col) -{ - acc.packet[0] = res.template loadPacket(row + N*accCols, col + 0); - acc.packet[1] = res.template loadPacket(row + (N+1)*accCols, col + 0); -} - const static Packet4i mask41 = { -1, 0, 0, 0 }; const static Packet4i mask42 = { -1, -1, 0, 0 }; const static Packet4i mask43 = { -1, -1, -1, 0 }; @@ -1275,22 +1293,44 @@ EIGEN_ALWAYS_INLINE Packet2d bmask(const int remaining_rows) } } -template -EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha, const Packet& pMask) +template +EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha, const Packet& pMask) { - band(accZ, pMask); + band(accZ, pMask); - bscale(acc, accZ, pAlpha); + bscale(acc, accZ, pAlpha); } -template -EIGEN_ALWAYS_INLINE void pbroadcast4_old(const __UNPACK_TYPE__(Packet)* a, Packet& a0, Packet& a1, Packet& a2, Packet& a3) +template EIGEN_ALWAYS_INLINE void +pbroadcastN_old(const __UNPACK_TYPE__(Packet) *a, + Packet& a0, Packet& a1, Packet& a2, Packet& a3) +{ + a0 = pset1(a[0]); + if (N > 1) { + a1 = pset1(a[1]); + } else { + EIGEN_UNUSED_VARIABLE(a1); + } + if (N > 2) { + a2 = pset1(a[2]); + } else { + EIGEN_UNUSED_VARIABLE(a2); + } + if (N > 3) { + a3 = pset1(a[3]); + } else { + EIGEN_UNUSED_VARIABLE(a3); + } +} + +template<> +EIGEN_ALWAYS_INLINE void pbroadcastN_old(const float* a, Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) { - pbroadcast4(a, a0, a1, a2, a3); + pbroadcast4(a, a0, a1, a2, a3); } template<> -EIGEN_ALWAYS_INLINE void pbroadcast4_old(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) +EIGEN_ALWAYS_INLINE void pbroadcastN_old(const double* a, Packet2d& a0, Packet2d& a1, Packet2d& a2, Packet2d& a3) { a1 = pload(a); a3 = pload(a + 2); @@ -1300,89 +1340,96 @@ EIGEN_ALWAYS_INLINE void pbroadcast4_old(const double* a, Packet2d& a0 a3 = vec_splat(a3, 1); } -// PEEL loop factor. -#define PEEL 7 - -template -EIGEN_ALWAYS_INLINE void MICRO_EXTRA_COL( - const Scalar* &lhs_ptr, - const Scalar* &rhs_ptr, - PacketBlock &accZero, - Index remaining_rows, - Index remaining_cols) +template EIGEN_ALWAYS_INLINE void +pbroadcastN(const __UNPACK_TYPE__(Packet) *a, + Packet& a0, Packet& a1, Packet& a2, Packet& a3) { - Packet rhsV[1]; - rhsV[0] = pset1(rhs_ptr[0]); - pger<1,Scalar, Packet, false>(&accZero, lhs_ptr, rhsV); - lhs_ptr += remaining_rows; - rhs_ptr += remaining_cols; + a0 = pset1(a[0]); + if (N > 1) { + a1 = pset1(a[1]); + } else { + EIGEN_UNUSED_VARIABLE(a1); + } + if (N > 2) { + a2 = pset1(a[2]); + } else { + EIGEN_UNUSED_VARIABLE(a2); + } + if (N > 3) { + a3 = pset1(a[3]); + } else { + EIGEN_UNUSED_VARIABLE(a3); + } } -template -EIGEN_STRONG_INLINE void gemm_extra_col( - const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, - Index depth, - Index strideA, - Index offsetA, - Index row, - Index col, - Index remaining_rows, - Index remaining_cols, - const Packet& pAlpha) +template<> EIGEN_ALWAYS_INLINE void +pbroadcastN(const float *a, + Packet4f& a0, Packet4f& a1, Packet4f& a2, Packet4f& a3) { - const Scalar* rhs_ptr = rhs_base; - const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; - PacketBlock accZero; + a3 = pload(a); + a0 = vec_splat(a3, 0); + a1 = vec_splat(a3, 1); + a2 = vec_splat(a3, 2); + a3 = vec_splat(a3, 3); +} - bsetzero(accZero); +// PEEL loop factor. +#define PEEL 7 +#define PEEL_ROW 7 - Index remaining_depth = (depth & -accRows); - Index k = 0; - for(; k + PEEL <= remaining_depth; k+= PEEL) - { - EIGEN_POWER_PREFETCH(rhs_ptr); - EIGEN_POWER_PREFETCH(lhs_ptr); - for (int l = 0; l < PEEL; l++) { - MICRO_EXTRA_COL(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); - } - } - for(; k < remaining_depth; k++) - { - MICRO_EXTRA_COL(lhs_ptr, rhs_ptr, accZero, remaining_rows, remaining_cols); +#define MICRO_UNROLL_PEEL(func) \ + func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) + +#define MICRO_ZERO_PEEL(peel) \ + if ((PEEL_ROW > peel) && (peel != 0)) { \ + bsetzero(accZero##peel); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accZero##peel); \ } - for(; k < depth; k++) - { - Packet rhsV[1]; - rhsV[0] = pset1(rhs_ptr[0]); - pger<1, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows); - lhs_ptr += remaining_rows; - rhs_ptr += remaining_cols; + +#define MICRO_ZERO_PEEL_ROW \ + MICRO_UNROLL_PEEL(MICRO_ZERO_PEEL); + +#define MICRO_WORK_PEEL(peel) \ + if (PEEL_ROW > peel) { \ + pbroadcastN(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + pger(&accZero##peel, lhs_ptr + (remaining_rows * peel), rhsV##peel); \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ } - accZero.packet[0] = vec_mul(pAlpha, accZero.packet[0]); - for(Index i = 0; i < remaining_rows; i++) { - res(row + i, col) += accZero.packet[0][i]; +#define MICRO_WORK_PEEL_ROW \ + Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4], rhsV4[4], rhsV5[4], rhsV6[4], rhsV7[4]; \ + MICRO_UNROLL_PEEL(MICRO_WORK_PEEL); \ + lhs_ptr += (remaining_rows * PEEL_ROW); \ + rhs_ptr += (accRows * PEEL_ROW); + +#define MICRO_ADD_PEEL(peel, sum) \ + if (PEEL_ROW > peel) { \ + for (Index i = 0; i < accRows; i++) { \ + accZero##sum.packet[i] += accZero##peel.packet[i]; \ + } \ } -} -template +#define MICRO_ADD_PEEL_ROW \ + MICRO_ADD_PEEL(4, 0) MICRO_ADD_PEEL(5, 1) MICRO_ADD_PEEL(6, 2) MICRO_ADD_PEEL(7, 3) \ + MICRO_ADD_PEEL(2, 0) MICRO_ADD_PEEL(3, 1) MICRO_ADD_PEEL(1, 0) + +template EIGEN_ALWAYS_INLINE void MICRO_EXTRA_ROW( const Scalar* &lhs_ptr, const Scalar* &rhs_ptr, - PacketBlock &accZero, - Index remaining_rows) + PacketBlock &accZero) { Packet rhsV[4]; - pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - pger<4, Scalar, Packet, false>(&accZero, lhs_ptr, rhsV); + pbroadcastN(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + pger(&accZero, lhs_ptr, rhsV); lhs_ptr += remaining_rows; rhs_ptr += accRows; } -template -EIGEN_STRONG_INLINE void gemm_extra_row( +template +EIGEN_ALWAYS_INLINE void gemm_unrolled_row_iteration( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, @@ -1393,59 +1440,89 @@ EIGEN_STRONG_INLINE void gemm_extra_row( Index col, Index rows, Index cols, - Index remaining_rows, const Packet& pAlpha, const Packet& pMask) { const Scalar* rhs_ptr = rhs_base; const Scalar* lhs_ptr = lhs_base + row*strideA + remaining_rows*offsetA; - PacketBlock accZero, acc; + PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7, acc; - bsetzero(accZero); + bsetzero(accZero0); - Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows); + Index remaining_depth = (col + quad_traits::rows < cols) ? depth : (depth & -quad_traits::rows); Index k = 0; - for(; k + PEEL <= remaining_depth; k+= PEEL) - { - EIGEN_POWER_PREFETCH(rhs_ptr); - EIGEN_POWER_PREFETCH(lhs_ptr); - for (int l = 0; l < PEEL; l++) { - MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); - } + if (remaining_depth >= PEEL_ROW) { + MICRO_ZERO_PEEL_ROW + do + { + EIGEN_POWER_PREFETCH(rhs_ptr); + EIGEN_POWER_PREFETCH(lhs_ptr); + MICRO_WORK_PEEL_ROW + } while ((k += PEEL_ROW) + PEEL_ROW <= remaining_depth); + MICRO_ADD_PEEL_ROW } for(; k < remaining_depth; k++) { - MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero, remaining_rows); + MICRO_EXTRA_ROW(lhs_ptr, rhs_ptr, accZero0); } if ((remaining_depth == depth) && (rows >= accCols)) { - for(Index j = 0; j < 4; j++) { - acc.packet[j] = res.template loadPacket(row, col + j); - } - bscale(acc, accZero, pAlpha, pMask); - res.template storePacketBlock(row, col, acc); + bload(acc, res, row, 0); + bscale(acc, accZero0, pAlpha, pMask); + res.template storePacketBlock(row, 0, acc); } else { for(; k < depth; k++) { Packet rhsV[4]; - pbroadcast4(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - pger<4, Scalar, Packet, Index, false>(&accZero, lhs_ptr, rhsV, remaining_rows); + pbroadcastN(rhs_ptr, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + pger(&accZero0, lhs_ptr, rhsV); lhs_ptr += remaining_rows; rhs_ptr += accRows; } - for(Index j = 0; j < 4; j++) { - accZero.packet[j] = vec_mul(pAlpha, accZero.packet[j]); - } - for(Index j = 0; j < 4; j++) { + for(Index j = 0; j < accRows; j++) { + accZero0.packet[j] = vec_mul(pAlpha, accZero0.packet[j]); for(Index i = 0; i < remaining_rows; i++) { - res(row + i, col + j) += accZero.packet[j][i]; + res(row + i, j) += accZero0.packet[j][i]; } } } } +template +EIGEN_ALWAYS_INLINE void gemm_extra_row( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index row, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask) +{ + switch(remaining_rows) { + case 1: + gemm_unrolled_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask); + break; + case 2: + if (sizeof(Scalar) == sizeof(float)) { + gemm_unrolled_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask); + } + break; + default: + if (sizeof(Scalar) == sizeof(float)) { + gemm_unrolled_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, pAlpha, pMask); + } + break; + } +} + #define MICRO_UNROLL(func) \ func(0) func(1) func(2) func(3) func(4) func(5) func(6) func(7) @@ -1464,34 +1541,24 @@ EIGEN_STRONG_INLINE void gemm_extra_row( #define MICRO_WORK_ONE(iter, peel) \ if (unroll_factor > iter) { \ - pger_common(&accZero##iter, lhsV##iter, rhsV##peel); \ + pger_common(&accZero##iter, lhsV##iter, rhsV##peel); \ } #define MICRO_TYPE_PEEL4(func, func2, peel) \ if (PEEL > peel) { \ Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ - pbroadcast4(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ - MICRO_UNROLL_WORK(func, func2, peel) \ - } else { \ - EIGEN_UNUSED_VARIABLE(rhsV##peel); \ - } - -#define MICRO_TYPE_PEEL1(func, func2, peel) \ - if (PEEL > peel) { \ - Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4, lhsV5, lhsV6, lhsV7; \ - rhsV##peel[0] = pset1(rhs_ptr[remaining_cols * peel]); \ + pbroadcastN(rhs_ptr + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ MICRO_UNROLL_WORK(func, func2, peel) \ } else { \ EIGEN_UNUSED_VARIABLE(rhsV##peel); \ } #define MICRO_UNROLL_TYPE_PEEL(M, func, func1, func2) \ - Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \ + Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M]; \ func(func1,func2,0); func(func1,func2,1); \ func(func1,func2,2); func(func1,func2,3); \ func(func1,func2,4); func(func1,func2,5); \ - func(func1,func2,6); func(func1,func2,7); \ - func(func1,func2,8); func(func1,func2,9); + func(func1,func2,6); func(func1,func2,7); #define MICRO_UNROLL_TYPE_ONE(M, func, func1, func2) \ Packet rhsV0[M]; \ @@ -1505,17 +1572,9 @@ EIGEN_STRONG_INLINE void gemm_extra_row( MICRO_UNROLL_TYPE_ONE(4, MICRO_TYPE_PEEL4, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ rhs_ptr += accRows; -#define MICRO_ONE_PEEL1 \ - MICRO_UNROLL_TYPE_PEEL(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ - rhs_ptr += (remaining_cols * PEEL); - -#define MICRO_ONE1 \ - MICRO_UNROLL_TYPE_ONE(1, MICRO_TYPE_PEEL1, MICRO_WORK_ONE, MICRO_LOAD_ONE); \ - rhs_ptr += remaining_cols; - #define MICRO_DST_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - bsetzero(accZero##iter); \ + bsetzero(accZero##iter); \ } else { \ EIGEN_UNUSED_VARIABLE(accZero##iter); \ } @@ -1524,7 +1583,7 @@ EIGEN_STRONG_INLINE void gemm_extra_row( #define MICRO_SRC_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \ + lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \ } else { \ EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ } @@ -1540,25 +1599,13 @@ EIGEN_STRONG_INLINE void gemm_extra_row( #define MICRO_STORE_ONE(iter) \ if (unroll_factor > iter) { \ - acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ - acc.packet[1] = res.template loadPacket(row + iter*accCols, col + 1); \ - acc.packet[2] = res.template loadPacket(row + iter*accCols, col + 2); \ - acc.packet[3] = res.template loadPacket(row + iter*accCols, col + 3); \ - bscale(acc, accZero##iter, pAlpha); \ - res.template storePacketBlock(row + iter*accCols, col, acc); \ + bload(acc, res, row + iter*accCols, 0); \ + bscale(acc, accZero##iter, pAlpha); \ + res.template storePacketBlock(row + iter*accCols, 0, acc); \ } #define MICRO_STORE MICRO_UNROLL(MICRO_STORE_ONE) -#define MICRO_COL_STORE_ONE(iter) \ - if (unroll_factor > iter) { \ - acc.packet[0] = res.template loadPacket(row + iter*accCols, col + 0); \ - bscale(acc, accZero##iter, pAlpha); \ - res.template storePacketBlock(row + iter*accCols, col, acc); \ - } - -#define MICRO_COL_STORE MICRO_UNROLL(MICRO_COL_STORE_ONE) - template EIGEN_STRONG_INLINE void gemm_unrolled_iteration( const DataMapper& res, @@ -1566,15 +1613,13 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration( const Scalar* rhs_base, Index depth, Index strideA, - Index offsetA, Index& row, - Index col, const Packet& pAlpha) { const Scalar* rhs_ptr = rhs_base; const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, * lhs_ptr7 = NULL; - PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; - PacketBlock acc; + PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; + PacketBlock acc; MICRO_SRC_PTR MICRO_DST_PTR @@ -1595,101 +1640,100 @@ EIGEN_STRONG_INLINE void gemm_unrolled_iteration( row += unroll_factor*accCols; } -template -EIGEN_STRONG_INLINE void gemm_unrolled_col_iteration( +template +EIGEN_ALWAYS_INLINE void gemm_cols( const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, + const Scalar* blockA, + const Scalar* blockB, Index depth, Index strideA, Index offsetA, - Index& row, + Index strideB, + Index offsetB, Index col, - Index remaining_cols, - const Packet& pAlpha) + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask) { - const Scalar* rhs_ptr = rhs_base; - const Scalar* lhs_ptr0 = NULL, * lhs_ptr1 = NULL, * lhs_ptr2 = NULL, * lhs_ptr3 = NULL, * lhs_ptr4 = NULL, * lhs_ptr5 = NULL, * lhs_ptr6 = NULL, *lhs_ptr7 = NULL; - PacketBlock accZero0, accZero1, accZero2, accZero3, accZero4, accZero5, accZero6, accZero7; - PacketBlock acc; + const DataMapper res3 = res.getSubMapper(0, col); - MICRO_SRC_PTR - MICRO_DST_PTR + const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA + accCols*offsetA; + Index row = 0; - Index k = 0; - for(; k + PEEL <= depth; k+= PEEL) - { - EIGEN_POWER_PREFETCH(rhs_ptr); - MICRO_PREFETCH - MICRO_ONE_PEEL1 - } - for(; k < depth; k++) - { - MICRO_ONE1 - } - MICRO_COL_STORE - - row += unroll_factor*accCols; -} - -template -EIGEN_STRONG_INLINE void gemm_unrolled_col( - const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, - Index depth, - Index strideA, - Index offsetA, - Index& row, - Index rows, - Index col, - Index remaining_cols, - const Packet& pAlpha) -{ #define MAX_UNROLL 6 while(row + MAX_UNROLL*accCols <= rows) { - gemm_unrolled_col_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + gemm_unrolled_iteration(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); } switch( (rows-row)/accCols ) { #if MAX_UNROLL > 7 case 7: - gemm_unrolled_col_iteration<7, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); break; #endif #if MAX_UNROLL > 6 case 6: - gemm_unrolled_col_iteration<6, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); break; #endif #if MAX_UNROLL > 5 - case 5: - gemm_unrolled_col_iteration<5, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + case 5: + gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); break; #endif #if MAX_UNROLL > 4 - case 4: - gemm_unrolled_col_iteration<4, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); + case 4: + gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); break; #endif #if MAX_UNROLL > 3 - case 3: - gemm_unrolled_col_iteration<3, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); - break; + case 3: + gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif #if MAX_UNROLL > 2 - case 2: - gemm_unrolled_col_iteration<2, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); - break; + case 2: + gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif #if MAX_UNROLL > 1 - case 1: - gemm_unrolled_col_iteration<1, Scalar, Packet, DataMapper, Index, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_cols, pAlpha); - break; + case 1: + gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif - default: - break; + default: + break; } #undef MAX_UNROLL + + if(remaining_rows > 0) + { + gemm_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); + } +} + +template +EIGEN_STRONG_INLINE void gemm_extra_cols( + const DataMapper& res, + const Scalar* blockA, + const Scalar* blockB, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index offsetB, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask) +{ + for (; col < cols; col++) { + gemm_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); + } } /**************** @@ -1699,7 +1743,6 @@ template(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - } - switch( (rows-row)/accCols ) { -#if MAX_UNROLL > 7 - case 7: - gemm_unrolled_iteration<7, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; -#endif -#if MAX_UNROLL > 6 - case 6: - gemm_unrolled_iteration<6, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; -#endif -#if MAX_UNROLL > 5 - case 5: - gemm_unrolled_iteration<5, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; -#endif -#if MAX_UNROLL > 4 - case 4: - gemm_unrolled_iteration<4, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; -#endif -#if MAX_UNROLL > 3 - case 3: - gemm_unrolled_iteration<3, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; -#endif -#if MAX_UNROLL > 2 - case 2: - gemm_unrolled_iteration<2, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; -#endif -#if MAX_UNROLL > 1 - case 1: - gemm_unrolled_iteration<1, Scalar, Packet, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; -#endif - default: - break; - } -#undef MAX_UNROLL - - if(remaining_rows > 0) - { - gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); - } - } - - if(remaining_cols > 0) - { - const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB; - const Scalar* lhs_base = blockA; - - for(; col < cols; col++) - { - Index row = 0; - - gemm_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); - - if (remaining_rows > 0) - { - gemm_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); - } - rhs_base++; + gemm_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); } - } + + gemm_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); } #define accColsC (accCols / 2) @@ -1791,117 +1765,66 @@ EIGEN_STRONG_INLINE void gemm(const DataMapper& res, const Scalar* blockA, const // PEEL_COMPLEX loop factor. #define PEEL_COMPLEX 3 +#define PEEL_COMPLEX_ROW 3 -template -EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_COL( - const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag, - const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag, - PacketBlock &accReal, PacketBlock &accImag, - Index remaining_rows, - Index remaining_cols) -{ - Packet rhsV[1], rhsVi[1]; - rhsV[0] = pset1(rhs_ptr_real[0]); - if(!RhsIsReal) rhsVi[0] = pset1(rhs_ptr_imag[0]); - pgerc<1, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); - lhs_ptr_real += remaining_rows; - if(!LhsIsReal) lhs_ptr_imag += remaining_rows; - else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); - rhs_ptr_real += remaining_cols; - if(!RhsIsReal) rhs_ptr_imag += remaining_cols; - else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); -} - -template -EIGEN_STRONG_INLINE void gemm_complex_extra_col( - const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, - Index depth, - Index strideA, - Index offsetA, - Index strideB, - Index row, - Index col, - Index remaining_rows, - Index remaining_cols, - const Packet& pAlphaReal, - const Packet& pAlphaImag) -{ - const Scalar* rhs_ptr_real = rhs_base; - const Scalar* rhs_ptr_imag; - if(!RhsIsReal) rhs_ptr_imag = rhs_base + remaining_cols*strideB; - else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); - const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA; - const Scalar* lhs_ptr_imag; - if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; - else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); - PacketBlock accReal, accImag; - PacketBlock taccReal, taccImag; - PacketBlock acc0, acc1; - - bsetzero(accReal); - bsetzero(accImag); +#define MICRO_COMPLEX_UNROLL_PEEL(func) \ + func(0) func(1) func(2) func(3) - Index remaining_depth = (depth & -accRows); - Index k = 0; - for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX) - { - EIGEN_POWER_PREFETCH(rhs_ptr_real); - if(!RhsIsReal) { - EIGEN_POWER_PREFETCH(rhs_ptr_imag); - } - EIGEN_POWER_PREFETCH(lhs_ptr_real); - if(!LhsIsReal) { - EIGEN_POWER_PREFETCH(lhs_ptr_imag); - } - for (int l = 0; l < PEEL_COMPLEX; l++) { - MICRO_COMPLEX_EXTRA_COL(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols); - } - } - for(; k < remaining_depth; k++) - { - MICRO_COMPLEX_EXTRA_COL(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows, remaining_cols); +#define MICRO_COMPLEX_ZERO_PEEL(peel) \ + if ((PEEL_COMPLEX_ROW > peel) && (peel != 0)) { \ + bsetzero(accReal##peel); \ + bsetzero(accImag##peel); \ + } else { \ + EIGEN_UNUSED_VARIABLE(accReal##peel); \ + EIGEN_UNUSED_VARIABLE(accImag##peel); \ } - for(; k < depth; k++) - { - Packet rhsV[1], rhsVi[1]; - rhsV[0] = pset1(rhs_ptr_real[0]); - if(!RhsIsReal) rhsVi[0] = pset1(rhs_ptr_imag[0]); - pgerc<1, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows); - lhs_ptr_real += remaining_rows; - if(!LhsIsReal) lhs_ptr_imag += remaining_rows; - rhs_ptr_real += remaining_cols; - if(!RhsIsReal) rhs_ptr_imag += remaining_cols; +#define MICRO_COMPLEX_ZERO_PEEL_ROW \ + MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_ZERO_PEEL); + +#define MICRO_COMPLEX_WORK_PEEL(peel) \ + if (PEEL_COMPLEX_ROW > peel) { \ + pbroadcastN_old(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + if(!RhsIsReal) pbroadcastN_old(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \ + pgerc(&accReal##peel, &accImag##peel, lhs_ptr_real + (remaining_rows * peel), lhs_ptr_imag + (remaining_rows * peel), rhsV##peel, rhsVi##peel); \ + } else { \ + EIGEN_UNUSED_VARIABLE(rhsV##peel); \ + EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ } - bscalec(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag); - bcouple_common(taccReal, taccImag, acc0, acc1); +#define MICRO_COMPLEX_WORK_PEEL_ROW \ + Packet rhsV0[4], rhsV1[4], rhsV2[4], rhsV3[4]; \ + Packet rhsVi0[4], rhsVi1[4], rhsVi2[4], rhsVi3[4]; \ + MICRO_COMPLEX_UNROLL_PEEL(MICRO_COMPLEX_WORK_PEEL); \ + lhs_ptr_real += (remaining_rows * PEEL_COMPLEX_ROW); \ + if(!LhsIsReal) lhs_ptr_imag += (remaining_rows * PEEL_COMPLEX_ROW); \ + else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); \ + rhs_ptr_real += (accRows * PEEL_COMPLEX_ROW); \ + if(!RhsIsReal) rhs_ptr_imag += (accRows * PEEL_COMPLEX_ROW); \ + else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); - if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) - { - res(row + 0, col + 0) += pfirst(acc0.packet[0]); - } else { - acc0.packet[0] += res.template loadPacket(row + 0, col + 0); - res.template storePacketBlock(row + 0, col + 0, acc0); - if(remaining_rows > accColsC) { - res(row + accColsC, col + 0) += pfirst(acc1.packet[0]); - } +#define MICRO_COMPLEX_ADD_PEEL(peel, sum) \ + if (PEEL_COMPLEX_ROW > peel) { \ + for (Index i = 0; i < accRows; i++) { \ + accReal##sum.packet[i] += accReal##peel.packet[i]; \ + accImag##sum.packet[i] += accImag##peel.packet[i]; \ + } \ } -} -template +#define MICRO_COMPLEX_ADD_PEEL_ROW \ + MICRO_COMPLEX_ADD_PEEL(2, 0) MICRO_COMPLEX_ADD_PEEL(3, 1) \ + MICRO_COMPLEX_ADD_PEEL(1, 0) + +template EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW( const Scalar* &lhs_ptr_real, const Scalar* &lhs_ptr_imag, const Scalar* &rhs_ptr_real, const Scalar* &rhs_ptr_imag, - PacketBlock &accReal, PacketBlock &accImag, - Index remaining_rows) + PacketBlock &accReal, PacketBlock &accImag) { Packet rhsV[4], rhsVi[4]; - pbroadcast4_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - if(!RhsIsReal) pbroadcast4_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); - pgerc<4, Scalar, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); + pbroadcastN_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + if(!RhsIsReal) pbroadcastN_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); + pgerc(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); lhs_ptr_real += remaining_rows; if(!LhsIsReal) lhs_ptr_imag += remaining_rows; else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); @@ -1910,8 +1833,8 @@ EIGEN_ALWAYS_INLINE void MICRO_COMPLEX_EXTRA_ROW( else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); } -template -EIGEN_STRONG_INLINE void gemm_complex_extra_row( +template +EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, @@ -1923,7 +1846,6 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row( Index col, Index rows, Index cols, - Index remaining_rows, const Packet& pAlphaReal, const Packet& pAlphaImag, const Packet& pMask) @@ -1936,93 +1858,129 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row( const Scalar* lhs_ptr_imag; if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); - PacketBlock accReal, accImag; - PacketBlock taccReal, taccImag; - PacketBlock acc0, acc1; - PacketBlock tRes; + PacketBlock accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + PacketBlock tRes; - bsetzero(accReal); - bsetzero(accImag); + bsetzero(accReal0); + bsetzero(accImag0); - Index remaining_depth = (col + accRows < cols) ? depth : (depth & -accRows); + Index remaining_depth = (col + quad_traits::rows < cols) ? depth : (depth & -quad_traits::rows); Index k = 0; - for(; k + PEEL_COMPLEX <= remaining_depth; k+= PEEL_COMPLEX) - { - EIGEN_POWER_PREFETCH(rhs_ptr_real); - if(!RhsIsReal) { - EIGEN_POWER_PREFETCH(rhs_ptr_imag); - } - EIGEN_POWER_PREFETCH(lhs_ptr_real); - if(!LhsIsReal) { - EIGEN_POWER_PREFETCH(lhs_ptr_imag); - } - for (int l = 0; l < PEEL_COMPLEX; l++) { - MICRO_COMPLEX_EXTRA_ROW(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows); - } + if (remaining_depth >= PEEL_COMPLEX_ROW) { + MICRO_COMPLEX_ZERO_PEEL_ROW + do + { + EIGEN_POWER_PREFETCH(rhs_ptr_real); + if(!RhsIsReal) { + EIGEN_POWER_PREFETCH(rhs_ptr_imag); + } + EIGEN_POWER_PREFETCH(lhs_ptr_real); + if(!LhsIsReal) { + EIGEN_POWER_PREFETCH(lhs_ptr_imag); + } + MICRO_COMPLEX_WORK_PEEL_ROW + } while ((k += PEEL_COMPLEX_ROW) + PEEL_COMPLEX_ROW <= remaining_depth); + MICRO_COMPLEX_ADD_PEEL_ROW } for(; k < remaining_depth; k++) { - MICRO_COMPLEX_EXTRA_ROW(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal, accImag, remaining_rows); + MICRO_COMPLEX_EXTRA_ROW(lhs_ptr_real, lhs_ptr_imag, rhs_ptr_real, rhs_ptr_imag, accReal0, accImag0); } if ((remaining_depth == depth) && (rows >= accCols)) { - bload(tRes, res, row, col); - bscalec(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); - bcouple(taccReal, taccImag, tRes, acc0, acc1); - res.template storePacketBlock(row + 0, col, acc0); - res.template storePacketBlock(row + accColsC, col, acc1); + bload(tRes, res, row, 0); + bscalec(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag, pMask); + bcouple(taccReal, taccImag, tRes, acc0, acc1); + res.template storePacketBlock(row + 0, 0, acc0); + res.template storePacketBlock(row + accColsC, 0, acc1); } else { for(; k < depth; k++) { Packet rhsV[4], rhsVi[4]; - pbroadcast4_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); - if(!RhsIsReal) pbroadcast4_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); - pgerc<4, Scalar, Packet, Index, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal, &accImag, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi, remaining_rows); + pbroadcastN_old(rhs_ptr_real, rhsV[0], rhsV[1], rhsV[2], rhsV[3]); + if(!RhsIsReal) pbroadcastN_old(rhs_ptr_imag, rhsVi[0], rhsVi[1], rhsVi[2], rhsVi[3]); + pgerc(&accReal0, &accImag0, lhs_ptr_real, lhs_ptr_imag, rhsV, rhsVi); lhs_ptr_real += remaining_rows; if(!LhsIsReal) lhs_ptr_imag += remaining_rows; rhs_ptr_real += accRows; if(!RhsIsReal) rhs_ptr_imag += accRows; } - bscalec(accReal, accImag, pAlphaReal, pAlphaImag, taccReal, taccImag); - bcouple_common(taccReal, taccImag, acc0, acc1); + bscalec(accReal0, accImag0, pAlphaReal, pAlphaImag, taccReal, taccImag); + bcouple_common(taccReal, taccImag, acc0, acc1); if ((sizeof(Scalar) == sizeof(float)) && (remaining_rows == 1)) { - for(Index j = 0; j < 4; j++) { - res(row + 0, col + j) += pfirst(acc0.packet[j]); + for(Index j = 0; j < accRows; j++) { + res(row + 0, j) += pfirst(acc0.packet[j]); } } else { - for(Index j = 0; j < 4; j++) { + for(Index j = 0; j < accRows; j++) { PacketBlock acc2; - acc2.packet[0] = res.template loadPacket(row + 0, col + j) + acc0.packet[j]; - res.template storePacketBlock(row + 0, col + j, acc2); + acc2.packet[0] = res.template loadPacket(row + 0, j) + acc0.packet[j]; + res.template storePacketBlock(row + 0, j, acc2); if(remaining_rows > accColsC) { - res(row + accColsC, col + j) += pfirst(acc1.packet[j]); + res(row + accColsC, j) += pfirst(acc1.packet[j]); } } } } } +template +EIGEN_ALWAYS_INLINE void gemm_complex_extra_row( + const DataMapper& res, + const Scalar* lhs_base, + const Scalar* rhs_base, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index row, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlphaReal, + const Packet& pAlphaImag, + const Packet& pMask) +{ + switch(remaining_rows) { + case 1: + gemm_unrolled_complex_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask); + break; + case 2: + if (sizeof(Scalar) == sizeof(float)) { + gemm_unrolled_complex_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask); + } + break; + default: + if (sizeof(Scalar) == sizeof(float)) { + gemm_unrolled_complex_row_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, pAlphaReal, pAlphaImag, pMask); + } + break; + } +} + #define MICRO_COMPLEX_UNROLL(func) \ - func(0) func(1) func(2) func(3) func(4) + func(0) func(1) func(2) func(3) #define MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ MICRO_COMPLEX_UNROLL(func2); \ - func(0,peel) func(1,peel) func(2,peel) func(3,peel) func(4,peel) + func(0,peel) func(1,peel) func(2,peel) func(3,peel) #define MICRO_COMPLEX_LOAD_ONE(iter) \ if (unroll_factor > iter) { \ lhsV##iter = ploadLhs(lhs_ptr_real##iter); \ - lhs_ptr_real##iter += accCols; \ if(!LhsIsReal) { \ - lhsVi##iter = ploadLhs(lhs_ptr_imag##iter); \ - lhs_ptr_imag##iter += accCols; \ + lhsVi##iter = ploadLhs(lhs_ptr_real##iter + imag_delta); \ } else { \ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ } \ + lhs_ptr_real##iter += accCols; \ } else { \ EIGEN_UNUSED_VARIABLE(lhsV##iter); \ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ @@ -2030,37 +1988,16 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row( #define MICRO_COMPLEX_WORK_ONE4(iter, peel) \ if (unroll_factor > iter) { \ - pgerc_common<4, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ - } - -#define MICRO_COMPLEX_WORK_ONE1(iter, peel) \ - if (unroll_factor > iter) { \ - pgerc_common<1, Packet, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ + pgerc_common(&accReal##iter, &accImag##iter, lhsV##iter, lhsVi##iter, rhsV##peel, rhsVi##peel); \ } #define MICRO_COMPLEX_TYPE_PEEL4(func, func2, peel) \ if (PEEL_COMPLEX > peel) { \ - Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ - Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ - pbroadcast4_old(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ + Packet lhsV0, lhsV1, lhsV2, lhsV3; \ + Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \ + pbroadcastN_old(rhs_ptr_real + (accRows * peel), rhsV##peel[0], rhsV##peel[1], rhsV##peel[2], rhsV##peel[3]); \ if(!RhsIsReal) { \ - pbroadcast4_old(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \ - } else { \ - EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ - } \ - MICRO_COMPLEX_UNROLL_WORK(func, func2, peel) \ - } else { \ - EIGEN_UNUSED_VARIABLE(rhsV##peel); \ - EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ - } - -#define MICRO_COMPLEX_TYPE_PEEL1(func, func2, peel) \ - if (PEEL_COMPLEX > peel) { \ - Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ - Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ - rhsV##peel[0] = pset1(rhs_ptr_real[remaining_cols * peel]); \ - if(!RhsIsReal) { \ - rhsVi##peel[0] = pset1(rhs_ptr_imag[remaining_cols * peel]); \ + pbroadcastN_old(rhs_ptr_imag + (accRows * peel), rhsVi##peel[0], rhsVi##peel[1], rhsVi##peel[2], rhsVi##peel[3]); \ } else { \ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ } \ @@ -2071,13 +2008,10 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row( } #define MICRO_COMPLEX_UNROLL_TYPE_PEEL(M, func, func1, func2) \ - Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M], rhsV4[M], rhsV5[M], rhsV6[M], rhsV7[M], rhsV8[M], rhsV9[M]; \ - Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M], rhsVi4[M], rhsVi5[M], rhsVi6[M], rhsVi7[M], rhsVi8[M], rhsVi9[M]; \ + Packet rhsV0[M], rhsV1[M], rhsV2[M], rhsV3[M]; \ + Packet rhsVi0[M], rhsVi1[M], rhsVi2[M], rhsVi3[M]; \ func(func1,func2,0); func(func1,func2,1); \ - func(func1,func2,2); func(func1,func2,3); \ - func(func1,func2,4); func(func1,func2,5); \ - func(func1,func2,6); func(func1,func2,7); \ - func(func1,func2,8); func(func1,func2,9); + func(func1,func2,2); func(func1,func2,3); #define MICRO_COMPLEX_UNROLL_TYPE_ONE(M, func, func1, func2) \ Packet rhsV0[M], rhsVi0[M];\ @@ -2093,20 +2027,10 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row( rhs_ptr_real += accRows; \ if(!RhsIsReal) rhs_ptr_imag += accRows; -#define MICRO_COMPLEX_ONE_PEEL1 \ - MICRO_COMPLEX_UNROLL_TYPE_PEEL(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \ - rhs_ptr_real += (remaining_cols * PEEL_COMPLEX); \ - if(!RhsIsReal) rhs_ptr_imag += (remaining_cols * PEEL_COMPLEX); - -#define MICRO_COMPLEX_ONE1 \ - MICRO_COMPLEX_UNROLL_TYPE_ONE(1, MICRO_COMPLEX_TYPE_PEEL1, MICRO_COMPLEX_WORK_ONE1, MICRO_COMPLEX_LOAD_ONE); \ - rhs_ptr_real += remaining_cols; \ - if(!RhsIsReal) rhs_ptr_imag += remaining_cols; - #define MICRO_COMPLEX_DST_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - bsetzero(accReal##iter); \ - bsetzero(accImag##iter); \ + bsetzero(accReal##iter); \ + bsetzero(accImag##iter); \ } else { \ EIGEN_UNUSED_VARIABLE(accReal##iter); \ EIGEN_UNUSED_VARIABLE(accImag##iter); \ @@ -2116,15 +2040,9 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row( #define MICRO_COMPLEX_SRC_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \ - if(!LhsIsReal) { \ - lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ - } \ + lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \ } else { \ EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \ - EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ } #define MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_SRC_PTR_ONE) @@ -2132,35 +2050,21 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row( #define MICRO_COMPLEX_PREFETCH_ONE(iter) \ if (unroll_factor > iter) { \ EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \ - if(!LhsIsReal) { \ - EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \ - } \ } #define MICRO_COMPLEX_PREFETCH MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_PREFETCH_ONE) #define MICRO_COMPLEX_STORE_ONE(iter) \ if (unroll_factor > iter) { \ - bload(tRes, res, row + iter*accCols, col); \ - bscalec(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ - bcouple(taccReal, taccImag, tRes, acc0, acc1); \ - res.template storePacketBlock(row + iter*accCols + 0, col, acc0); \ - res.template storePacketBlock(row + iter*accCols + accColsC, col, acc1); \ + bload(tRes, res, row + iter*accCols, 0); \ + bscalec(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ + bcouple(taccReal, taccImag, tRes, acc0, acc1); \ + res.template storePacketBlock(row + iter*accCols + 0, 0, acc0); \ + res.template storePacketBlock(row + iter*accCols + accColsC, 0, acc1); \ } #define MICRO_COMPLEX_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_STORE_ONE) -#define MICRO_COMPLEX_COL_STORE_ONE(iter) \ - if (unroll_factor > iter) { \ - bload(tRes, res, row + iter*accCols, col); \ - bscalec(accReal##iter, accImag##iter, pAlphaReal, pAlphaImag, taccReal, taccImag); \ - bcouple(taccReal, taccImag, tRes, acc0, acc1); \ - res.template storePacketBlock(row + iter*accCols + 0, col, acc0); \ - res.template storePacketBlock(row + iter*accCols + accColsC, col, acc1); \ - } - -#define MICRO_COMPLEX_COL_STORE MICRO_COMPLEX_UNROLL(MICRO_COMPLEX_COL_STORE_ONE) - template EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( const DataMapper& res, @@ -2168,29 +2072,26 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( const Scalar* rhs_base, Index depth, Index strideA, - Index offsetA, Index strideB, Index& row, - Index col, const Packet& pAlphaReal, const Packet& pAlphaImag) { const Scalar* rhs_ptr_real = rhs_base; const Scalar* rhs_ptr_imag; + const Index imag_delta = accCols*strideA; if(!RhsIsReal) { rhs_ptr_imag = rhs_base + accRows*strideB; } else { EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); } - const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL; - const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL; - const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL; - PacketBlock accReal0, accImag0, accReal1, accImag1; - PacketBlock accReal2, accImag2, accReal3, accImag3; - PacketBlock accReal4, accImag4; - PacketBlock taccReal, taccImag; - PacketBlock acc0, acc1; - PacketBlock tRes; + const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL; + const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL; + PacketBlock accReal0, accImag0, accReal1, accImag1; + PacketBlock accReal2, accImag2, accReal3, accImag3; + PacketBlock taccReal, taccImag; + PacketBlock acc0, acc1; + PacketBlock tRes; MICRO_COMPLEX_SRC_PTR MICRO_COMPLEX_DST_PTR @@ -2214,112 +2115,93 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( row += unroll_factor*accCols; } -template -EIGEN_STRONG_INLINE void gemm_complex_unrolled_col_iteration( +template +EIGEN_ALWAYS_INLINE void gemm_complex_cols( const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, + const Scalar* blockA, + const Scalar* blockB, Index depth, Index strideA, Index offsetA, Index strideB, - Index& row, + Index offsetB, Index col, - Index remaining_cols, + Index rows, + Index cols, + Index remaining_rows, const Packet& pAlphaReal, - const Packet& pAlphaImag) + const Packet& pAlphaImag, + const Packet& pMask) { - const Scalar* rhs_ptr_real = rhs_base; - const Scalar* rhs_ptr_imag; - if(!RhsIsReal) { - rhs_ptr_imag = rhs_base + remaining_cols*strideB; - } else { - EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); - } - const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL; - const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL; - const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL; - PacketBlock accReal0, accImag0, accReal1, accImag1; - PacketBlock accReal2, accImag2, accReal3, accImag3; - PacketBlock accReal4, accImag4; - PacketBlock taccReal, taccImag; - PacketBlock acc0, acc1; - PacketBlock tRes; + const DataMapper res3 = res.getSubMapper(0, col); - MICRO_COMPLEX_SRC_PTR - MICRO_COMPLEX_DST_PTR + const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA + accCols*offsetA; + Index row = 0; - Index k = 0; - for(; k + PEEL_COMPLEX <= depth; k+= PEEL_COMPLEX) - { - EIGEN_POWER_PREFETCH(rhs_ptr_real); - if(!RhsIsReal) { - EIGEN_POWER_PREFETCH(rhs_ptr_imag); - } - MICRO_COMPLEX_PREFETCH - MICRO_COMPLEX_ONE_PEEL1 +#define MAX_COMPLEX_UNROLL 3 + while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { + gemm_complex_unrolled_iteration(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); } - for(; k < depth; k++) - { - MICRO_COMPLEX_ONE1 + switch( (rows-row)/accCols ) { +#if MAX_COMPLEX_UNROLL > 4 + case 4: + gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 3 + case 3: + gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 2 + case 2: + gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_UNROLL > 1 + case 1: + gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + break; +#endif + default: + break; } - MICRO_COMPLEX_COL_STORE +#undef MAX_COMPLEX_UNROLL - row += unroll_factor*accCols; + if(remaining_rows > 0) + { + gemm_complex_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + } } template -EIGEN_STRONG_INLINE void gemm_complex_unrolled_col( +EIGEN_STRONG_INLINE void gemm_complex_extra_cols( const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, + const Scalar* blockA, + const Scalar* blockB, Index depth, Index strideA, Index offsetA, Index strideB, - Index& row, - Index rows, + Index offsetB, Index col, - Index remaining_cols, + Index rows, + Index cols, + Index remaining_rows, const Packet& pAlphaReal, - const Packet& pAlphaImag) + const Packet& pAlphaImag, + const Packet& pMask) { -#define MAX_COMPLEX_UNROLL 3 - while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { - gemm_complex_unrolled_col_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); + for (; col < cols; col++) { + gemm_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); } - switch( (rows-row)/accCols ) { -#if MAX_COMPLEX_UNROLL > 4 - case 4: - gemm_complex_unrolled_col_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); - break; -#endif -#if MAX_COMPLEX_UNROLL > 3 - case 3: - gemm_complex_unrolled_col_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); - break; -#endif -#if MAX_COMPLEX_UNROLL > 2 - case 2: - gemm_complex_unrolled_col_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); - break; -#endif -#if MAX_COMPLEX_UNROLL > 1 - case 1: - gemm_complex_unrolled_col_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_cols, pAlphaReal, pAlphaImag); - break; -#endif - default: - break; - } -#undef MAX_COMPLEX_UNROLL } template EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { const Index remaining_rows = rows % accCols; - const Index remaining_cols = cols % accRows; if( strideA == -1 ) strideA = depth; if( strideB == -1 ) strideB = depth; @@ -2334,64 +2216,10 @@ EIGEN_STRONG_INLINE void gemm_complex(const DataMapper& res, const LhsScalar* bl Index col = 0; for(; col + accRows <= cols; col += accRows) { - const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; - const Scalar* lhs_base = blockA; - Index row = 0; - -#define MAX_COMPLEX_UNROLL 3 - while(row + MAX_COMPLEX_UNROLL*accCols <= rows) { - gemm_complex_unrolled_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - } - switch( (rows-row)/accCols ) { -#if MAX_COMPLEX_UNROLL > 4 - case 4: - gemm_complex_unrolled_iteration<4, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - break; -#endif -#if MAX_COMPLEX_UNROLL > 3 - case 3: - gemm_complex_unrolled_iteration<3, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - break; -#endif -#if MAX_COMPLEX_UNROLL > 2 - case 2: - gemm_complex_unrolled_iteration<2, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - break; -#endif -#if MAX_COMPLEX_UNROLL > 1 - case 1: - gemm_complex_unrolled_iteration<1, Scalar, Packet, Packetc, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - break; -#endif - default: - break; - } -#undef MAX_COMPLEX_UNROLL - - if(remaining_rows > 0) - { - gemm_complex_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); - } + gemm_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); } - if(remaining_cols > 0) - { - const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB; - const Scalar* lhs_base = blockA; - - for(; col < cols; col++) - { - Index row = 0; - - gemm_complex_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag); - - if (remaining_rows > 0) - { - gemm_complex_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag); - } - rhs_base++; - } - } + gemm_complex_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); } #undef accColsC diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h index d4287cc6f..768d9c7c4 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -11,22 +11,8 @@ namespace Eigen { namespace internal { -template -EIGEN_STRONG_INLINE void gemm_extra_col( - const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, - Index depth, - Index strideA, - Index offsetA, - Index row, - Index col, - Index remaining_rows, - Index remaining_cols, - const Packet& pAlpha); - template -EIGEN_STRONG_INLINE void gemm_extra_row( +EIGEN_ALWAYS_INLINE void gemm_extra_row( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, @@ -41,41 +27,28 @@ EIGEN_STRONG_INLINE void gemm_extra_row( const Packet& pAlpha, const Packet& pMask); -template -EIGEN_STRONG_INLINE void gemm_unrolled_col( +template +EIGEN_STRONG_INLINE void gemm_extra_cols( const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, + const Scalar* blockA, + const Scalar* blockB, Index depth, Index strideA, Index offsetA, - Index& row, - Index rows, + Index strideB, + Index offsetB, Index col, - Index remaining_cols, - const Packet& pAlpha); + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask); template EIGEN_ALWAYS_INLINE Packet bmask(const int remaining_rows); template -EIGEN_STRONG_INLINE void gemm_complex_extra_col( - const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, - Index depth, - Index strideA, - Index offsetA, - Index strideB, - Index row, - Index col, - Index remaining_rows, - Index remaining_cols, - const Packet& pAlphaReal, - const Packet& pAlphaImag); - -template -EIGEN_STRONG_INLINE void gemm_complex_extra_row( +EIGEN_ALWAYS_INLINE void gemm_complex_extra_row( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, @@ -93,123 +66,88 @@ EIGEN_STRONG_INLINE void gemm_complex_extra_row( const Packet& pMask); template -EIGEN_STRONG_INLINE void gemm_complex_unrolled_col( +EIGEN_STRONG_INLINE void gemm_complex_extra_cols( const DataMapper& res, - const Scalar* lhs_base, - const Scalar* rhs_base, + const Scalar* blockA, + const Scalar* blockB, Index depth, Index strideA, Index offsetA, Index strideB, - Index& row, - Index rows, + Index offsetB, Index col, - Index remaining_cols, + Index rows, + Index cols, + Index remaining_rows, const Packet& pAlphaReal, - const Packet& pAlphaImag); + const Packet& pAlphaImag, + const Packet& pMask); template EIGEN_ALWAYS_INLINE Packet ploadLhs(const Scalar* lhs); -template -EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col); +template +EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col); -template -EIGEN_ALWAYS_INLINE void bload(PacketBlock& acc, const DataMapper& res, Index row, Index col); - -template -EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha); +template +EIGEN_ALWAYS_INLINE void bscale(PacketBlock& acc, PacketBlock& accZ, const Packet& pAlpha); template EIGEN_ALWAYS_INLINE void bscalec(PacketBlock& aReal, PacketBlock& aImag, const Packet& bReal, const Packet& bImag, PacketBlock& cReal, PacketBlock& cImag); -const static Packet16uc p16uc_SETCOMPLEX32_FIRST = { 0, 1, 2, 3, - 16, 17, 18, 19, - 4, 5, 6, 7, - 20, 21, 22, 23}; - -const static Packet16uc p16uc_SETCOMPLEX32_SECOND = { 8, 9, 10, 11, - 24, 25, 26, 27, - 12, 13, 14, 15, - 28, 29, 30, 31}; -//[a,b],[ai,bi] = [a,ai] - This is equivalent to p16uc_GETREAL64 -const static Packet16uc p16uc_SETCOMPLEX64_FIRST = { 0, 1, 2, 3, 4, 5, 6, 7, - 16, 17, 18, 19, 20, 21, 22, 23}; - -//[a,b],[ai,bi] = [b,bi] - This is equivalent to p16uc_GETIMAG64 -const static Packet16uc p16uc_SETCOMPLEX64_SECOND = { 8, 9, 10, 11, 12, 13, 14, 15, - 24, 25, 26, 27, 28, 29, 30, 31}; - - // Grab two decouples real/imaginary PacketBlocks and return two coupled (real/imaginary pairs) PacketBlocks. -template -EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) -{ - acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); - acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_FIRST); - acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_FIRST); - acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_FIRST); - - acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND); - acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX32_SECOND); - acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX32_SECOND); - acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX32_SECOND); -} - -template -EIGEN_ALWAYS_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) -{ - bcouple_common(taccReal, taccImag, acc1, acc2); - - acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); - acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); - acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); - acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); - - acc2.packet[0] = padd(tRes.packet[4], acc2.packet[0]); - acc2.packet[1] = padd(tRes.packet[5], acc2.packet[1]); - acc2.packet[2] = padd(tRes.packet[6], acc2.packet[2]); - acc2.packet[3] = padd(tRes.packet[7], acc2.packet[3]); -} - -template -EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) +template +EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) { - acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_FIRST); - - acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX32_SECOND); + acc1.packet[0].v = vec_mergeh(taccReal.packet[0], taccImag.packet[0]); + if (N > 1) { + acc1.packet[1].v = vec_mergeh(taccReal.packet[1], taccImag.packet[1]); + } + if (N > 2) { + acc1.packet[2].v = vec_mergeh(taccReal.packet[2], taccImag.packet[2]); + } + if (N > 3) { + acc1.packet[3].v = vec_mergeh(taccReal.packet[3], taccImag.packet[3]); + } + + acc2.packet[0].v = vec_mergel(taccReal.packet[0], taccImag.packet[0]); + if (N > 1) { + acc2.packet[1].v = vec_mergel(taccReal.packet[1], taccImag.packet[1]); + } + if (N > 2) { + acc2.packet[2].v = vec_mergel(taccReal.packet[2], taccImag.packet[2]); + } + if (N > 3) { + acc2.packet[3].v = vec_mergel(taccReal.packet[3], taccImag.packet[3]); + } } -template -EIGEN_ALWAYS_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) +template +EIGEN_ALWAYS_INLINE void bcouple(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& tRes, PacketBlock& acc1, PacketBlock& acc2) { - bcouple_common(taccReal, taccImag, acc1, acc2); + bcouple_common(taccReal, taccImag, acc1, acc2); acc1.packet[0] = padd(tRes.packet[0], acc1.packet[0]); - - acc2.packet[0] = padd(tRes.packet[1], acc2.packet[0]); -} - -template<> -EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) -{ - acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); - acc1.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_FIRST); - acc1.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_FIRST); - acc1.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_FIRST); - - acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND); - acc2.packet[1].v = vec_perm(taccReal.packet[1], taccImag.packet[1], p16uc_SETCOMPLEX64_SECOND); - acc2.packet[2].v = vec_perm(taccReal.packet[2], taccImag.packet[2], p16uc_SETCOMPLEX64_SECOND); - acc2.packet[3].v = vec_perm(taccReal.packet[3], taccImag.packet[3], p16uc_SETCOMPLEX64_SECOND); -} - -template<> -EIGEN_ALWAYS_INLINE void bcouple_common(PacketBlock& taccReal, PacketBlock& taccImag, PacketBlock& acc1, PacketBlock& acc2) -{ - acc1.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_FIRST); - - acc2.packet[0].v = vec_perm(taccReal.packet[0], taccImag.packet[0], p16uc_SETCOMPLEX64_SECOND); + if (N > 1) { + acc1.packet[1] = padd(tRes.packet[1], acc1.packet[1]); + } + if (N > 2) { + acc1.packet[2] = padd(tRes.packet[2], acc1.packet[2]); + } + if (N > 3) { + acc1.packet[3] = padd(tRes.packet[3], acc1.packet[3]); + } + + acc2.packet[0] = padd(tRes.packet[0+N], acc2.packet[0]); + if (N > 1) { + acc2.packet[1] = padd(tRes.packet[1+N], acc2.packet[1]); + } + if (N > 2) { + acc2.packet[2] = padd(tRes.packet[2+N], acc2.packet[2]); + } + if (N > 3) { + acc2.packet[3] = padd(tRes.packet[3+N], acc2.packet[3]); + } } // This is necessary because ploadRhs for double returns a pair of vectors when MMA is enabled. diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h index f1f8352c9..e18b7f267 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -11,7 +11,7 @@ #ifndef EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H #define EIGEN_MATRIX_PRODUCT_MMA_ALTIVEC_H -#pragma GCC target("cpu=power10") +#pragma GCC target("cpu=power10,htm") #ifdef __has_builtin #if !__has_builtin(__builtin_vsx_assemble_pair) @@ -32,37 +32,37 @@ EIGEN_ALWAYS_INLINE void bsetzeroMMA(__vector_quad* acc) } template -EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, Index j, const DataMapper& data, const Packet& alpha, __vector_quad* acc) +EIGEN_ALWAYS_INLINE void storeAccumulator(Index i, const DataMapper& data, const Packet& alpha, __vector_quad* acc) { PacketBlock result; __builtin_mma_disassemble_acc(&result.packet, acc); PacketBlock tRes; - bload(tRes, data, i, j); + bload(tRes, data, i, 0); - bscale(tRes, result, alpha); + bscale(tRes, result, alpha); - data.template storePacketBlock(i, j, tRes); + data.template storePacketBlock(i, 0, tRes); } -template -EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, Index j, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag) +template +EIGEN_ALWAYS_INLINE void storeComplexAccumulator(Index i, const DataMapper& data, const Packet& alphaReal, const Packet& alphaImag, __vector_quad* accReal, __vector_quad* accImag) { PacketBlock resultReal, resultImag; __builtin_mma_disassemble_acc(&resultReal.packet, accReal); __builtin_mma_disassemble_acc(&resultImag.packet, accImag); PacketBlock tRes; - bload(tRes, data, i, j); + bload(tRes, data, i, 0); PacketBlock taccReal, taccImag; bscalec(resultReal, resultImag, alphaReal, alphaImag, taccReal, taccImag); PacketBlock acc1, acc2; - bcouple(taccReal, taccImag, tRes, acc1, acc2); + bcouple(taccReal, taccImag, tRes, acc1, acc2); - data.template storePacketBlock(i + N*accColsC, j, acc1); - data.template storePacketBlock(i + (N+1)*accColsC, j, acc2); + data.template storePacketBlock(i, 0, acc1); + data.template storePacketBlock(i + accColsC, 0, acc2); } // Defaults to float32, since Eigen still supports C++03 we can't use default template arguments @@ -127,7 +127,7 @@ EIGEN_ALWAYS_INLINE void pgercMMA(__vector_quad* accReal, __vector_quad* accImag template EIGEN_ALWAYS_INLINE void ploadRhsMMA(const Scalar* rhs, Packet& rhsV) { - rhsV = ploadRhs((const Scalar*)(rhs)); + rhsV = ploadRhs(rhs); } template<> @@ -186,12 +186,11 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&) } #define MICRO_MMA_UNROLL_TYPE_PEEL(func, func2, type) \ - type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \ + type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7; \ MICRO_MMA_TYPE_PEEL(func,func2,type,0); MICRO_MMA_TYPE_PEEL(func,func2,type,1); \ MICRO_MMA_TYPE_PEEL(func,func2,type,2); MICRO_MMA_TYPE_PEEL(func,func2,type,3); \ MICRO_MMA_TYPE_PEEL(func,func2,type,4); MICRO_MMA_TYPE_PEEL(func,func2,type,5); \ - MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); \ - MICRO_MMA_TYPE_PEEL(func,func2,type,8); MICRO_MMA_TYPE_PEEL(func,func2,type,9); + MICRO_MMA_TYPE_PEEL(func,func2,type,6); MICRO_MMA_TYPE_PEEL(func,func2,type,7); #define MICRO_MMA_UNROLL_TYPE_ONE(func, func2, type) \ type rhsV0; \ @@ -224,7 +223,7 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&) #define MICRO_MMA_SRC_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols + accCols*offsetA; \ + lhs_ptr##iter = lhs_base + ( (row/accCols) + iter )*strideA*accCols; \ } else { \ EIGEN_UNUSED_VARIABLE(lhs_ptr##iter); \ } @@ -240,21 +239,19 @@ EIGEN_ALWAYS_INLINE void ploadRhsMMA(const float*, __vector_pair&) #define MICRO_MMA_STORE_ONE(iter) \ if (unroll_factor > iter) { \ - storeAccumulator(row + iter*accCols, col, res, pAlpha, &accZero##iter); \ + storeAccumulator(row + iter*accCols, res, pAlpha, &accZero##iter); \ } #define MICRO_MMA_STORE MICRO_MMA_UNROLL(MICRO_MMA_STORE_ONE) template -EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration( +EIGEN_ALWAYS_INLINE void gemm_unrolled_MMA_iteration( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, Index depth, Index strideA, - Index offsetA, Index& row, - Index col, const Packet& pAlpha) { const Scalar* rhs_ptr = rhs_base; @@ -280,94 +277,98 @@ EIGEN_STRONG_INLINE void gemm_unrolled_MMA_iteration( row += unroll_factor*accCols; } -template -void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +template +EIGEN_ALWAYS_INLINE void gemmMMA_cols( + const DataMapper& res, + const Scalar* blockA, + const Scalar* blockB, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index offsetB, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlpha, + const Packet& pMask) { - const Index remaining_rows = rows % accCols; - const Index remaining_cols = cols % accRows; - - if( strideA == -1 ) strideA = depth; - if( strideB == -1 ) strideB = depth; - - const Packet pAlpha = pset1(alpha); - const Packet pMask = bmask((const int)(remaining_rows)); + const DataMapper res3 = res.getSubMapper(0, col); - Index col = 0; - for(; col + accRows <= cols; col += accRows) - { - const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; - const Scalar* lhs_base = blockA; + const Scalar* rhs_base = blockB + col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA + accCols*offsetA; + Index row = 0; - Index row = 0; #define MAX_MMA_UNROLL 7 - while(row + MAX_MMA_UNROLL*accCols <= rows) { - gemm_unrolled_MMA_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - } - switch( (rows-row)/accCols ) { + while(row + MAX_MMA_UNROLL*accCols <= rows) { + gemm_unrolled_MMA_iteration(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + } + switch( (rows-row)/accCols ) { #if MAX_MMA_UNROLL > 7 - case 7: - gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; + case 7: + gemm_unrolled_MMA_iteration<7, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif #if MAX_MMA_UNROLL > 6 - case 6: - gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; + case 6: + gemm_unrolled_MMA_iteration<6, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif #if MAX_MMA_UNROLL > 5 - case 5: - gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; + case 5: + gemm_unrolled_MMA_iteration<5, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif #if MAX_MMA_UNROLL > 4 - case 4: - gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; + case 4: + gemm_unrolled_MMA_iteration<4, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif #if MAX_MMA_UNROLL > 3 - case 3: - gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; + case 3: + gemm_unrolled_MMA_iteration<3, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif #if MAX_MMA_UNROLL > 2 - case 2: - gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; + case 2: + gemm_unrolled_MMA_iteration<2, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif #if MAX_MMA_UNROLL > 1 - case 1: - gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, pAlpha); - break; + case 1: + gemm_unrolled_MMA_iteration<1, Scalar, Packet, RhsPacket, DataMapper, Index, accRows, accCols>(res3, lhs_base, rhs_base, depth, strideA, row, pAlpha); + break; #endif - default: - break; - } + default: + break; + } #undef MAX_MMA_UNROLL - if(remaining_rows > 0) - { - gemm_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); - } - } + if(remaining_rows > 0) + { + gemm_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, row, col, rows, cols, remaining_rows, pAlpha, pMask); + } +} - if(remaining_cols > 0) - { - const Scalar* rhs_base = blockB + col*strideB + remaining_cols*offsetB; - const Scalar* lhs_base = blockA; +template +void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, Index rows, Index depth, Index cols, Scalar alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) +{ + const Index remaining_rows = rows % accCols; - for(; col < cols; col++) - { - Index row = 0; + if( strideA == -1 ) strideA = depth; + if( strideB == -1 ) strideB = depth; - gemm_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, rows, col, remaining_cols, pAlpha); + const Packet pAlpha = pset1(alpha); + const Packet pMask = bmask((const int)(remaining_rows)); - if (remaining_rows > 0) - { - gemm_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, row, col, remaining_rows, remaining_cols, pAlpha); - } - rhs_base++; - } + Index col = 0; + for(; col + accRows <= cols; col += accRows) + { + gemmMMA_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); } + + gemm_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlpha, pMask); } #define accColsC (accCols / 2) @@ -375,21 +376,20 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #define advanceCols ((RhsIsReal) ? 1 : 2) // PEEL_COMPLEX_MMA loop factor. -#define PEEL_COMPLEX_MMA 7 +#define PEEL_COMPLEX_MMA 3 #define MICRO_COMPLEX_MMA_UNROLL(func) \ - func(0) func(1) func(2) func(3) func(4) + func(0) func(1) func(2) func(3) #define MICRO_COMPLEX_MMA_LOAD_ONE(iter) \ if (unroll_factor > iter) { \ lhsV##iter = ploadLhs(lhs_ptr_real##iter); \ - lhs_ptr_real##iter += accCols; \ if(!LhsIsReal) { \ - lhsVi##iter = ploadLhs(lhs_ptr_imag##iter); \ - lhs_ptr_imag##iter += accCols; \ + lhsVi##iter = ploadLhs(lhs_ptr_real##iter + imag_delta); \ } else { \ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ } \ + lhs_ptr_real##iter += accCols; \ } else { \ EIGEN_UNUSED_VARIABLE(lhsV##iter); \ EIGEN_UNUSED_VARIABLE(lhsVi##iter); \ @@ -402,8 +402,8 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #define MICRO_COMPLEX_MMA_TYPE_PEEL(func, func2, type, peel) \ if (PEEL_COMPLEX_MMA > peel) { \ - Packet lhsV0, lhsV1, lhsV2, lhsV3, lhsV4; \ - Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3, lhsVi4; \ + Packet lhsV0, lhsV1, lhsV2, lhsV3; \ + Packet lhsVi0, lhsVi1, lhsVi2, lhsVi3; \ ploadRhsMMA(rhs_ptr_real + (accRows * peel), rhsV##peel); \ if(!RhsIsReal) { \ ploadRhsMMA(rhs_ptr_imag + (accRows * peel), rhsVi##peel); \ @@ -411,20 +411,17 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ } \ MICRO_COMPLEX_MMA_UNROLL(func2); \ - func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) func(4,type,peel) \ + func(0,type,peel) func(1,type,peel) func(2,type,peel) func(3,type,peel) \ } else { \ EIGEN_UNUSED_VARIABLE(rhsV##peel); \ EIGEN_UNUSED_VARIABLE(rhsVi##peel); \ } #define MICRO_COMPLEX_MMA_UNROLL_TYPE_PEEL(func, func2, type) \ - type rhsV0, rhsV1, rhsV2, rhsV3, rhsV4, rhsV5, rhsV6, rhsV7, rhsV8, rhsV9; \ - type rhsVi0, rhsVi1, rhsVi2, rhsVi3, rhsVi4, rhsVi5, rhsVi6, rhsVi7, rhsVi8, rhsVi9; \ + type rhsV0, rhsV1, rhsV2, rhsV3; \ + type rhsVi0, rhsVi1, rhsVi2, rhsVi3; \ MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,0); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,1); \ - MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); \ - MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,4); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,5); \ - MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,6); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,7); \ - MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,8); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,9); + MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,2); MICRO_COMPLEX_MMA_TYPE_PEEL(func,func2,type,3); #define MICRO_COMPLEX_MMA_UNROLL_TYPE_ONE(func, func2, type) \ type rhsV0, rhsVi0; \ @@ -461,15 +458,9 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #define MICRO_COMPLEX_MMA_SRC_PTR_ONE(iter) \ if (unroll_factor > iter) { \ - lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols + accCols*offsetA; \ - if(!LhsIsReal) { \ - lhs_ptr_imag##iter = lhs_ptr_real##iter + accCols*strideA; \ - } else { \ - EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ - } \ + lhs_ptr_real##iter = lhs_base + ( ((advanceRows*row)/accCols) + iter*advanceRows )*strideA*accCols; \ } else { \ EIGEN_UNUSED_VARIABLE(lhs_ptr_real##iter); \ - EIGEN_UNUSED_VARIABLE(lhs_ptr_imag##iter); \ } #define MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_SRC_PTR_ONE) @@ -477,45 +468,40 @@ void gemmMMA(const DataMapper& res, const Scalar* blockA, const Scalar* blockB, #define MICRO_COMPLEX_MMA_PREFETCH_ONE(iter) \ if (unroll_factor > iter) { \ EIGEN_POWER_PREFETCH(lhs_ptr_real##iter); \ - if(!LhsIsReal) { \ - EIGEN_POWER_PREFETCH(lhs_ptr_imag##iter); \ - } \ } #define MICRO_COMPLEX_MMA_PREFETCH MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_PREFETCH_ONE) #define MICRO_COMPLEX_MMA_STORE_ONE(iter) \ if (unroll_factor > iter) { \ - storeComplexAccumulator(row + iter*accCols, col, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \ + storeComplexAccumulator(row + iter*accCols, res, pAlphaReal, pAlphaImag, &accReal##iter, &accImag##iter); \ } #define MICRO_COMPLEX_MMA_STORE MICRO_COMPLEX_MMA_UNROLL(MICRO_COMPLEX_MMA_STORE_ONE) template -EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration( +EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration( const DataMapper& res, const Scalar* lhs_base, const Scalar* rhs_base, Index depth, Index strideA, - Index offsetA, Index strideB, Index& row, - Index col, const Packet& pAlphaReal, const Packet& pAlphaImag) { const Scalar* rhs_ptr_real = rhs_base; const Scalar* rhs_ptr_imag; + const Index imag_delta = accCols*strideA; if(!RhsIsReal) { rhs_ptr_imag = rhs_base + accRows*strideB; } else { EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); } - const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_imag0 = NULL, * lhs_ptr_real1 = NULL, * lhs_ptr_imag1 = NULL; - const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_imag2 = NULL, * lhs_ptr_real3 = NULL, * lhs_ptr_imag3 = NULL; - const Scalar* lhs_ptr_real4 = NULL, * lhs_ptr_imag4 = NULL; - __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3, accReal4, accImag4; + const Scalar* lhs_ptr_real0 = NULL, * lhs_ptr_real1 = NULL; + const Scalar* lhs_ptr_real2 = NULL, * lhs_ptr_real3 = NULL; + __vector_quad accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3; MICRO_COMPLEX_MMA_SRC_PTR MICRO_COMPLEX_MMA_DST_PTR @@ -539,11 +525,70 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_MMA_iteration( row += unroll_factor*accCols; } +template +EIGEN_ALWAYS_INLINE void gemmMMA_complex_cols( + const DataMapper& res, + const Scalar* blockA, + const Scalar* blockB, + Index depth, + Index strideA, + Index offsetA, + Index strideB, + Index offsetB, + Index col, + Index rows, + Index cols, + Index remaining_rows, + const Packet& pAlphaReal, + const Packet& pAlphaImag, + const Packet& pMask) +{ + const DataMapper res3 = res.getSubMapper(0, col); + + const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; + const Scalar* lhs_base = blockA + accCols*offsetA; + Index row = 0; + +#define MAX_COMPLEX_MMA_UNROLL 4 + while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) { + gemm_complex_unrolled_MMA_iteration(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + } + switch( (rows-row)/accCols ) { +#if MAX_COMPLEX_MMA_UNROLL > 4 + case 4: + gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_MMA_UNROLL > 3 + case 3: + gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_MMA_UNROLL > 2 + case 2: + gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + break; +#endif +#if MAX_COMPLEX_MMA_UNROLL > 1 + case 1: + gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res3, lhs_base, rhs_base, depth, strideA, strideB, row, pAlphaReal, pAlphaImag); + break; +#endif + default: + break; + } +#undef MAX_COMPLEX_MMA_UNROLL + + if(remaining_rows > 0) + { + gemm_complex_extra_row(res3, blockA, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); + } +} + template void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsScalar* blockBc, Index rows, Index depth, Index cols, Scalarc alpha, Index strideA, Index strideB, Index offsetA, Index offsetB) { const Index remaining_rows = rows % accCols; - const Index remaining_cols = cols % accRows; if( strideA == -1 ) strideA = depth; if( strideB == -1 ) strideB = depth; @@ -558,64 +603,10 @@ void gemm_complexMMA(const DataMapper& res, const LhsScalar* blockAc, const RhsS Index col = 0; for(; col + accRows <= cols; col += accRows) { - const Scalar* rhs_base = blockB + advanceCols*col*strideB + accRows*offsetB; - const Scalar* lhs_base = blockA; - Index row = 0; - -#define MAX_COMPLEX_MMA_UNROLL 4 - while(row + MAX_COMPLEX_MMA_UNROLL*accCols <= rows) { - gemm_complex_unrolled_MMA_iteration(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - } - switch( (rows-row)/accCols ) { -#if MAX_COMPLEX_MMA_UNROLL > 4 - case 4: - gemm_complex_unrolled_MMA_iteration<4, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - break; -#endif -#if MAX_COMPLEX_MMA_UNROLL > 3 - case 3: - gemm_complex_unrolled_MMA_iteration<3, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - break; -#endif -#if MAX_COMPLEX_MMA_UNROLL > 2 - case 2: - gemm_complex_unrolled_MMA_iteration<2, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - break; -#endif -#if MAX_COMPLEX_MMA_UNROLL > 1 - case 1: - gemm_complex_unrolled_MMA_iteration<1, Scalar, Packet, Packetc, RhsPacket, DataMapper, Index, accRows, accCols, ConjugateLhs, ConjugateRhs, LhsIsReal, RhsIsReal>(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, pAlphaReal, pAlphaImag); - break; -#endif - default: - break; - } -#undef MAX_COMPLEX_MMA_UNROLL - - if(remaining_rows > 0) - { - gemm_complex_extra_row(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); - } + gemmMMA_complex_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); } - if(remaining_cols > 0) - { - const Scalar* rhs_base = blockB + advanceCols*col*strideB + remaining_cols*offsetB; - const Scalar* lhs_base = blockA; - - for(; col < cols; col++) - { - Index row = 0; - - gemm_complex_unrolled_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, rows, col, remaining_cols, pAlphaReal, pAlphaImag); - - if (remaining_rows > 0) - { - gemm_complex_extra_col(res, lhs_base, rhs_base, depth, strideA, offsetA, strideB, row, col, remaining_rows, remaining_cols, pAlphaReal, pAlphaImag); - } - rhs_base++; - } - } + gemm_complex_extra_cols(res, blockA, blockB, depth, strideA, offsetA, strideB, offsetB, col, rows, cols, remaining_rows, pAlphaReal, pAlphaImag, pMask); } #undef accColsC -- GitLab From efdb8ac4662beb0c171adec9d36bbb8a6269488b Mon Sep 17 00:00:00 2001 From: Chip-Kerchner Date: Tue, 26 Oct 2021 16:42:23 -0500 Subject: [PATCH 2/2] Fix used uninitialized warnings. --- Eigen/src/Core/arch/AltiVec/MatrixProduct.h | 6 +++--- Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h index bd5da3623..3745a87cb 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProduct.h @@ -1851,11 +1851,11 @@ EIGEN_ALWAYS_INLINE void gemm_unrolled_complex_row_iteration( const Packet& pMask) { const Scalar* rhs_ptr_real = rhs_base; - const Scalar* rhs_ptr_imag; + const Scalar* rhs_ptr_imag = NULL; if(!RhsIsReal) rhs_ptr_imag = rhs_base + accRows*strideB; else EIGEN_UNUSED_VARIABLE(rhs_ptr_imag); const Scalar* lhs_ptr_real = lhs_base + advanceRows*row*strideA + remaining_rows*offsetA; - const Scalar* lhs_ptr_imag; + const Scalar* lhs_ptr_imag = NULL; if(!LhsIsReal) lhs_ptr_imag = lhs_ptr_real + remaining_rows*strideA; else EIGEN_UNUSED_VARIABLE(lhs_ptr_imag); PacketBlock accReal0, accImag0, accReal1, accImag1, accReal2, accImag2, accReal3, accImag3; @@ -2078,7 +2078,7 @@ EIGEN_STRONG_INLINE void gemm_complex_unrolled_iteration( const Packet& pAlphaImag) { const Scalar* rhs_ptr_real = rhs_base; - const Scalar* rhs_ptr_imag; + const Scalar* rhs_ptr_imag = NULL; const Index imag_delta = accCols*strideA; if(!RhsIsReal) { rhs_ptr_imag = rhs_base + accRows*strideB; diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h index e18b7f267..9a3132276 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMA.h @@ -492,7 +492,7 @@ EIGEN_ALWAYS_INLINE void gemm_complex_unrolled_MMA_iteration( const Packet& pAlphaImag) { const Scalar* rhs_ptr_real = rhs_base; - const Scalar* rhs_ptr_imag; + const Scalar* rhs_ptr_imag = NULL; const Index imag_delta = accCols*strideA; if(!RhsIsReal) { rhs_ptr_imag = rhs_base + accRows*strideB; -- GitLab