aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavid Mansell <David.Mansell@arm.com>2018-07-06 14:52:52 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:10 +0000
commitd93991e290618a685b67506c78090350e6aee43f (patch)
tree1d5c3b3017cfccd3f0ec3f24e8e11334cf977ce3
parentdec32a9edd4b3c6dc55c60d7436e79af6be58c3d (diff)
downloadComputeLibrary-d93991e290618a685b67506c78090350e6aee43f.tar.gz
COMPMID-1380: Pre-work for SVE support.
This patch makes the needed infrastructure changes to allow SVE kernels to be added later on. Change-Id: Ide5bccac2f47278e93fff3d648231aee2d5f8c2e Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/139070 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Jenkins <bsgcomp@arm.com>
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int16.cpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_int8.cpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp127
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp7
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp27
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp27
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp27
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp27
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp27
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp19
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp22
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp24
-rw-r--r--src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp27
-rw-r--r--src/core/NEON/kernels/arm_gemm/mergeresults.hpp18
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp69
-rw-r--r--src/core/NEON/kernels/arm_gemm/transform.hpp10
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp2
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp8
-rw-r--r--src/core/NEON/kernels/arm_gemm/utils.hpp7
33 files changed, 303 insertions, 232 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
index fa12942829..65f43f302b 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp16.cpp
@@ -67,12 +67,6 @@ UniqueGemmCommon<__fp16, __fp16> gemm(const CPUInfo &ci, const unsigned int M, c
#endif
}
-// Instantiate static class members if necessary.
-#if defined(__aarch64__) && (defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) || defined(FP16_KERNELS))
-const int hgemm_24x8::out_width;
-const int hgemm_24x8::out_height;
-#endif
-
} // namespace arm_gemm
#endif // __ARM_FP16_ARGS
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
index 99f061bde8..2fd040efbe 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_fp32.cpp
@@ -76,14 +76,8 @@ UniqueGemmCommon<float, float> gemm<float, float>(const CPUInfo &ci, const unsig
// Instantiate static class variables.
#ifdef __aarch64__
-const int sgemm_12x8::out_width;
-const int sgemm_12x8::out_height;
-
const int sgemm_native_16x4::out_width;
const int sgemm_native_16x4::out_height;
-#else
-const int sgemm_8x6::out_width;
-const int sgemm_8x6::out_height;
#endif
} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
index 317541919b..57cd15f698 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int16.cpp
@@ -39,10 +39,6 @@ UniqueGemmCommon<int16_t, int32_t> gemm<int16_t, int32_t>(const CPUInfo &ci, con
return UniqueGemmCommon<int16_t, int32_t>(new GemmInterleaved<gemm_s16_12x8, int16_t, int32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
-// Instantiate static class members
-const int gemm_s16_12x8::out_width;
-const int gemm_s16_12x8::out_height;
-
} // namespace arm_gemm
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
index 7eff47de68..04803eb81a 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_int8.cpp
@@ -51,12 +51,6 @@ UniqueGemmCommon<int8_t, int32_t> gemm<int8_t, int32_t>(const CPUInfo &ci, const
// gemm = new GemmInterleaved<gemm_s16_12x8, int8_t, int32_t>(ci, M, N, K, trA, trB);
}
-// Instantiate static class members
-const int gemm_s8_12x8::out_width;
-const int gemm_s8_12x8::out_height;
-const int gemm_s8_4x4::out_width;
-const int gemm_s8_4x4::out_height;
-
} // namespace arm_gemm
#endif // aarch64
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
index c304edd1f9..c5a43e6519 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_interleaved.hpp
@@ -158,7 +158,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
// C working size: One needed per thread.
size_t get_c_working_size() const {
- return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height);
+ return ROUND_UP(sizeof(Tri) * _x_block * strategy::out_height());
}
// Internal execute function.
@@ -174,13 +174,13 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
blockwalker next=current;
/* Translate 'start' and 'end' into a position within the batches and rows. */
- const unsigned int window_per_batch = _Mround / strategy::out_height;
+ const unsigned int window_per_batch = _Mround / strategy::out_height();
unsigned int batch_0 = start / window_per_batch;
unsigned int batch_end = end / window_per_batch;
/* Compute the M values to operate on */
- unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height;
- unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height;
+ unsigned int m_0 = (start - (batch_0 * window_per_batch)) * strategy::out_height();
+ unsigned int m_max = (end - (batch_end * window_per_batch)) * strategy::out_height();
/* Make sure we've been set up correctly. */
if (pretransposed) {
@@ -214,7 +214,7 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
for (;!current.done();current.advance()) {
if (current.newkblock()) {
#ifdef CYCLE_PROFILING
- auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height * (current.kmax()-current.k0()) * sizeof(Toi));
+ auto p=prof.ScopedProfiler(PROFILE_PREPA, (end - start) * strategy::out_height() * (current.kmax()-current.k0()) * sizeof(Toi));
#endif
for (unsigned int batch = batch_0; batch <= batch_end; batch++) {
unsigned int first_m = (batch == batch_0) ? m_0 : 0;
@@ -223,25 +223,17 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
if (first_m >= last_m)
continue;
- if (_trA ^ strategy::A_transpose) {
- Transform<strategy::A_interleave, strategy::A_block, true>(
- a_panel + ((batch * _Mround + first_m) * _k_block),
- this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
- this->_lda, first_m, last_m, current.k0(), current.kmax());
- } else {
- Transform<strategy::A_interleave, strategy::A_block, false>(
- a_panel + ((batch * _Mround + first_m) * _k_block),
- this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
- this->_lda, first_m, last_m, current.k0(), current.kmax());
- }
+ strat.transforms.PrepareA(a_panel + ((batch * _Mround + first_m) * _k_block),
+ this->_Aptr + (batch * this->_A_batch_stride) + (current.multi() * this->_A_multi_stride),
+ this->_lda, first_m, last_m, current.k0(), current.kmax(), _trA);
}
// Figure out how many "K" the kernel will actually process.
- kern_k = iceildiv(current.kmax() - current.k0(), strategy::k_unroll);
- kern_k *= strat.k_unroll;
+ kern_k = iceildiv(current.kmax() - current.k0(), strategy::k_unroll());
+ kern_k *= strat.k_unroll();
}
- int bblocks = iceildiv(current.xmax() - current.x0(), strategy::out_width);
+ int bblocks = iceildiv(current.xmax() - current.x0(), strategy::out_width());
if (!pretransposed) {
/* Look ahead to the next block and populate it if necessary.
@@ -259,15 +251,9 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
#endif
Toi *b_panel = reinterpret_cast<Toi *>(buffer);
- if (_trB ^ strategy::B_transpose) {
- Transform<strategy::B_interleave, strategy::B_block, true>(
- b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
- next.x0(), next.xmax(), next.k0(), next.kmax());
- } else {
- Transform<strategy::B_interleave, strategy::B_block, false>(
- b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
- next.x0(), next.xmax(), next.k0(), next.kmax());
- }
+
+ strat.transforms.PrepareB(b_panel, this->_Bptr + (next.multi() * this->_B_multi_stride), this->_ldb,
+ next.x0(), next.xmax(), next.k0(), next.kmax(), _trB);
});
}
@@ -278,15 +264,9 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
#endif
Toi *b_panel = reinterpret_cast<Toi *>(bpv);
- if (_trB ^ strategy::B_transpose) {
- Transform<strategy::B_interleave, strategy::B_block, true>(
- b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax());
- } else {
- Transform<strategy::B_interleave, strategy::B_block, false>(
- b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax());
- }
+
+ strat.transforms.PrepareB(b_panel, this->_Bptr + (current.multi() * this->_B_multi_stride), this->_ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax(), _trB);
}));
}
@@ -300,33 +280,32 @@ class GemmInterleaved : public GemmCommon<To, Tr> {
if (first_m >= last_m)
continue;
- for (unsigned int y=first_m; y<last_m; y+=strategy::out_height) {
- unsigned int ymax = std::min(_Msize, y + strategy::out_height);
+ for (unsigned int y=first_m; y<last_m; y+=strategy::out_height()) {
+ unsigned int ymax = std::min(_Msize, y + strategy::out_height());
{
#ifdef CYCLE_PROFILING
- auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height * bblocks * strategy::out_width * kern_k));
+ auto p=prof.ScopedProfiler(PROFILE_KERNEL, (strategy::out_height() * bblocks * strategy::out_width() * kern_k));
#endif
strat.kernel(a_ptr, b_panel, c_panel, 1, bblocks, kern_k);
- a_ptr += (strategy::out_height * kern_k);
+ a_ptr += (strategy::out_height() * kern_k);
}
{
#ifdef CYCLE_PROFILING
- auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height * bblocks * strategy::out_width * sizeof(Tr)));
+ auto p=prof.ScopedProfiler(PROFILE_MERGE, (strategy::out_height() * bblocks * strategy::out_width() * sizeof(Tr)));
#endif
- MergeResults<strategy::out_width, strategy::out_height>(
- this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride),
- c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(),
- _alpha, (current.k0()==0 ? _beta : static_cast<Tr>(1)));
+ strat.transforms.Merge(this->_Cptr + (batch * this->_C_batch_stride) + (current.multi() * this->_C_multi_stride),
+ c_panel, this->_ldc, y, ymax, current.x0(), current.xmax(),
+ _alpha, (current.k0()==0 ? _beta : static_cast<Tr>(1)));
}
}
}
if (pretransposed) {
- b_panel += (bblocks * strat.out_width * kern_k);
+ b_panel += (bblocks * strat.out_width() * kern_k);
} else {
_bm->release(current.index());
}
@@ -353,11 +332,11 @@ public:
// k_block: Find out how much of the larger array can be loaded into half the cache.
// This should account for associative caches.
- _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width, strategy::out_height)));
+ _k_block = (L1_size / 2) / (sizeof(Toi) * (std::max(strategy::out_width(), strategy::out_height())));
// Needs to be (at least a single) multiple of the K unroll level.
- _k_block /= strategy::k_unroll;
- _k_block = std::max(_k_block, 1U) * strategy::k_unroll;
+ _k_block /= strategy::k_unroll();
+ _k_block = std::max(_k_block, 1U) * strategy::k_unroll();
// Now tune to presented problem size; this is how many blocks we need.
int num_k_blocks = iceildiv(K, _k_block);
@@ -366,28 +345,28 @@ public:
_k_block = iceildiv(K, num_k_blocks);
// And round UP to the K unroll level required.
- _k_block = iceildiv(_k_block, strategy::k_unroll);
- _k_block *= strategy::k_unroll;
+ _k_block = iceildiv(_k_block, strategy::k_unroll());
+ _k_block *= strategy::k_unroll();
// x_block: Work out how many rows (of length k_block) will fit in the L2
// Don't allocate more than 90% of the L2 to allow for overheads, and subtract off the L1 contents.
- _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width + strategy::out_height))) /
+ _x_block = (((L2_size * 9) / 10) - (_k_block * sizeof(Toi) * (strategy::out_width() + strategy::out_height()))) /
(sizeof(Toi) * _k_block);
// Needs to be (at least a single) multiple of the kernel output width.
- _x_block /= strategy::out_width;
- _x_block = std::max(_x_block, 1U) * strategy::out_width;
+ _x_block /= strategy::out_width();
+ _x_block = std::max(_x_block, 1U) * strategy::out_width();
// And tune to the presented problem size.
int num_x_blocks = iceildiv(N, _x_block);
_x_block = iceildiv(N, num_x_blocks);
- _x_block = iceildiv(_x_block, strategy::out_width);
- _x_block *= strategy::out_width;
+ _x_block = iceildiv(_x_block, strategy::out_width());
+ _x_block *= strategy::out_width();
// Work out the rounded size of M - needed for some buffers.
- _Mround = iceildiv(M, strategy::out_height);
- _Mround *= strategy::out_height;
+ _Mround = iceildiv(M, strategy::out_height());
+ _Mround *= strategy::out_height();
}
// Interface implementation - Compulsory functions
@@ -398,7 +377,7 @@ public:
// manager).
unsigned int get_window_size() const override {
// _Mround is a multiple of out_height by definition.
- return (_Mround / strategy::out_height) * _nbatches;
+ return (_Mround / strategy::out_height()) * _nbatches;
}
// set_nthreads: pass on to buffer manager to avoid it waiting for non-existant threads.
@@ -483,11 +462,11 @@ public:
size_t k_size = (current.kmax() - current.k0());
/* Round sizes up as needed. */
- x_size = iceildiv(x_size, strategy::out_width);
- x_size *= strategy::out_width;
+ x_size = iceildiv(x_size, strategy::out_width());
+ x_size *= strategy::out_width();
- k_size = iceildiv(k_size, strategy::k_unroll);
- k_size *= strategy::k_unroll;
+ k_size = iceildiv(k_size, strategy::k_unroll());
+ k_size *= strategy::k_unroll();
total += x_size * k_size * sizeof(Toi);
} while (current.advance());
@@ -499,6 +478,7 @@ public:
blockwalker current(*this);
Toi *buffer = reinterpret_cast<Toi *>(in_buffer);
_B_transposed = buffer;
+ strategy strat(_ci);
do {
/* Figure out the size of each block. */
@@ -506,21 +486,14 @@ public:
size_t k_size = (current.kmax() - current.k0());
/* Round sizes up as needed. */
- x_size = iceildiv(x_size, strategy::out_width);
- x_size *= strategy::out_width;
+ x_size = iceildiv(x_size, strategy::out_width());
+ x_size *= strategy::out_width();
- k_size = iceildiv(k_size, strategy::k_unroll);
- k_size *= strategy::k_unroll;
+ k_size = iceildiv(k_size, strategy::k_unroll());
+ k_size *= strategy::k_unroll();
- if (_trB ^ strategy::B_transpose) {
- Transform<strategy::B_interleave, strategy::B_block, true>(
- buffer, B + (current.multi() * B_multi_stride), ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax());
- } else {
- Transform<strategy::B_interleave, strategy::B_block, false>(
- buffer, B + (current.multi() * B_multi_stride), ldb,
- current.x0(), current.xmax(), current.k0(), current.kmax());
- }
+ strat.transforms.PrepareB(buffer, B + (current.multi() * B_multi_stride), ldb,
+ current.x0(), current.xmax(), current.k0(), current.kmax(), _trB);
buffer += (x_size * k_size);
} while (current.advance());
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
index 4e8b811e83..6db55c02d0 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint16.cpp
@@ -39,10 +39,6 @@ UniqueGemmCommon<uint16_t, uint32_t> gemm<uint16_t, uint32_t>(const CPUInfo &ci,
return UniqueGemmCommon<uint16_t, uint32_t>(new GemmInterleaved<gemm_u16_12x8, uint16_t, uint32_t>(&ci, M, N, K, nbatches, nmulti, trA, trB, alpha, beta, maxthreads, pretransposed_hint));
}
-// Instantiate static class members
-const int gemm_u16_12x8::out_width;
-const int gemm_u16_12x8::out_height;
-
} // namespace arm_gemm
#endif // __aarch64__
diff --git a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
index 321aa65d83..1ca92f9d4e 100644
--- a/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
+++ b/src/core/NEON/kernels/arm_gemm/gemm_uint8.cpp
@@ -51,13 +51,6 @@ UniqueGemmCommon<uint8_t, uint32_t> gemm<uint8_t, uint32_t>(const CPUInfo &ci, c
// gemm = new GemmInterleaved<gemm_s16_12x8, int8_t, int32_t>(ci, M, N, K, trA, trB);
}
-// Instantiate static class members
-const int gemm_u8_12x8::out_width;
-const int gemm_u8_12x8::out_height;
-
-const int gemm_u8_4x4::out_width;
-const int gemm_u8_4x4::out_height;
-
} // namespace arm_gemm
#endif // aarch64
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp
index 01bf1f9297..06e62456dc 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6.hpp
@@ -25,6 +25,8 @@
#ifdef __arm__
+#include "../std_transforms_fixed.hpp"
+
namespace arm_gemm {
// Actual kernel implementations
@@ -47,20 +49,21 @@ public:
typedef void (*kern_type)(const float *, const float *, float *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 6;
- static const int A_block = 1;
- static const int A_transpose = 0;
+ /* Kernel blocking parameters */
+ static int out_width() {
+ return 8;
+ }
- /* Same for B input */
- static const int B_interleave = 8;
- static const int B_block = 1;
- static const int B_transpose = 1;
+ static int out_height() {
+ return 6;
+ }
- /* Kernel blocking parameters */
- static const int out_width = 8;
- static const int out_height = 6;
- static const int k_unroll = 1;
+ static int k_unroll() {
+ return 1;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 6, 8> transforms = {};
kern_type kernel = a32_sgemm_8x6;
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp
index 27700b47d1..95a2bc2fbc 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s16_12x8.hpp
@@ -25,6 +25,8 @@
#ifdef __aarch64__
+#include "../std_transforms_fixed.hpp"
+
namespace arm_gemm {
// Actual kernel implementations
@@ -45,20 +47,21 @@ public:
typedef void (*kern_type)(const int16_t *, const int16_t *, int32_t *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 8;
- static const int A_block = 1;
- static const int A_transpose = 0;
+ /* Kernel blocking parameters */
+ static int out_width() {
+ return 12;
+ }
+
+ static int out_height() {
+ return 8;
+ }
- /* Same for B input */
- static const int B_interleave = 12;
- static const int B_block = 1;
- static const int B_transpose = 1;
+ static int k_unroll() {
+ return 1;
+ }
- /* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 1;
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12> transforms = {};
kern_type kernel = a64_gemm_s16_asimd_12x8;
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp
index cb97270c24..fdc0200435 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_12x8.hpp
@@ -27,6 +27,8 @@
#include "arm_gemm.hpp"
+#include "../std_transforms_fixed.hpp"
+
namespace arm_gemm {
// Load the actual kernel
@@ -40,20 +42,21 @@ public:
typedef void (*kern_type)(const int8_t *, const int8_t *, int32_t *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 8;
- static const int A_block = 4;
- static const bool A_transpose = false;
+ /* Kernel blocking parameters */
+ static int out_width() {
+ return 12;
+ }
- /* Same for B input */
- static const int B_interleave = 12;
- static const int B_block = 4;
- static const bool B_transpose = true;
+ static int out_height() {
+ return 8;
+ }
- /* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 4;
+ static int k_unroll() {
+ return 4;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12, 4> transforms = {};
kern_type kernel = a64_gemm_s8_12x8;
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp
index b5b07b2c56..be7ead9f48 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_s8_4x4.hpp
@@ -25,6 +25,8 @@
#ifdef __aarch64__
+#include "../std_transforms_fixed.hpp"
+
namespace arm_gemm {
// Load the actual kernel
@@ -39,20 +41,21 @@ public:
typedef void (*kern_type)(const int8_t *, const int8_t *, int32_t *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 4;
- static const int A_block = 16;
- static const bool A_transpose = false;
+ /* Kernel blocking parameters */
+ static int out_width() {
+ return 4;
+ }
+
+ static int out_height() {
+ return 4;
+ }
- /* Same for B input */
- static const int B_interleave = 4;
- static const int B_block = 16;
- static const bool B_transpose = true;
+ static int k_unroll() {
+ return 16;
+ }
- /* Kernel blocking parameters */
- static const int out_width = 4;
- static const int out_height = 4;
- static const int k_unroll = 16;
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 4, 4, 16> transforms = {};
kern_type kernel=a64_gemm_s8_4x4;
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp
index 13dd570677..d2692ba77f 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_12x8.hpp
@@ -25,6 +25,8 @@
#ifdef __aarch64__
+#include "../std_transforms_fixed.hpp"
+
namespace arm_gemm {
// Actual kernel implementations
@@ -45,20 +47,21 @@ public:
typedef void (*kern_type)(const uint16_t *, const uint16_t *, uint32_t *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 8;
- static const int A_block = 1;
- static const int A_transpose = 0;
+ /* Kernel blocking parameters */
+ static int out_width() {
+ return 12;
+ }
+
+ static int out_height() {
+ return 8;
+ }
- /* Same for B input */
- static const int B_interleave = 12;
- static const int B_block = 1;
- static const int B_transpose = 1;
+ static int k_unroll() {
+ return 1;
+ }
- /* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 1;
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12> transforms = {};
kern_type kernel = a64_gemm_u16_asimd_12x8;
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp
index c67aed7275..a252abfd3e 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_12x8.hpp
@@ -27,6 +27,8 @@
#include "arm_gemm.hpp"
+#include "../std_transforms_fixed.hpp"
+
namespace arm_gemm {
// Load the actual kernel
@@ -51,9 +53,20 @@ public:
static const bool B_transpose = true;
/* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 4;
+ static int out_width() {
+ return 12;
+ }
+
+ static int out_height() {
+ return 8;
+ }
+
+ static int k_unroll() {
+ return 4;
+ }
+
+ // Use the standard fixed sized transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12, 4> transforms = {};
kern_type kernel = a64_gemm_u8_12x8;
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp
index 23f4c1d84f..2da3ecd4f8 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u8_4x4.hpp
@@ -25,6 +25,8 @@
#ifdef __aarch64__
+#include "../std_transforms_fixed.hpp"
+
namespace arm_gemm {
// Kernel definition
@@ -48,14 +50,24 @@ public:
static const bool B_transpose = true;
/* Kernel blocking parameters */
- static const int out_width = 4;
- static const int out_height = 4;
- static const int k_unroll = 16;
+ static int out_width() {
+ return 4;
+ }
+
+ static int out_height() {
+ return 4;
+ }
+
+ static int k_unroll() {
+ return 16;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 4, 4, 16> transforms = {};
- kern_type kernel = nullptr;
+ kern_type kernel = a64_gemm_u8_4x4;
gemm_u8_4x4(const CPUInfo *ci) {
- kernel = a64_gemm_u8_4x4;
}
};
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp
index fe74b994f5..911a4ebb01 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_24x8.hpp
@@ -27,6 +27,8 @@
#include "arm_gemm.hpp"
+#include "../std_transforms_fixed.hpp"
+
namespace arm_gemm {
// Actual kernel implementations
@@ -44,17 +46,21 @@ public:
typedef void (*kern_type)(const __fp16 *, const __fp16 *, __fp16 *, int, int, int);
- static const int A_block = 1;
- static const int A_interleave = 8;
- static const bool A_transpose = false;
+ /* Kernel blocking parameters */
+ static int out_width() {
+ return 24;
+ }
- static const int B_block = 1;
- static const int B_interleave = 24;
- static const bool B_transpose = true;
+ static int out_height() {
+ return 8;
+ }
+
+ static int k_unroll() {
+ return 1;
+ }
- static const int out_width = 24;
- static const int out_height = 8;
- static const int k_unroll = 1;
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 24> transforms = {};
// Default to the generic kernel
kern_type kernel = a64_hgemm_asimd_24x8;
diff --git a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp
index c91d50469f..10d1069417 100644
--- a/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/kernels/a64_sgemm_12x8.hpp
@@ -25,6 +25,8 @@
#ifdef __aarch64__
+#include "../std_transforms_fixed.hpp"
+
namespace arm_gemm {
// Actual kernel implementations
@@ -48,20 +50,21 @@ public:
typedef void (*kern_type)(const float *, const float *, float *, int, int, int);
- /* Describes the data layout for A input */
- static const int A_interleave = 8;
- static const int A_block = 1;
- static const int A_transpose = 0;
+ /* Kernel blocking parameters */
+ static int out_width() {
+ return 12;
+ }
- /* Same for B input */
- static const int B_interleave = 12;
- static const int B_block = 1;
- static const int B_transpose = 1;
+ static int out_height() {
+ return 8;
+ }
- /* Kernel blocking parameters */
- static const int out_width = 12;
- static const int out_height = 8;
- static const int k_unroll = 1;
+ static int k_unroll() {
+ return 1;
+ }
+
+ // Use the standard fixed size transforms.
+ StdTransformsFixed<operand_type, result_type, 8, 12> transforms = {};
kern_type kernel=a64_sgemm_asimd_12x8;
diff --git a/src/core/NEON/kernels/arm_gemm/mergeresults.hpp b/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
index b1e2ca1daa..04d1343b1c 100644
--- a/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
+++ b/src/core/NEON/kernels/arm_gemm/mergeresults.hpp
@@ -32,15 +32,19 @@
namespace arm_gemm {
-template<unsigned int width, unsigned int height, typename Tin, typename Tout>
+template<unsigned int twidth, unsigned int height, bool sve=false, typename Tin, typename Tout>
inline void MergeResults(Tout * out, const Tin * in, int ldc, int y0, int ymax, int x0, int xmax, const Tout alpha, const Tout beta) {
- int full_y_blocks = (ymax - y0) / height;
- int y_remainder = (ymax - y0) % height;
- int y_blocks = full_y_blocks + (y_remainder ? 1 : 0);
+ // For SVE cases, multiply the width up by the vector length.
+ // Use the *input* type to determine this, since this will be what the kernel operated on.
+ const int width = twidth * (sve ? get_vector_length<Tin>() : 1);
- int full_x_blocks = (xmax - x0) / width;
- int x_remainder = (xmax - x0) % width;
- int x_blocks = full_x_blocks + (x_remainder ? 1 : 0);
+ const int full_y_blocks = (ymax - y0) / height;
+ const int y_remainder = (ymax - y0) % height;
+ const int y_blocks = full_y_blocks + (y_remainder ? 1 : 0);
+
+ const int full_x_blocks = (xmax - x0) / width;
+ const int x_remainder = (xmax - x0) % width;
+ const int x_blocks = full_x_blocks + (x_remainder ? 1 : 0);
for (int y_block = 0; y_block < y_blocks; y_block++) {
int ybase = y0 + (y_block * height);
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
index 2b833937a8..f4485bcbb1 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
@@ -28,7 +28,7 @@
#include <arm_neon.h>
template<>
-inline void MergeResults<8, 6>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) {
+inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) {
const float *inptr = in;
prefetch_6x(inptr);
prefetch_6x(inptr + 96);
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
index f6befa2d14..be23978b80 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
@@ -26,7 +26,7 @@
#ifdef __aarch64__
template<>
-inline void MergeResults<12, 8>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) {
+inline void MergeResults<12, 8, false>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float alpha, const float beta) {
const float *inptr = in;
prefetch_6x(inptr);
prefetch_6x(inptr + 96);
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
index e7a7521823..9e5eb88dc1 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
@@ -29,7 +29,7 @@
#include <arm_neon.h>
template<>
-inline void MergeResults<12,8>(__fp16 *out, const float *in, int ldout, int y0, int ymax, int x0, int xmax, const __fp16 alpha, const __fp16 beta) {
+inline void MergeResults<12,8,false>(__fp16 *out, const float *in, int ldout, int y0, int ymax, int x0, int xmax, const __fp16 alpha, const __fp16 beta) {
const float *inptr = in;
prefetch_6x(inptr);
prefetch_6x(inptr + 24);
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
index 1a51505a25..ee32ce7630 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
@@ -26,7 +26,7 @@
#ifdef __aarch64__
template<>
-inline void MergeResults<12, 8>(int32_t *out, const int32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const int32_t alpha, const int32_t beta) {
+inline void MergeResults<12, 8, false>(int32_t *out, const int32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const int32_t alpha, const int32_t beta) {
const int32_t *inptr = in;
prefetch_6x(inptr);
prefetch_6x(inptr + 96);
diff --git a/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
new file mode 100644
index 0000000000..44124a7b41
--- /dev/null
+++ b/src/core/NEON/kernels/arm_gemm/std_transforms_fixed.hpp
@@ -0,0 +1,69 @@
+/*
+ * Copyright (c) 2017-2018 ARM Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#pragma once
+
+namespace arm_gemm {
+
+/*
+ * Define "standard" transforms for the blocked GEMMs with fixed vector
+ * length.
+ *
+ * This assumes that A is interleaved 'height' ways, B is interleaved
+ * 'width' ways and transposed, and that the merge needs to work in 'height'
+ * x 'width' blocks.
+ *
+ * The optional 'block' parameter is for kernels using dot-product type
+ * instructions like UDOT and SDOT.
+ */
+template<typename TOperand, typename TResult, unsigned int height, unsigned int width, unsigned int block=1>
+class StdTransformsFixed
+{
+public:
+ template<typename TIn>
+ void PrepareA(TOperand *out, const TIn *in, const int stride, const int y0,
+ const int ymax, const int k0, const int kmax, bool transposed) {
+ if (transposed) {
+ Transform<height, block, true>(out, in, stride, y0, ymax, k0, kmax);
+ } else {
+ Transform<height, block, false>(out, in, stride, y0, ymax, k0, kmax);
+ }
+ }
+
+ template<typename TIn>
+ void PrepareB(TOperand *out, const TIn *in, const int stride, const int x0,
+ const int xmax, const int k0, const int kmax, bool transposed) {
+ if (transposed) {
+ Transform<width, block, false>(out, in, stride, x0, xmax, k0, kmax);
+ } else {
+ Transform<width, block, true>(out, in, stride, x0, xmax, k0, kmax);
+ }
+ }
+
+ template<typename TOut>
+ void Merge(TOut *out, const TResult *in, int stride, int y0, int ymax, int x0, int xmax, const TOut alpha, const TOut beta) {
+ MergeResults<width, height>(out, in, stride, y0, ymax, x0, xmax, alpha, beta);
+ }
+};
+
+} // namespace arm_gemm
diff --git a/src/core/NEON/kernels/arm_gemm/transform.hpp b/src/core/NEON/kernels/arm_gemm/transform.hpp
index 35e61b05a4..77d0d87a4d 100644
--- a/src/core/NEON/kernels/arm_gemm/transform.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transform.hpp
@@ -34,11 +34,14 @@
* Need to cope with the work requested in either dimension not actually
* being a multiple of the block sizes.
*/
-template <unsigned IntBy, unsigned int BlockBy, bool Transposed, size_t TOutSize, size_t TInSize>
+template <unsigned int tIntBy, unsigned int BlockBy, bool Transposed, size_t TOutSize, size_t TInSize, bool sve>
struct TransformImpl {
template <typename TOut, typename TIn>
static void Transform(TOut* out, const TIn* const in, const int stride,
const int y0, const int ymax, const int x0, const int xmax) {
+ // For SVE cases we multiply the interleave factor by the vector length.
+ const unsigned int IntBy = tIntBy * (sve ? get_vector_length<TOut>() : 1);
+
const int n_whole_y_blocks = (ymax - y0) / IntBy;
const int y_remainders = (ymax - y0) % IntBy;
const int n_y_blocks = n_whole_y_blocks + (y_remainders ? 1 : 0);
@@ -95,17 +98,16 @@ struct TransformImpl {
};
/*****************************************************************************/
-template <unsigned int IntBy, unsigned int BlockBy, bool Transposed, typename TOut, typename TIn>
+template <unsigned int IntBy, unsigned int BlockBy, bool Transposed, bool sve=false, typename TOut, typename TIn>
void Transform(
TOut* out, const TIn* const in, const int stride,
const int k0, const int kmax, const int x0, const int xmax
) {
// Redirect to a specialised implementation predicated on argument size.
- TransformImpl<IntBy, BlockBy, Transposed, sizeof(TOut), sizeof(TIn)>::Transform(
+ TransformImpl<IntBy, BlockBy, Transposed, sizeof(TOut), sizeof(TIn), sve>::Transform(
out, in, stride, k0, kmax, x0, xmax
);
}
/*****************************************************************************/
#include "transforms/list.hpp"
-
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
index e485ca7009..492abe51ed 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
@@ -31,7 +31,7 @@
template<>
template<typename T>
-inline void TransformImpl<6, 1, false, 4, 4>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
+inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
uint32_t *outptr = reinterpret_cast<uint32_t *>(out);
const uint32_t *inptr = reinterpret_cast<const uint32_t *>(in);
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp
index a7e17fa074..56a226fce0 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_transpose_interleave_8way_32bit.hpp
@@ -30,12 +30,12 @@
// Generic unblocked transposed 8x32-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<8, 1, true, 4, 4>::Transform(
+inline void TransformImpl<8, 1, true, 4, 4, false>::Transform(
T* out, const T* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
// Redirect to a 16x uint16_t specialisation
- TransformImpl<16, 1, true, 2, 2>::Transform(
+ TransformImpl<16, 1, true, 2, 2, false>::Transform(
reinterpret_cast<uint16_t *>(out),
reinterpret_cast<const uint16_t * const>(in),
stride*2, x0*2, xmax*2, k0, kmax
@@ -45,7 +45,7 @@ inline void TransformImpl<8, 1, true, 4, 4>::Transform(
// Generic 12x16-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<16, 1, true, 2, 2>::Transform(
+inline void TransformImpl<16, 1, true, 2, 2, false>::Transform(
T* out, const T* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
@@ -117,7 +117,7 @@ inline void TransposeInterleaveCommon<16, uint16_t, uint16_t>::moveblock_1x4(con
template <>
template <>
-inline void TransformImpl<16, 1, true, 2, 2>::Transform(
+inline void TransformImpl<16, 1, true, 2, 2, false>::Transform(
uint16_t* out, const uint16_t* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
index 7e61f425d4..8ea0483a50 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
@@ -32,7 +32,7 @@
template<>
template<typename T>
-void TransformImpl<4, 16, false, 1, 1>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
+void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
uint8_t *outptr = (uint8_t *)out;
const uint8_t *inptr = (uint8_t *)in;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
index 99bb2d66bd..91ee49229b 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
@@ -31,7 +31,7 @@
template<>
template<typename T>
-void TransformImpl<8, 1, false, 2, 2>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
+void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
uint16_t *outptr = (uint16_t *)out;
const uint16_t *inptr = (const uint16_t *)in;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
index 83391cc59f..7a32f331ea 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
@@ -31,7 +31,7 @@
template<>
template<typename T>
-inline void TransformImpl<8, 1, false, 4, 4>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
+inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T *in, int ldin, int y0, int ymax, int k0, int kmax) {
uint32_t *outptr = (uint32_t *)out;
const uint32_t *inptr = (uint32_t *)in;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
index fd812165fd..773d56d913 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
@@ -31,7 +31,7 @@
template<>
template<>
-inline void TransformImpl<8, 1, false, 4, 2>::Transform(float *out, const __fp16 *in, int ldin, int y0, int ymax, int k0, int kmax) {
+inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const __fp16 *in, int ldin, int y0, int ymax, int k0, int kmax) {
float *outptr = out;
const __fp16 *inptr = in;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp
index 6e07064a0c..16fa31eb67 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_16bit.hpp
@@ -30,12 +30,12 @@
// Generic unblocked transposed 6x32-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<6, 1, true, 4, 4>::Transform(
+inline void TransformImpl<6, 1, true, 4, 4, false>::Transform(
T* out, const T* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
// Redirect to a 12 x uint16_t specialisation
- TransformImpl<12, 1, true, 2, 2>::Transform(
+ TransformImpl<12, 1, true, 2, 2, false>::Transform(
reinterpret_cast<uint16_t *>(out),
reinterpret_cast<const uint16_t * const>(in),
stride*2, x0*2, xmax*2, k0, kmax
@@ -45,7 +45,7 @@ inline void TransformImpl<6, 1, true, 4, 4>::Transform(
// Generic 12x16-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<12, 1, true, 2, 2>::Transform(
+inline void TransformImpl<12, 1, true, 2, 2, false>::Transform(
T* out, const T* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
@@ -135,7 +135,7 @@ inline void TransposeInterleaveCommon<12, uint16_t, uint16_t>::moveblock_1x4(con
template <>
template <>
-inline void TransformImpl<12, 1, true, 2, 2>::Transform(
+inline void TransformImpl<12, 1, true, 2, 2, false>::Transform(
uint16_t* out, const uint16_t* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp
index 2f90c18ebd..46b4bf5149 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_12way_half_to_float.hpp
@@ -110,7 +110,7 @@ inline void TransposeInterleaveCommon<12, __fp16, float>::moveblock_1x4(const __
template <>
template <>
-inline void TransformImpl<12, 1, true, 4, 2>::Transform(
+inline void TransformImpl<12, 1, true, 4, 2, false>::Transform(
float* out, const __fp16* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp
index b6565baa23..c39dd82119 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_transpose_interleave_24way_16bit.hpp
@@ -30,12 +30,12 @@
// Generic unblocked transposed 12x32-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<12, 1, true, 4, 4>::Transform(
+inline void TransformImpl<12, 1, true, 4, 4, false>::Transform(
T* out, const T* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
// Redirect to a 24 x uint16_t specialisation
- TransformImpl<24, 1, true, 2, 2>::Transform(
+ TransformImpl<24, 1, true, 2, 2, false>::Transform(
reinterpret_cast<uint16_t *>(out),
reinterpret_cast<const uint16_t * const>(in),
stride*2, x0*2, xmax*2, k0, kmax
@@ -45,7 +45,7 @@ inline void TransformImpl<12, 1, true, 4, 4>::Transform(
// Generic 24x16-bit sized specialisation
template <>
template <typename T>
-inline void TransformImpl<24, 1, true, 2, 2>::Transform(
+inline void TransformImpl<24, 1, true, 2, 2, false>::Transform(
T* out, const T* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
@@ -120,7 +120,7 @@ inline void TransposeInterleaveCommon<24, uint16_t, uint16_t>::moveblock_1x4(con
template <>
template <>
-inline void TransformImpl<24, 1, true, 2, 2>::Transform(
+inline void TransformImpl<24, 1, true, 2, 2, false>::Transform(
uint16_t* out, const uint16_t* const in, const int stride,
const int x0, const int xmax, const int k0, const int kmax
) {
diff --git a/src/core/NEON/kernels/arm_gemm/utils.hpp b/src/core/NEON/kernels/arm_gemm/utils.hpp
index c1977d5f3e..b77bc7a566 100644
--- a/src/core/NEON/kernels/arm_gemm/utils.hpp
+++ b/src/core/NEON/kernels/arm_gemm/utils.hpp
@@ -44,3 +44,10 @@ inline T roundup(const T a, const T b) {
return a;
}
}
+
+template <typename T>
+inline unsigned long get_vector_length() {
+ const unsigned long length = 16;
+
+ return length / sizeof(T);
+}