aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-04-13 13:44:10 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commite7e96e09ff0d3e47797adf197aff2bc39671788c (patch)
treeb52ecdd7627bdf51b8b8da9b9553cb900460222f /src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
parent1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 (diff)
downloadComputeLibrary-e7e96e09ff0d3e47797adf197aff2bc39671788c.tar.gz
COMPMID-1054 Update RSH's GEMM to add batch+multi support
Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp')
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp91
1 files changed, 61 insertions, 30 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 0df331acb4..770ee033c8 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -28,17 +28,18 @@
#include "arm_gemm.hpp"
#include "mergeresults.hpp"
-#include "profiler.hpp"
#include "transform.hpp"
+#ifdef CYCLE_PROFILING
+#include "profiler.hpp"
+#endif
+
namespace arm_gemm
{
// Implementation of the GemmCommon abstract class.
//
-// This is implementation is for GEMV with a transposed matrix.
-//
-// By default the source data is used in-place, but if type conversion is
-// needed we need to allocate working space (CURRENTLY NOT IMPLEMENTED).
+// This is implementation is for GEMV with pretransposition.
+// batches are not supported as a batched GEMV makes no sense (can be converted to a GEMM).
template <typename strategy, typename To, typename Tr>
class GemvPretransposed : public GemmCommon<To, Tr>
@@ -48,12 +49,14 @@ class GemvPretransposed : public GemmCommon<To, Tr>
const unsigned int _Nsize;
const unsigned int _Ksize;
+ const unsigned int _nmultis;
const bool _trB;
const Tr _beta;
const CPUInfo *const _ci;
+ const unsigned int _buffer_per_multi;
unsigned int m_block = 0;
unsigned int n_block = 0;
@@ -64,44 +67,64 @@ public:
GemvPretransposed(GemvPretransposed &) = delete;
GemvPretransposed &operator=(GemvPretransposed &) = delete;
- GemvPretransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const bool trB, const Tr beta)
- : _Nsize(N), _Ksize(K), _trB(trB), _beta(beta), _ci(ci)
+ GemvPretransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const bool trB, const Tr beta)
+ : _Nsize(N), _Ksize(K), _nmultis(nmultis), _trB(trB), _beta(beta), _ci(ci), _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave)
{
/* For now don't do any blocking. TODO: figure out if we should. */
m_block = K;
n_block = N;
}
- // Window is number of out_width blocks.
+ // Window is number of out_width blocks, times number of multis.
unsigned int get_window_size() const override
{
- return iceildiv(_Nsize, strategy::out_width);
+ return iceildiv(_Nsize, strategy::out_width) * _nmultis;
}
// Actually execute the GEMV.
void execute(unsigned int start, unsigned int end, int) override
{
+#ifdef CYCLE_PROFILING
profiler prof;
+#endif
+
strategy strat(_ci);
- unsigned int N_start = start * strategy::out_width;
- unsigned int N_end = std::min(end * strategy::out_width, _Nsize);
+ /* Break the window values down into multis of interest... */
+ const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width);
+ const unsigned int multi_0 = start / window_per_multi;
+ const unsigned int multi_end = end / window_per_multi;
+
+ /* ... and figure out where we start and end in the first and last multi. */
+ const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width;
+ const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width;
static_assert(std::is_same<Tr, Tri>::value, "GemvPretransposed: Result types must be the same.");
- for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
+ for(unsigned int multi = multi_0; multi <= multi_end; multi++)
{
- unsigned int mmax = std::min(m0 + m_block, _Ksize);
+ const unsigned int n_start = (multi == multi_0) ? n_0 : 0;
+ const unsigned int n_end = (multi == multi_end) ? n_max : _Nsize;
- for(unsigned int n0 = N_start; n0 < N_end; n0 += n_block)
- {
- unsigned int nmax = std::min(n0 + n_block, N_end);
+ if(n_end <= n_start)
+ continue;
- prof(PROFILE_KERNEL, ((mmax - m0) * (nmax - n0)), [&](void)
+ for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
+ {
+ unsigned int mmax = std::min(m0 + m_block, _Ksize);
+ for(unsigned int n = n_start; n < n_end; n += n_block)
{
+ unsigned int nmax = std::min(n + n_block, n_end);
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax - m0) * (nmax - n));
+#endif
/* This assumes that the underlying call was a GEMM with M=1; for the N=1 case we would have to pick up this->_Bptr below instead */
- strat.kernel(_A_pretransposed + (n0 * _Ksize) + (m0 * strategy::A_interleave), (_Ksize * strategy::A_interleave), this->_Aptr + m0, this->_Cptr + n0, _beta, (mmax - m0), (nmax - n0));
- });
+ strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave),
+ (_Ksize * strategy::A_interleave),
+ this->_Aptr + (multi * this->_A_multi_stride) + m0,
+ this->_Cptr + (multi * this->_C_multi_stride) + n,
+ _beta, (mmax - m0), (nmax - n));
+ }
}
}
}
@@ -120,27 +143,35 @@ public:
size_t get_B_pretransposed_array_size() const override
{
- return _Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave * sizeof(float);
+ return _buffer_per_multi * _nmultis * sizeof(To);
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb) override
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override
{
Toi *A_buffer = reinterpret_cast<Toi *>(buffer);
- /* Reverse sense here as we are dealing with B rather than A. So if
- * strategy::A_transpose is false and _trB is false, we still
- * transpose. */
- if(_trB ^ strategy::A_transpose)
+ for(unsigned int multi = 0; multi < _nmultis; multi++)
{
- Transform<strategy::A_interleave, strategy::A_block, false>(A_buffer, B, ldb, 0, _Nsize, 0, _Ksize);
- }
- else
- {
- Transform<strategy::A_interleave, strategy::A_block, true>(A_buffer, B, ldb, 0, _Nsize, 0, _Ksize);
+ /* Reverse sense here as we are dealing with B rather than A. So if
+ * strategy::A_transpose is false and _trB is false, we still
+ * transpose. */
+ if(_trB ^ strategy::A_transpose)
+ {
+ Transform<strategy::A_interleave, strategy::A_block, false>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
+ }
+ else
+ {
+ Transform<strategy::A_interleave, strategy::A_block, true>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
+ }
}
_A_pretransposed = A_buffer;
}
+
+ void set_pretransposed_B_data(void *buffer) override
+ {
+ _A_pretransposed = reinterpret_cast<Toi *>(buffer);
+ }
};
} // namespace arm_gemm