aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2019-06-21 18:43:12 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2019-07-05 15:30:24 +0000
commit37d080f2f11cfd734104b76512e1fb191486216e (patch)
treed5df067c826aacc0676e7e9557a54b61a9a3b7eb /src/core/NEON/kernels
parent11de30da8a9f79943255ddba7bb70a66b076673b (diff)
downloadComputeLibrary-37d080f2f11cfd734104b76512e1fb191486216e.tar.gz
COMPMID-2378: Sanitize GEMM configuration for NEON
Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels')
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp4
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp6
-rw-r--r--src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp6
-rw-r--r--src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp25
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h37
-rw-r--r--src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp18
13 files changed, 104 insertions, 30 deletions
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
index f4485bcbb1..e1af2d4490 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp
@@ -61,12 +61,16 @@ inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldo
switch ((y + 5) - ymax) {
case 4:
outptr1 = dummyres;
+ // fall through
case 3:
outptr2 = dummyres;
+ // fall through
case 2:
outptr3 = dummyres;
+ // fall through
case 1:
outptr4 = dummyres;
+ // fall through
case 0:
outptr5 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
index be23978b80..9fca4e3a84 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp
@@ -63,16 +63,22 @@ inline void MergeResults<12, 8, false>(float *out, const float *in, const int ld
switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
+ // fall through
case 5:
outptr2 = dummyres;
+ // fall through
case 4:
outptr3 = dummyres;
+ // fall through
case 3:
outptr4 = dummyres;
+ // fall through
case 2:
outptr5 = dummyres;
+ // fall through
case 1:
outptr6 = dummyres;
+ // fall through
case 0:
outptr7 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
index 9e5eb88dc1..0e638eef1c 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp
@@ -66,16 +66,22 @@ inline void MergeResults<12,8,false>(__fp16 *out, const float *in, int ldout, in
switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
+ // fall through
case 5:
outptr2 = dummyres;
+ // fall through
case 4:
outptr3 = dummyres;
+ // fall through
case 3:
outptr4 = dummyres;
+ // fall through
case 2:
outptr5 = dummyres;
+ // fall through
case 1:
outptr6 = dummyres;
+ // fall through
case 0:
outptr7 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
index 3ed43b10bd..60cc2f32da 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp
@@ -65,16 +65,22 @@ inline void MergeResults<24, 8>(__fp16 *out, const __fp16 *in, const int ldout,
switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
+ // fall through
case 5:
outptr2 = dummyres;
+ // fall through
case 4:
outptr3 = dummyres;
+ // fall through
case 3:
outptr4 = dummyres;
+ // fall through
case 2:
outptr5 = dummyres;
+ // fall through
case 1:
outptr6 = dummyres;
+ // fall through
case 0:
outptr7 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
index 35d4cc5d73..0212dfdbb6 100644
--- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
+++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp
@@ -63,16 +63,22 @@ inline void MergeResults<12, 8, false>(int32_t *out, const int32_t *in, const in
switch ((y + 7) - ymax) {
case 6:
outptr1 = dummyres;
+ // fall through
case 5:
outptr2 = dummyres;
+ // fall through
case 4:
outptr3 = dummyres;
+ // fall through
case 3:
outptr4 = dummyres;
+ // fall through
case 2:
outptr5 = dummyres;
+ // fall through
case 1:
outptr6 = dummyres;
+ // fall through
case 0:
outptr7 = dummyres;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
index 20ad301a18..a460fdfcf4 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp
@@ -60,12 +60,16 @@ inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T *
/* Everything falls through in here */
case 4:
inptr1 = zerobuff;
+ // fall through
case 3:
inptr2 = zerobuff;
+ // fall through
case 2:
inptr3 = zerobuff;
+ // fall through
case 1:
inptr4 = zerobuff;
+ // fall through
case 0:
inptr5 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
index 2f513a6118..6a15fc42e4 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp
@@ -57,8 +57,10 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
+ // fall through
case 1:
inptr2 = zerobuff;
+ // fall through
case 0:
inptr3 = zerobuff;
break;
@@ -93,8 +95,10 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in
/* Everything falls through in here */
case 2:
inptr1 = zerobuff;
+ // fall through
case 1:
inptr2 = zerobuff;
+ // fall through
case 0:
inptr3 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
index 27136d144a..0028ab08a9 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp
@@ -64,16 +64,22 @@ void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int
/* Everything falls through in here */
case 6:
inptr1 = zerobuff;
+ // fall through
case 5:
inptr2 = zerobuff;
+ // fall through
case 4:
inptr3 = zerobuff;
+ // fall through
case 3:
inptr4 = zerobuff;
+ // fall through
case 2:
inptr5 = zerobuff;
+ // fall through
case 1:
inptr6 = zerobuff;
+ // fall through
case 0:
inptr7 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
index 54822c81b0..758c084a46 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp
@@ -64,16 +64,22 @@ inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T *
/* Everything falls through in here */
case 6:
inptr1 = zerobuff;
+ // fall through
case 5:
inptr2 = zerobuff;
+ // fall through
case 4:
inptr3 = zerobuff;
+ // fall through
case 3:
inptr4 = zerobuff;
+ // fall through
case 2:
inptr5 = zerobuff;
+ // fall through
case 1:
inptr6 = zerobuff;
+ // fall through
case 0:
inptr7 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
index 0606330d27..de8e95a6d7 100644
--- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
+++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp
@@ -64,16 +64,22 @@ inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const
/* Everything falls through in here */
case 6:
inptr1 = zerobuff;
+ // fall through
case 5:
inptr2 = zerobuff;
+ // fall through
case 4:
inptr3 = zerobuff;
+ // fall through
case 3:
inptr4 = zerobuff;
+ // fall through
case 2:
inptr5 = zerobuff;
+ // fall through
case 1:
inptr6 = zerobuff;
+ // fall through
case 0:
inptr7 = zerobuff;
break;
diff --git a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
index 0fc3610014..d00f204b81 100644
--- a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018 ARM Limited.
+ * Copyright (c) 2018-2019 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -33,11 +33,11 @@
using namespace arm_compute;
INEGEMMWrapperKernel::INEGEMMWrapperKernel()
- : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _window3d(), _window_shape()
+ : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _gemm_info(), _window3d(), _window_shape()
{
}
-INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c)
+INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info)
{
Params p;
@@ -45,21 +45,30 @@ INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITen
ARM_COMPUTE_ERROR_ON_NULLPTR(b);
ARM_COMPUTE_ERROR_ON_NULLPTR(c);
+ // Initalize params
p.M = c->info()->tensor_shape().y();
p.N = c->info()->tensor_shape().x();
p.K = a->info()->tensor_shape().x();
p.multis = b->info()->tensor_shape().z();
p.batches = c->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs
+ // Update M in case of GEMM3D for output
+ if(gemm_info.depth_output_gemm3d() != 0)
+ {
+ p.M = c->info()->tensor_shape().y() * c->info()->tensor_shape().z();
+ p.batches = c->info()->tensor_shape().total_size_upper(3) / p.multis;
+ }
+
return p;
}
-void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta)
+void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info)
{
- _params = extract_parameters(a, b, c);
- _a = a;
- _b = b;
- _c = c;
+ _gemm_info = gemm_info;
+ _params = extract_parameters(a, b, c, gemm_info);
+ _a = a;
+ _b = b;
+ _c = c;
_window3d = configure_internal(alpha, beta);
_window_shape = _window3d.shape();
diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
index 26d9e9999d..6e30148b5d 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
+++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h
@@ -76,32 +76,34 @@ public:
* @param[in] transformed_a Reshaped tensor A.
* @param[in] block_walker Window representing the layout of the matrix's blocks.
* @param[in] params M, N, K sizes.
+ * @param[in] gemm_info GEMM meta-data
*
* @return A wrapped specialized transformA kernel
*/
virtual std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a,
ITensor *transformed_a,
const Window &block_walker,
- const INEGEMMWrapperKernel::Params &params) = 0;
+ const INEGEMMWrapperKernel::Params &params,
+ const GEMMInfo &gemm_info) = 0;
/** Instantiate and configure a prepareB Kernel
*
- * @param transformed_a Already reshaped tensor A.
- * @param transformed_b Already reshaped tensor B.
- * @param tmp_c Temporary buffer to be used to store intermediate results.
- * @param c Result tensor C.
- * @param block_walker Window containing iteration information for the M and batch dimensions.
- * @param block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
- * @param params M, N, K sizes.
- * @param alpha Alpha value
- * @param beta Beta value
- * @param pretranspose_b Is B also pretransposed ?
- * @param num_threads Maximum number of threads that might be used for the calculations.
+ * @param[in] transformed_a Already reshaped tensor A.
+ * @param[in] transformed_b Already reshaped tensor B.
+ * @param[in] tmp_c Temporary buffer to be used to store intermediate results.
+ * @param[in] c Result tensor C.
+ * @param[in] block_walker Window containing iteration information for the M and batch dimensions.
+ * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes).
+ * @param[in] params M, N, K sizes.
+ * @param[in] alpha Alpha value
+ * @param[in] beta Beta value
+ * @param[in] gemm_info GEMM meta-data
+ * @param[in] num_threads Maximum number of threads that might be used for the calculations.
*
* @return A wrapped specialized MatrixMultiply kernel
*/
virtual std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c,
const Window &block_walker, const BlockSizes &block_sizes,
- const INEGEMMWrapperKernel::Params &params, float alpha, float beta, bool pretranspose_b,
+ const INEGEMMWrapperKernel::Params &params, float alpha, float beta, const GEMMInfo &gemm_info,
unsigned int num_threads) = 0;
/** Calculates the block sizes of a given strategy
*
@@ -138,19 +140,20 @@ public:
std::unique_ptr<NEGEMMInterleavedTransformAWrapper> instantiate_transformA(const ITensor *a,
ITensor *transformed_a,
const Window &block_walker,
- const INEGEMMWrapperKernel::Params &params) override
+ const INEGEMMWrapperKernel::Params &params,
+ const GEMMInfo &gemm_info) override
{
auto transform_a = support::cpp14::make_unique<NEGEMMInterleavedTransformAWrapperTemplate<strategy>>();
- transform_a->configure(a, transformed_a, false, block_walker, params);
+ transform_a->configure(a, transformed_a, false, gemm_info.reinterpret_input_as_3d(), block_walker, params);
return std::move(transform_a);
}
std::unique_ptr<NEGEMMInterleavedMatrixMultiplyWrapper> instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c,
const Window &block_walker, const BlockSizes &block_sizes,
- const INEGEMMWrapperKernel::Params &params, float alpha, float beta, bool pretranspose_b,
+ const INEGEMMWrapperKernel::Params &params, float alpha, float beta, const GEMMInfo &gemm_info,
unsigned int num_threads) override
{
auto matrix_multiply = support::cpp14::make_unique<NEGEMMInterleavedMatrixMultiplyWrapperTemplate<strategy>>();
- matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, num_threads);
+ matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, gemm_info, alpha, beta, num_threads);
return std::move(matrix_multiply);
}
diff --git a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
index 97c20dbd4e..ecdb5a938c 100644
--- a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
+++ b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp
@@ -81,12 +81,20 @@ void NEGEMMNativeWrapperKernel<To, Tr>::run_internal(const Window &window, const
TensorAccessor<To> b(*_b);
TensorAccessor<Tr> c(*_c);
- if(_a->info()->data_layout() == DataLayout::NHWC)
+ // Handle 3d input re-interpretation
+ if(_gemm_info.reinterpret_input_as_3d())
{
- // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is
- // the relevant multiple of the row stride.
- const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _c->info()->dimension(1);
- a.set_stride(2, nhwc_batch_stride);
+ Strides a_strides_as_3d = _a->info()->strides_in_bytes();
+ a_strides_as_3d.remove(Window::DimZ);
+ a.set_strides(a_strides_as_3d);
+ }
+
+ // Handle 3d output re-interpretation
+ if(_gemm_info.depth_output_gemm3d() != 0)
+ {
+ Strides c_strides_as_3d = _c->info()->strides_in_bytes();
+ c_strides_as_3d.remove(Window::DimZ);
+ c.set_strides(c_strides_as_3d);
}
unsigned int m_end = 0;