aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2018-04-13 13:44:10 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:49:37 +0000
commite7e96e09ff0d3e47797adf197aff2bc39671788c (patch)
treeb52ecdd7627bdf51b8b8da9b9553cb900460222f
parent1ed1fc6d3b7d8494ce3bbc5f8b46bfde6fc586f9 (diff)
downloadComputeLibrary-e7e96e09ff0d3e47797adf197aff2bc39671788c.tar.gz
COMPMID-1054 Update RSH's GEMM to add batch+multi support
Change-Id: Ib9d91b77f1d51976da4449fa1e6eeeffae307353 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/127876 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Pablo Tello <pablo.tello@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
-rw-r--r--arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp7
-rw-r--r--arm_compute/core/NEON/kernels/assembly/gemm_common.hpp18
-rw-r--r--arm_compute/runtime/NEON/AssemblyHelper.h32
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_batched.hpp106
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp7
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp19
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int16.cpp3
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp5
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp216
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_native.hpp54
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp3
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp5
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp50
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp91
-rw-r--r--src/core/NEON/kernels/arm_gemm/profiler.hpp43
15 files changed, 489 insertions, 170 deletions
diff --git a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
index d6c9931a21..0a541c6db9 100644
--- a/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/arm_gemm.hpp
@@ -34,6 +34,9 @@ template<typename Top, typename Tret>
using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret> >;
template<typename Top, typename Tret>
-UniqueGemmCommon<Top, Tret> gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K, const bool trA, const bool trB, const Tret alpha, const Tret beta, const int maxthreads, const bool pretransposed_hint);
-
+UniqueGemmCommon<Top, Tret> gemm(const CPUInfo &ci,
+ const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
+ const bool trA, const bool trB, const Tret alpha, const Tret beta,
+ const int maxthreads, const bool pretransposed_hint);
} // namespace arm_gemm
diff --git a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
index 7f47abcbb9..3919c339bf 100644
--- a/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
+++ b/arm_compute/core/NEON/kernels/assembly/gemm_common.hpp
@@ -39,23 +39,35 @@ class GemmCommon {
protected:
const To *_Aptr=nullptr;
int _lda=0;
+ int _A_batch_stride=0;
+ int _A_multi_stride=0;
const To *_Bptr=nullptr;
int _ldb=0;
+ int _B_multi_stride=0;
Tr *_Cptr=nullptr;
int _ldc=0;
+ int _C_batch_stride=0;
+ int _C_multi_stride=0;
public:
/* Pass in the pointers to the arrays to be operated on and their
* strides. This has a default implementation that just captures them
* all in protected members. If B is pretransposed (see below) then the
* settings for B here are ignored. */
- virtual void set_arrays(const To *A, const int lda, const To *B, const int ldb, Tr *C, const int ldc) {
+ virtual void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+ const To *B, const int ldb, /* batches share B */ const int B_multi_stride,
+ Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) {
_Aptr = A;
_lda = lda;
+ _A_batch_stride = A_batch_stride;
+ _A_multi_stride = A_multi_stride;
_Bptr = B;
_ldb = ldb;
+ _B_multi_stride = B_multi_stride;
_Cptr = C;
_ldc = ldc;
+ _C_batch_stride = C_batch_stride;
+ _C_multi_stride = C_multi_stride;
}
/* For threading, we divide the work into some number of units and work
@@ -95,7 +107,9 @@ public:
/* Total number of bytes of space needed for pretransposed arrays. */
virtual size_t get_B_pretransposed_array_size() const { return 0; }
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
- virtual void pretranspose_B_array(void *buffer, const To *, const int) { };
+ virtual void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) { };
+ /* Set pretransposed data - the void * passed in must previously have been passed to pretranspose_B_array() for the same or a similar GEMM. */
+ virtual void set_pretransposed_B_data(void *buffer) { }
// Destructor
virtual ~GemmCommon() { }
diff --git a/arm_compute/runtime/NEON/AssemblyHelper.h b/arm_compute/runtime/NEON/AssemblyHelper.h
index 40f28587c2..2b4f35f2e1 100644
--- a/arm_compute/runtime/NEON/AssemblyHelper.h
+++ b/arm_compute/runtime/NEON/AssemblyHelper.h
@@ -82,24 +82,19 @@ public:
const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
- // Configure kernel window
- Window window = calculate_max_window(*_d->info());
+ const int batch_stride_a = _a->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput);
+
+ const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+ const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput);
+
+ const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer());
const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer());
+ auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer());
- // Only iterate over batches
- Window win(window);
- win.set(0, Window::Dimension(0, 1, 1));
- win.set(1, Window::Dimension(0, 1, 1));
- Iterator in0(_a, window);
- Iterator out(_d, window);
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto in0_ptr = reinterpret_cast<const TypeInput *>(in0.ptr());
- auto out_ptr = reinterpret_cast<TypeOutput *>(out.ptr());
- _gemm_kernel_asm->set_arrays(in0_ptr, lda, in1_ptr, ldb, out_ptr, ldd);
- NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
- },
- in0, out);
+ _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr, ldd, batch_stride_d, multi_stride_d);
+ NEScheduler::get().schedule(_optimised_kernel.get(), Window::DimX);
}
};
@@ -146,10 +141,13 @@ inline bool setup_assembly_kernel(const ITensor *a, const ITensor *b, ITensor *d
const int M = d->info()->tensor_shape().y();
const int N = d->info()->tensor_shape().x();
const int K = a->info()->tensor_shape().x();
+ const int batches = a->info()->tensor_shape().total_size_upper(2);
+ const int multis = b->info()->tensor_shape().z();
unsigned int num_threads = NEScheduler::get().num_threads();
+
// unique_ptr to a Gemm object
std::unique_ptr<typename T::AssemblyGemm>
- asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, false, false, alpha, beta, num_threads, false));
+ asm_gemm(arm_gemm::gemm<typename T::TypeOperator, typename T::TypeResult>(ci, M, N, K, batches, multis, false, false, alpha, beta, num_threads, false));
// arm_compute wrapper for the Gemm object (see above)
std::unique_ptr<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>
acl_gemm_wrapper = support::cpp14::make_unique<NEGEMMAssemblyWrapper<typename T::AssemblyGemm>>();
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_batched.hpp b/src/core/NEON/kernels/arm_gemm/gemm_batched.hpp
new file mode 100644
index 0000000000..385358f615
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/gemm_batched.hpp
@@ -0,0 +1,106 @@
+/*
+ * Copyright (c) 2017-2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+#include "arm_gemm.hpp"
+
+namespace arm_gemm
+{
+template <typename To, typename Tr>
+class GemmBatched : public GemmCommon<To, Tr>
+{
+private:
+ UniqueGemmCommon<To, Tr> _subgemm = nullptr;
+
+public:
+ GemmBatched(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
+ const To alpha, const To beta, const int maxthreads, const bool pretransposed_hint)
+ {
+ /* Just create a subgemm with batches->M */
+ _subgemm = gemm<To, Tr>(ci, nbatches, N, K, 1, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint);
+ }
+
+ void set_arrays(const To *A, const int lda, const int A_batch_stride, const int A_multi_stride,
+ const To *B, const int ldb, const int B_multi_stride,
+ Tr *C, const int ldc, const int C_batch_stride, const int C_multi_stride) override
+ {
+ /* A and C's batch stride becomes their new row stride. New batch stride is 0 as nbatches for subgemm is always 1. */
+ _subgemm->set_arrays(A, A_batch_stride, 0, A_multi_stride,
+ B, ldb, B_multi_stride,
+ C, C_batch_stride, 0, C_multi_stride);
+ }
+
+ unsigned int get_window_size() const override
+ {
+ return _subgemm->get_window_size();
+ }
+
+ void set_nthreads(int nthreads) override
+ {
+ _subgemm->set_nthreads(nthreads);
+ }
+
+ void execute(unsigned int start, unsigned int end, int threadid) override
+ {
+ _subgemm->execute(start, end, threadid);
+ }
+
+ size_t get_working_size() const override
+ {
+ return _subgemm->get_working_size();
+ }
+
+ void set_working_space(void *space) override
+ {
+ _subgemm->set_working_space(space);
+ }
+
+ bool B_is_pretransposed() const override
+ {
+ return _subgemm->B_is_pretransposed();
+ }
+
+ bool B_pretranspose_required() const override
+ {
+ return _subgemm->B_pretranspose_required();
+ }
+
+ size_t get_B_pretransposed_array_size() const override
+ {
+ return _subgemm->get_B_pretransposed_array_size();
+ }
+
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override
+ {
+ _subgemm->pretranspose_B_array(buffer, B, ldb, B_multi_stride);
+ }
+
+ void set_pretransposed_B_data(void *buffer) override
+ {
+ _subgemm->set_pretransposed_B_data(buffer);
+ }
+};
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index 484892dc81..d1180b13cb 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -38,6 +38,7 @@ namespace arm_gemm
{
template <>
UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const __fp16 alpha, const __fp16 beta,
const int maxthreads, const bool pretransposed_hint)
{
@@ -56,15 +57,15 @@ UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, c
// If FP16 is supported, use it.
if(use_fp16)
{
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<hgemm_24x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
#endif
// Fallback to using the blocked SGEMM kernel.
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_12x8, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#else
// For AArch32, only support the SGEMM route for now.
- return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<__fp16, __fp16>(new GemmInterleaved<sgemm_8x6, __fp16, __fp16>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#endif
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index a5b41cac2f..43df1aa779 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -22,6 +22,7 @@
* SOFTWARE.
*/
#include "arm_gemm.hpp"
+#include "gemm_batched.hpp"
#include "gemm_common.hpp"
#include "gemm_interleaved.hpp"
#include "gemm_native.hpp"
@@ -38,21 +39,27 @@ namespace arm_gemm
{
template <>
UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const float alpha, const float beta,
const int maxthreads, const bool pretransposed_hint)
{
+ /* Handle "batched GEMM" */
+ if(M == 1 && nbatches > 1)
+ {
+ return UniqueGemmCommon<float, float>(new GemmBatched<float, float>(ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ }
#ifdef __aarch64__
/* Cases in priority order */
- /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set */
+ /* GemvPretransposed: requires M=1, alpha=1, and transposed hint set. nbatches must be 1 or we would have returned above so don't test. */
if(M == 1 && alpha == 1.0f && pretransposed_hint)
{
- return UniqueGemmCommon<float, float>(new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, trB, beta));
+ return UniqueGemmCommon<float, float>(new GemvPretransposed<sgemv_pretransposed, float, float>(&ci, N, K, nmulti, trB, beta));
}
/* GemvNativeTransposed: requires M=1, no trA or trB, doesn't handle alpha */
if(M == 1 && alpha == 1.0f && !trA && !trB)
{
- return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, beta));
+ return UniqueGemmCommon<float, float>(new GemvNativeTransposed<sgemv_trans, float, float>(&ci, N, K, nmulti, beta));
}
/* Native GEMM: requires M to be a multiple of 4, K at least 4, N a
@@ -60,13 +67,13 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig
* sizes. */
if(N <= 128 && K <= 128 && ((M % 4) == 0) && (K >= 4) && ((N % 16) == 0) && alpha == 1.0f)
{
- return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, beta));
+ return UniqueGemmCommon<float, float>(new GemmNative<sgemm_native_16x4, float, float>(&ci, M, N, K, nbatches, nmulti, beta));
}
/* Blocked GEMM, handles all cases. */
- return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_12x8, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#else
- return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<float, float>(new GemmInterleaved<sgemm_8x6, float, float>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
#endif
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index 344bfed12b..7669fe0ff1 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -33,10 +33,11 @@ namespace arm_gemm
{
template <>
UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const int32_t alpha, const int32_t beta,
const int maxthreads, const bool pretransposed_hint)
{
- return UniqueGemmCommon<int16_t, int32_t>(new GemmInterleaved<gemm_s16_12x8, int16_t, int32_t>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<int16_t, int32_t>(new GemmInterleaved<gemm_s16_12x8, int16_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
// Instantiate static class members
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 856d407cfa..f13406284c 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -35,16 +35,17 @@ namespace arm_gemm
{
template <>
UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const int32_t alpha, const int32_t beta,
const int maxthreads, const bool pretransposed_hint)
{
if(ci.has_dotprod())
{
// Dot product supporting CPUs. This family has a special version for A55r1.
- return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_12x8, int8_t, int32_t>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_12x8, int8_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
- return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_4x4, int8_t, int32_t>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<int8_t, int32_t>(new GemmInterleaved<gemm_s8_4x4, int8_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
// TODO: There's a better approach for A53, but it doesn't work
// well on heterogeneous systems as the required data formats
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index 27e4e8d411..32c65cd3fb 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -33,9 +33,12 @@
#include "buffer_manager.hpp"
#include "mergeresults.hpp"
-#include "profiler.hpp"
#include "transform.hpp"
+#ifdef CYCLE_PROFILING
+#include "profiler.hpp"
+#endif
+
// Some macros used to decide how much working space to allocate.
// Round allocations up to the next cache line.
#define ALLOC_ROUND 64
@@ -60,6 +63,9 @@ class GemmInterleaved : public GemmCommon<To, Tr>
const unsigned int _Nsize;
const unsigned int _Ksize;
+ const unsigned int _nbatches;
+ const unsigned int _nmulti;
+
const bool _trA;
const bool _trB;
@@ -84,30 +90,31 @@ class GemmInterleaved : public GemmCommon<To, Tr>
class blockwalker
{
private:
- /* Loop parameters, we only block up N and K so don't worry about M. */
- const unsigned int _Nsize, _Ksize, _x_block, _k_block;
+ /* Size loops, etc. based on our parent's configuration */
+ const GemmInterleaved<strategy, To, Tr> &_parent;
- /* K and X parameters for current iteration. */
- unsigned int _k0 = 0, _x0 = 0;
+ /* K and X and multi parameters for current iteration. */
+ unsigned int _k0 = 0, _x0 = 0, _multi = 0;
unsigned int _index = 0;
bool _done = false;
bool _newkblock = true;
+ bool _newmulti = true;
public:
- blockwalker(const unsigned int K, const unsigned int k_block, const unsigned int N, const unsigned int x_block)
- : _Nsize(N), _Ksize(K), _x_block(x_block), _k_block(k_block)
+ blockwalker(const GemmInterleaved<strategy, To, Tr> &parent)
+ : _parent(parent)
{
}
unsigned int xmax()
{
- return std::min(_x0 + _x_block, _Nsize);
+ return std::min(_x0 + _parent._x_block, _parent._Nsize);
}
unsigned int kmax()
{
- return std::min(_k0 + _k_block, _Ksize);
+ return std::min(_k0 + _parent._k_block, _parent._Ksize);
}
/* Advance to the next block, return false at the end. */
@@ -119,15 +126,21 @@ class GemmInterleaved : public GemmCommon<To, Tr>
}
_newkblock = false;
- _x0 += _x_block;
- if(_x0 >= _Nsize)
+ _x0 += _parent._x_block;
+ if(_x0 >= _parent._Nsize)
{
_x0 = 0;
- _k0 += _k_block;
- if(_k0 >= _Ksize)
+ _k0 += _parent._k_block;
+ if(_k0 >= _parent._Ksize)
{
- _done = true;
- return false;
+ _k0 = 0;
+ _multi++;
+ if(_multi >= _parent._nmulti)
+ {
+ _done = true;
+ return false;
+ }
+ _newmulti = true;
}
_newkblock = true;
}
@@ -144,6 +157,10 @@ class GemmInterleaved : public GemmCommon<To, Tr>
{
return _x0;
}
+ unsigned int multi(void)
+ {
+ return _multi;
+ }
unsigned int index(void)
{
return _index;
@@ -161,7 +178,7 @@ class GemmInterleaved : public GemmCommon<To, Tr>
// A working size: One of these needed, regardless of thread count. Divided according to window.
size_t get_a_working_size() const
{
- return ROUND_UP(sizeof(Toi) * _k_block * _Mround);
+ return ROUND_UP(sizeof(Toi) * _k_block * _Mround * _nbatches);
}
// B working size: 0, 1 or 3 of these needed depending on pretransposed and threading settings.
@@ -181,15 +198,23 @@ class GemmInterleaved : public GemmCommon<To, Tr>
template <bool pretransposed>
void execute_internal(unsigned int start, unsigned int end, int threadid)
{
+#ifdef CYCLE_PROFILING
profiler prof;
+#endif
+
strategy strat(_ci);
- blockwalker current(_Ksize, _k_block, _Nsize, _x_block);
+ blockwalker current(*this);
blockwalker next = current;
+ /* Translate 'start' and 'end' into a position within the batches and rows. */
+ const unsigned int window_per_batch = _Mround / strategy::out_height;
+ unsigned int batch_0 = start / window_per_batch;
+ unsigned int batch_end = end / window_per_batch;
+
/* Compute the M values to operate on */
- unsigned int m_0 = start * strategy::out_height;
- unsigned int m_max = std::min(end * strategy::out_height, _Msize);
+ unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height;
+ unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height;
/* Make sure we've been set up correctly. */
if(pretransposed)
@@ -205,7 +230,8 @@ class GemmInterleaved : public GemmCommon<To, Tr>
int8_t *working_space_bytes = reinterpret_cast<int8_t *>(_working_space);
// Private buffers. Treat working_space as an array of C buffers (one per thread) first, followed by the (window-divided) A buffer.
- Toi *const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()) + (m_0 * _k_block * sizeof(Toi)));
+ // Set a_panel to the base of the A buffers - compute offsets into it based on M/batches later.
+ Toi *const a_panel = reinterpret_cast<Toi *>(working_space_bytes + (_maxthreads * get_c_working_size()));
Tri *const c_panel = reinterpret_cast<Tri *>(working_space_bytes + (threadid * get_c_working_size()));
// Shared buffers - these come either from BufferManager or _B_transposed.
@@ -225,17 +251,31 @@ class GemmInterleaved : public GemmCommon<To, Tr>
{
if(current.newkblock())
{
- prof(PROFILE_PREPA, ((m_max - m_0) * (current.kmax() - current.k0()) * sizeof(Toi)), [&](void)
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height * (current.kmax() - current.k0()) * sizeof(Toi));
+#endif
+ for(unsigned int batch = batch_0; batch <= batch_end; batch++)
{
+ unsigned int first_m = (batch == batch_0) ? m_0 : 0;
+ unsigned int last_m = (batch == batch_end) ? m_max : _Msize;
+
+ if(first_m >= last_m)
+ continue;
if(_trA ^ strategy::A_transpose)
{
- Transform<strategy::A_interleave, strategy::A_block, true>(a_panel, this->_Aptr, this->_lda, m_0, m_max, current.k0(), current.kmax());
+ Transform<strategy::A_interleave, strategy::A_block, true>(
+ a_panel + ((batch * _Mround + first_m) * _k_block),
+ this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
+ this->_lda, first_m, last_m, current.k0(), current.kmax());
}
else
{
- Transform<strategy::A_interleave, strategy::A_block, false>(a_panel, this->_Aptr, this->_lda, m_0, m_max, current.k0(), current.kmax());
+ Transform<strategy::A_interleave, strategy::A_block, false>(
+ a_panel + ((batch * _Mround + first_m) * _k_block),
+ this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
+ this->_lda, first_m, last_m, current.k0(), current.kmax());
}
- });
+ }
// Figure out how many "K" the kernel will actually process.
kern_k = iceildiv(current.kmax() - current.k0(), strategy::k_unroll);
@@ -258,53 +298,84 @@ class GemmInterleaved : public GemmCommon<To, Tr>
{
_bm->try_populate(next.index(), [&](void *buffer)
{
- prof(PROFILE_PREPB, (next.xmax() - next.x0()) * (next.kmax() - next.k0()) * sizeof(Toi), [&](void)
- {
- Toi *b_panel = reinterpret_cast<Toi *>(buffer);
- if(_trB ^ strategy::B_transpose)
- {
- Transform<strategy::B_interleave, strategy::B_block, true>(b_panel, this->_Bptr, this->_ldb, next.x0(), next.xmax(), next.k0(), next.kmax());
- }
- else
- {
- Transform<strategy::B_interleave, strategy::B_block, false>(b_panel, this->_Bptr, this->_ldb, next.x0(), next.xmax(), next.k0(), next.kmax());
- }
- });
- });
- }
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_PREPB, (next.xmax() - next.x0()) * (next.kmax() - next.k0()) * sizeof(Toi));
+#endif
- /* Get the buffer for this iteration from the BufferManager. */
- b_panel = reinterpret_cast<Toi *>(_bm->get(current.index(), [&](void *bpv)
- {
- prof(PROFILE_PREPB, (current.xmax() - current.x0()) * (current.kmax() - current.k0()) * sizeof(Toi), [&](void)
- {
- Toi *b_panel = reinterpret_cast<Toi *>(bpv);
+ Toi *b_panel = reinterpret_cast<Toi *>(buffer);
if(_trB ^ strategy::B_transpose)
{
- Transform<strategy::B_interleave, strategy::B_block, true>(b_panel, this->_Bptr, this->_ldb, current.x0(), current.xmax(), current.k0(), current.kmax());
+ Transform<strategy::B_interleave, strategy::B_block, true>(
+ b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
+ next.x0(), next.xmax(), next.k0(), next.kmax());
}
else
{
- Transform<strategy::B_interleave, strategy::B_block, false>(b_panel, this->_Bptr, this->_ldb, current.x0(), current.xmax(), current.k0(), current.kmax());
+ Transform<strategy::B_interleave, strategy::B_block, false>(
+ b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
+ next.x0(), next.xmax(), next.k0(), next.kmax());
}
});
+ }
+ /* Get the buffer for this iteration from the BufferManager. */
+ b_panel = reinterpret_cast<Toi *>(_bm->get(current.index(), [&](void *bpv)
+ {
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_PREPB, (current.xmax() - current.x0()) * (current.kmax() - current.k0()) * sizeof(Toi));
+#endif
+
+ Toi *b_panel = reinterpret_cast<Toi *>(bpv);
+ if(_trB ^ strategy::B_transpose)
+ {
+ Transform<strategy::B_interleave, strategy::B_block, true>(
+ b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax());
+ }
+ else
+ {
+ Transform<strategy::B_interleave, strategy::B_block, false>(
+ b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax());
+ }
+
}));
}
/* Do the actual work. */
- for(unsigned int y = m_0; y < m_max; y += strategy::out_height)
+ for(unsigned int batch = batch_0; batch <= batch_end; batch++)
{
- unsigned int ymax = std::min(_Msize, y + strategy::out_height);
+ unsigned int first_m = (batch == batch_0) ? m_0 : 0;
+ unsigned int last_m = (batch == batch_end) ? m_max : _Msize;
- prof(PROFILE_KERNEL, (strategy::out_height * bblocks * strategy::out_width * kern_k), [&](void)
- {
- strat.kernel(a_panel + ((y - m_0) * kern_k), b_panel, c_panel, 1, bblocks, kern_k);
- });
- prof(PROFILE_MERGE, (strategy::out_height * bblocks * strategy::out_width * sizeof(Tr)), [&](void)
+ const Toi *a_ptr = a_panel + (batch * _Mround + first_m) * _k_block;
+
+ if(first_m >= last_m)
+ continue;
+
+ for(unsigned int y = first_m; y < last_m; y += strategy::out_height)
{
- MergeResults<strategy::out_width, strategy::out_height>(this->_Cptr, c_panel, this->_ldc, y, ymax,
- current.x0(), current.xmax(), _alpha, (current.k0() == 0 ? _beta : static_cast<Tr>(1)));
- });
+ unsigned int ymax = std::min(_Msize, y + strategy::out_height);
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height * bblocks * strategy::out_width * kern_k));
+#endif
+
+ strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k);
+
+ a_ptr += (strategy::out_height * kern_k);
+ }
+
+ {
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height * bblocks * strategy::out_width * sizeof(Tr)));
+#endif
+ MergeResults<strategy::out_width, strategy::out_height>(
+ this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride),
+ c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(),
+ _alpha, (current.k0() == 0 ? _beta : static_cast<Tr>(1)));
+ }
+ }
}
if(pretransposed)
@@ -324,9 +395,9 @@ public:
/* Constructor */
GemmInterleaved(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K,
- const bool trA, const bool trB, const Tr alpha, const Tr beta, const int maxthreads,
- const bool pretransposed)
- : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), _pretransposed(pretransposed)
+ const unsigned int nbatches, const unsigned int nmulti, const bool trA, const bool trB,
+ const Tr alpha, const Tr beta, const int maxthreads, const bool pretransposed)
+ : _ci(ci), _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmulti(nmulti), _trA(trA), _trB(trB), _alpha(alpha), _beta(beta), _maxthreads(maxthreads), _pretransposed(pretransposed)
{
const unsigned int L1_size = ci->get_L1_cache_size();
const unsigned int L2_size = ci->get_L2_cache_size();
@@ -375,11 +446,15 @@ public:
// Interface implementation - Compulsory functions
- // Window size: Only the last thread should do a ragged block, so dole out work in units of out_height */
+ // Window size: Only the last thread should do a ragged block, so dole
+ // out work in units of out_height. Factor batches into the window, but
+ // not multi for now (as this would cause problems with the buffer
+ // manager).
+
unsigned int get_window_size() const override
{
// _Mround is a multiple of out_height by definition.
- return _Mround / strategy::out_height;
+ return (_Mround / strategy::out_height) * _nbatches;
}
// set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads.
@@ -471,7 +546,7 @@ public:
size_t get_B_pretransposed_array_size() const override
{
size_t total = 0;
- blockwalker current(_Ksize, _k_block, _Nsize, _x_block);
+ blockwalker current(*this);
do
{
@@ -493,9 +568,9 @@ public:
return total;
}
- void pretranspose_B_array(void *in_buffer, const To *B, const int ldb) override
+ void pretranspose_B_array(void *in_buffer, const To *B, const int ldb, const int B_multi_stride) override
{
- blockwalker current(_Ksize, _k_block, _Nsize, _x_block);
+ blockwalker current(*this);
Toi *buffer = reinterpret_cast<Toi *>(in_buffer);
_B_transposed = buffer;
@@ -514,11 +589,15 @@ public:
if(_trB ^ strategy::B_transpose)
{
- Transform<strategy::B_interleave, strategy::B_block, true>(buffer, B, ldb, current.x0(), current.xmax(), current.k0(), current.kmax());
+ Transform<strategy::B_interleave, strategy::B_block, true>(
+ buffer, B + (current.multi() * B_multi_stride), ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax());
}
else
{
- Transform<strategy::B_interleave, strategy::B_block, false>(buffer, B, ldb, current.x0(), current.xmax(), current.k0(), current.kmax());
+ Transform<strategy::B_interleave, strategy::B_block, false>(
+ buffer, B + (current.multi() * B_multi_stride), ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax());
}
buffer += (x_size * k_size);
@@ -526,6 +605,11 @@ public:
while(current.advance());
}
+ void set_pretransposed_B_data(void *in_buffer) override
+ {
+ _B_transposed = reinterpret_cast<Toi *>(in_buffer);
+ }
+
~GemmInterleaved() override
{
delete _bm;
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
index b0192793b9..beecb76f20 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_native.hpp
@@ -28,9 +28,12 @@
#include "arm_gemm.hpp"
#include "mergeresults.hpp"
-#include "profiler.hpp"
#include "transform.hpp"
+#ifdef CYCLE_PROFILING
+#include "profiler.hpp"
+#endif
+
namespace arm_gemm
{
// Implementation of the GemmCommon abstract class.
@@ -50,6 +53,9 @@ class GemmNative : public GemmCommon<To, Tr>
const unsigned int _Nsize;
const unsigned int _Ksize;
+ const unsigned int _nbatches;
+ const unsigned int _nmultis;
+
Tr _beta;
const CPUInfo *const _ci;
@@ -61,8 +67,8 @@ public:
GemmNative(GemmNative &) = delete;
GemmNative &operator=(GemmNative &) = delete;
- GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const Tr beta)
- : _Msize(M), _Nsize(N), _Ksize(K), _beta(beta), _ci(ci)
+ GemmNative(const CPUInfo *ci, const unsigned int M, const unsigned int N, const unsigned int K, const unsigned int nbatches, const unsigned int nmultis, const Tr beta)
+ : _Msize(M), _Nsize(N), _Ksize(K), _nbatches(nbatches), _nmultis(nmultis), _beta(beta), _ci(ci)
{
/* For now don't do any blocking. TODO: figure out if we should. */
k_block = K;
@@ -72,29 +78,55 @@ public:
// Window is number of out_height blocks
unsigned int get_window_size() const override
{
- return iceildiv(_Msize, strategy::out_height);
+ return iceildiv(_Msize, strategy::out_height) * _nbatches * _nmultis;
}
// Actually execute the GEMM.
void execute(unsigned int start, unsigned int end, int) override
{
+#ifdef CYCLE_PROFILING
profiler prof;
+#endif
strategy strat(_ci);
- unsigned int M_start = start * strategy::out_height;
- unsigned int M_end = std::min(end * strategy::out_height, _Msize);
+ const unsigned int window_per_batch = iceildiv(_Msize, strategy::out_height);
+ const unsigned int window_per_multi = window_per_batch * _nbatches;
+
+ const unsigned int first_multi = start / window_per_multi;
+ const unsigned int last_multi = end / window_per_multi;
+
+ const unsigned int first_batch = (start - (first_multi * window_per_multi)) / window_per_batch;
+ const unsigned int last_batch = (end - (last_multi * window_per_multi)) / window_per_batch;
+
+ const unsigned int first_row = ((start - (first_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
+ const unsigned int last_row = ((end - (last_multi * window_per_multi)) % window_per_batch) * strategy::out_height;
static_assert(std::is_same<To, Toi>::value, "gemm_native: Operand types must be the same.");
static_assert(std::is_same<Tr, Tri>::value, "gemm_native: Result types must be the same.");
- for(unsigned int y0 = M_start; y0 < M_end; y0 += strategy::out_height)
+ for(unsigned int multi = first_multi; multi <= last_multi; multi++)
{
- unsigned int ymax = std::min(y0 + strategy::out_height, M_end);
+ const unsigned int batch_0 = (multi == first_multi) ? first_batch : 0;
+ const unsigned int batch_max = (multi == last_multi) ? last_batch : _nbatches;
- prof(PROFILE_KERNEL, (ymax - y0) * _Nsize * _Ksize, [&](void)
+ for(unsigned int batch = batch_0; batch < batch_max; batch++)
{
- strat.kernel(this->_Aptr + (y0 * this->_lda), this->_lda, this->_Bptr, this->_ldb, this->_Cptr + (y0 * this->_ldc), this->_ldc, _beta, (ymax - y0), _Nsize, _Ksize);
- });
+ const unsigned int m_start = ((multi == first_multi) && (batch == first_batch)) ? first_row : 0;
+ const unsigned int m_end = ((multi == last_multi) && (batch == last_batch)) ? last_row : _Msize;
+
+ for(unsigned int y0 = m_start; y0 < m_end; y0 += strategy::out_height)
+ {
+ const unsigned int ymax = std::min(y0 + strategy::out_height, m_end);
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (ymax - y0) * _Nsize * _Ksize);
+#endif
+
+ strat.kernel(this->_Aptr + (multi * this->_A_multi_stride) + (batch * this->_A_batch_stride) + (y0 * this->_lda), this->_lda,
+ this->_Bptr + (multi * this->_B_multi_stride), this->_ldb,
+ this->_Cptr + (multi * this->_C_multi_stride) + (batch * this->_C_batch_stride) + (y0 * this->_ldc), this->_ldc,
+ _beta, (ymax - y0), _Nsize, _Ksize);
+ }
+ }
}
}
};
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index 3e790e1b2a..8f1f377aaf 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -33,10 +33,11 @@ namespace arm_gemm
{
template <>
UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, uint32_t alpha, uint32_t beta,
const int maxthreads, const bool pretransposed_hint)
{
- return UniqueGemmCommon<uint16_t, uint32_t>(new GemmInterleaved<gemm_u16_12x8, uint16_t, uint32_t>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<uint16_t, uint32_t>(new GemmInterleaved<gemm_u16_12x8, uint16_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
// Instantiate static class members
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index 9ec479ca7c..0e9f3f255c 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -34,17 +34,18 @@ namespace arm_gemm
{
template <>
UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(const CPUInfo &ci, const unsigned int M, const unsigned int N, const unsigned int K,
+ const unsigned int nbatches, const unsigned int nmulti,
const bool trA, const bool trB, const uint32_t alpha, const uint32_t beta,
const int maxthreads, const bool pretransposed_hint)
{
if(ci.has_dotprod())
{
// Dot product supporting CPUs. This family has a special version for A55r1.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_12x8, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
// Non dot-product code.
- return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(&ci, M, N, K, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
+ return UniqueGemmCommon<uint8_t, uint32_t>(new GemmInterleaved<gemm_u8_4x4, uint8_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
// TODO: There's a better approach for A53, but it doesn't work
// well on heterogeneous systems as the required data formats
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
index 29c71f2511..e5cc79eaed 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_native_transposed.hpp
@@ -28,9 +28,12 @@
#include "arm_gemm.hpp"
#include "mergeresults.hpp"
-#include "profiler.hpp"
#include "transform.hpp"
+#ifdef CYCLE_PROFILING
+#include "profiler.hpp"
+#endif
+
namespace arm_gemm
{
// Implementation of the GemmCommon abstract class.
@@ -48,6 +51,7 @@ class GemvNativeTransposed : public GemmCommon<To, Tr>
const unsigned int _Nsize;
const unsigned int _Ksize;
+ const unsigned int _nmultis;
const Tr _beta;
@@ -60,45 +64,61 @@ public:
GemvNativeTransposed(GemvNativeTransposed &) = delete;
GemvNativeTransposed &operator=(GemvNativeTransposed &) = delete;
- GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const Tr beta)
- : _Nsize(N), _Ksize(K), _beta(beta), _ci(ci)
+ GemvNativeTransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const Tr beta)
+ : _Nsize(N), _Ksize(K), _nmultis(nmultis), _beta(beta), _ci(ci)
{
/* For now don't do any blocking. TODO: figure out if we should. */
m_block = K;
n_block = N;
}
- // Window is number of out_width blocks.
+ // Window is number of out_width blocks times number of multis.
unsigned int get_window_size() const override
{
- return iceildiv(_Nsize, strategy::out_width);
+ return iceildiv(_Nsize, strategy::out_width) * _nmultis;
}
// Actually execute the GEMV.
void execute(unsigned int start, unsigned int end, int) override
{
+#ifdef CYCLE_PROFILING
profiler prof;
+#endif
+
strategy strat(_ci);
- unsigned int N_start = start * strategy::out_width;
- unsigned int N_end = std::min(end * strategy::out_width, _Nsize);
+ const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width);
+ const unsigned int multi_0 = start / window_per_multi;
+ const unsigned int multi_end = end / window_per_multi;
+
+ const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width;
+ const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width;
static_assert(std::is_same<To, Toi>::value, "gemv_transposed: Operand types must be the same.");
static_assert(std::is_same<Tr, Tri>::value, "gemv_transposed: Result types must be the same.");
- for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
+ for(unsigned int multi = multi_0; multi <= multi_end; multi++)
{
- unsigned int mmax = std::min(m0 + m_block, _Ksize);
+ const unsigned int n_start = (multi == multi_0) ? n_0 : 0;
+ const unsigned int n_end = (multi == multi_end) ? n_max : _Nsize;
- for(unsigned int n0 = N_start; n0 < N_end; n0 += n_block)
- {
- unsigned int nmax = std::min(n0 + n_block, N_end);
+ if(n_end <= n_start)
+ continue;
- prof(PROFILE_KERNEL, ((mmax - m0) * (nmax - n0)), [&](void)
+ for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
+ {
+ unsigned int mmax = std::min(m0 + m_block, _Ksize);
+ for(unsigned int n0 = n_start; n0 < n_end; n0 += n_block)
{
- strat.kernel(this->_Bptr + (m0 * this->_ldb) + n0, this->_Aptr + m0, this->_Cptr + n0,
+ unsigned int nmax = std::min(n0 + n_block, n_end);
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax - m0) * (nmax - n0));
+#endif
+ strat.kernel(this->_Bptr + (multi * this->_B_multi_stride) + (m0 * this->_ldb) + n0,
+ this->_Aptr + (multi * this->_A_multi_stride) + m0,
+ this->_Cptr + (multi * this->_C_multi_stride) + n0,
_beta, this->_ldb, (mmax - m0), (nmax - n0));
- });
+ }
}
}
}
diff --git a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
index 0df331acb4..770ee033c8 100644
--- a/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemv_pretransposed.hpp
@@ -28,17 +28,18 @@
#include "arm_gemm.hpp"
#include "mergeresults.hpp"
-#include "profiler.hpp"
#include "transform.hpp"
+#ifdef CYCLE_PROFILING
+#include "profiler.hpp"
+#endif
+
namespace arm_gemm
{
// Implementation of the GemmCommon abstract class.
//
-// This is implementation is for GEMV with a transposed matrix.
-//
-// By default the source data is used in-place, but if type conversion is
-// needed we need to allocate working space (CURRENTLY NOT IMPLEMENTED).
+// This is implementation is for GEMV with pretransposition.
+// batches are not supported as a batched GEMV makes no sense (can be converted to a GEMM).
template <typename strategy, typename To, typename Tr>
class GemvPretransposed : public GemmCommon<To, Tr>
@@ -48,12 +49,14 @@ class GemvPretransposed : public GemmCommon<To, Tr>
const unsigned int _Nsize;
const unsigned int _Ksize;
+ const unsigned int _nmultis;
const bool _trB;
const Tr _beta;
const CPUInfo *const _ci;
+ const unsigned int _buffer_per_multi;
unsigned int m_block = 0;
unsigned int n_block = 0;
@@ -64,44 +67,64 @@ public:
GemvPretransposed(GemvPretransposed &) = delete;
GemvPretransposed &operator=(GemvPretransposed &) = delete;
- GemvPretransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const bool trB, const Tr beta)
- : _Nsize(N), _Ksize(K), _trB(trB), _beta(beta), _ci(ci)
+ GemvPretransposed(const CPUInfo *ci, const unsigned int N, const unsigned int K, const unsigned int nmultis, const bool trB, const Tr beta)
+ : _Nsize(N), _Ksize(K), _nmultis(nmultis), _trB(trB), _beta(beta), _ci(ci), _buffer_per_multi(_Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave)
{
/* For now don't do any blocking. TODO: figure out if we should. */
m_block = K;
n_block = N;
}
- // Window is number of out_width blocks.
+ // Window is number of out_width blocks, times number of multis.
unsigned int get_window_size() const override
{
- return iceildiv(_Nsize, strategy::out_width);
+ return iceildiv(_Nsize, strategy::out_width) * _nmultis;
}
// Actually execute the GEMV.
void execute(unsigned int start, unsigned int end, int) override
{
+#ifdef CYCLE_PROFILING
profiler prof;
+#endif
+
strategy strat(_ci);
- unsigned int N_start = start * strategy::out_width;
- unsigned int N_end = std::min(end * strategy::out_width, _Nsize);
+ /* Break the window values down into multis of interest... */
+ const unsigned int window_per_multi = iceildiv(_Nsize, strategy::out_width);
+ const unsigned int multi_0 = start / window_per_multi;
+ const unsigned int multi_end = end / window_per_multi;
+
+ /* ... and figure out where we start and end in the first and last multi. */
+ const unsigned int n_0 = (start - (multi_0 * window_per_multi)) * strategy::out_width;
+ const unsigned int n_max = (end - (multi_end * window_per_multi)) * strategy::out_width;
static_assert(std::is_same<Tr, Tri>::value, "GemvPretransposed: Result types must be the same.");
- for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
+ for(unsigned int multi = multi_0; multi <= multi_end; multi++)
{
- unsigned int mmax = std::min(m0 + m_block, _Ksize);
+ const unsigned int n_start = (multi == multi_0) ? n_0 : 0;
+ const unsigned int n_end = (multi == multi_end) ? n_max : _Nsize;
- for(unsigned int n0 = N_start; n0 < N_end; n0 += n_block)
- {
- unsigned int nmax = std::min(n0 + n_block, N_end);
+ if(n_end <= n_start)
+ continue;
- prof(PROFILE_KERNEL, ((mmax - m0) * (nmax - n0)), [&](void)
+ for(unsigned int m0 = 0; m0 < _Ksize; m0 += m_block)
+ {
+ unsigned int mmax = std::min(m0 + m_block, _Ksize);
+ for(unsigned int n = n_start; n < n_end; n += n_block)
{
+ unsigned int nmax = std::min(n + n_block, n_end);
+#ifdef CYCLE_PROFILING
+ auto p = prof.ScopedProfiler(PROFILE_KERNEL, (mmax - m0) * (nmax - n));
+#endif
/* This assumes that the underlying call was a GEMM with M=1; for the N=1 case we would have to pick up this->_Bptr below instead */
- strat.kernel(_A_pretransposed + (n0 * _Ksize) + (m0 * strategy::A_interleave), (_Ksize * strategy::A_interleave), this->_Aptr + m0, this->_Cptr + n0, _beta, (mmax - m0), (nmax - n0));
- });
+ strat.kernel(_A_pretransposed + (multi * _buffer_per_multi) + (n * _Ksize) + (m0 * strategy::A_interleave),
+ (_Ksize * strategy::A_interleave),
+ this->_Aptr + (multi * this->_A_multi_stride) + m0,
+ this->_Cptr + (multi * this->_C_multi_stride) + n,
+ _beta, (mmax - m0), (nmax - n));
+ }
}
}
}
@@ -120,27 +143,35 @@ public:
size_t get_B_pretransposed_array_size() const override
{
- return _Ksize * iceildiv(_Nsize, strategy::A_interleave) * strategy::A_interleave * sizeof(float);
+ return _buffer_per_multi * _nmultis * sizeof(To);
}
- void pretranspose_B_array(void *buffer, const To *B, const int ldb) override
+ void pretranspose_B_array(void *buffer, const To *B, const int ldb, const int B_multi_stride) override
{
Toi *A_buffer = reinterpret_cast<Toi *>(buffer);
- /* Reverse sense here as we are dealing with B rather than A. So if
- * strategy::A_transpose is false and _trB is false, we still
- * transpose. */
- if(_trB ^ strategy::A_transpose)
+ for(unsigned int multi = 0; multi < _nmultis; multi++)
{
- Transform<strategy::A_interleave, strategy::A_block, false>(A_buffer, B, ldb, 0, _Nsize, 0, _Ksize);
- }
- else
- {
- Transform<strategy::A_interleave, strategy::A_block, true>(A_buffer, B, ldb, 0, _Nsize, 0, _Ksize);
+ /* Reverse sense here as we are dealing with B rather than A. So if
+ * strategy::A_transpose is false and _trB is false, we still
+ * transpose. */
+ if(_trB ^ strategy::A_transpose)
+ {
+ Transform<strategy::A_interleave, strategy::A_block, false>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
+ }
+ else
+ {
+ Transform<strategy::A_interleave, strategy::A_block, true>(A_buffer + (multi * _buffer_per_multi), B + (multi * B_multi_stride), ldb, 0, _Nsize, 0, _Ksize);
+ }
}
_A_pretransposed = A_buffer;
}
+
+ void set_pretransposed_B_data(void *buffer) override
+ {
+ _A_pretransposed = reinterpret_cast<Toi *>(buffer);
+ }
};
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/profiler.hpp b/src/core/NEON/kernels/arm_gemm/profiler.hpp
index c38b0a443c..ada0c95e26 100644
--- a/src/core/NEON/kernels/arm_gemm/profiler.hpp
+++ b/src/core/NEON/kernels/arm_gemm/profiler.hpp
@@ -47,6 +47,35 @@ private:
int currentevent = 0;
int countfd = 0;
+ class ScopedProfilerClass
+ {
+ private:
+ profiler &_parent;
+ bool legal = false;
+
+ public:
+ ScopedProfilerClass(profiler &prof, int i, unsigned long u)
+ : _parent(prof)
+ {
+ if(prof.currentevent == maxevents)
+ return;
+
+ prof.events[prof.currentevent] = i;
+ prof.units[prof.currentevent] = u;
+ legal = true;
+ start_counter(prof.countfd);
+ }
+
+ ~ScopedProfilerClass()
+ {
+ if(!legal)
+ return;
+
+ long long cycs = stop_counter(_parent.countfd);
+ _parent.times[_parent.currentevent++] = cycs;
+ }
+ };
+
public:
profiler()
{
@@ -107,19 +136,9 @@ public:
times[currentevent++] = cycs;
}
}
-};
-
-#else
-
-namespace arm_gemm
-{
-class profiler
-{
-public:
- template <typename T>
- void operator()(int i, unsigned long u, T func)
+ ScopedProfilerClass ScopedProfiler(int i, unsigned long u)
{
- func();
+ return ScopedProfilerClass(*this, i, u);
}
};