From 37d080f2f11cfd734104b76512e1fb191486216e Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 21 Jun 2019 18:43:12 +0100 Subject: COMPMID-2378: Sanitize GEMM configuration for NEON Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Michalis Spyrou Tested-by: Arm Jenkins --- .../arm_gemm/merges/a32_merge_float_8x6.hpp | 4 +++ .../arm_gemm/merges/a64_merge_float_12x8.hpp | 6 ++++ .../merges/a64_merge_float_to_half_12x8.hpp | 6 ++++ .../arm_gemm/merges/a64_merge_half_24x8.hpp | 6 ++++ .../arm_gemm/merges/a64_merge_int32_12x8.hpp | 6 ++++ .../transforms/a32_interleave_6way_32bit.hpp | 4 +++ .../transforms/a64_block16_interleave4_8bit.hpp | 4 +++ .../transforms/a64_interleave_8way_16bit.hpp | 6 ++++ .../transforms/a64_interleave_8way_32bit.hpp | 6 ++++ .../a64_interleave_8way_half_to_float.hpp | 6 ++++ .../NEON/kernels/assembly/INEGEMMWrapperKernel.cpp | 25 ++++++++++----- .../kernels/assembly/NEGEMMInterleavedStrategies.h | 37 ++++++++++++---------- .../kernels/assembly/NEGEMMNativeWrapperKernel.cpp | 18 ++++++++--- 13 files changed, 104 insertions(+), 30 deletions(-) (limited to 'src/core/NEON/kernels') diff --git a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp index f4485bcbb1..e1af2d4490 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp @@ -61,12 +61,16 @@ inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldo switch ((y + 5) - ymax) { case 4: outptr1 = dummyres; + // fall through case 3: outptr2 = dummyres; + // fall through case 2: outptr3 = dummyres; + // fall through case 1: outptr4 = dummyres; + // fall through case 0: outptr5 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp index be23978b80..9fca4e3a84 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp @@ -63,16 +63,22 @@ inline void MergeResults<12, 8, false>(float *out, const float *in, const int ld switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; + // fall through case 5: outptr2 = dummyres; + // fall through case 4: outptr3 = dummyres; + // fall through case 3: outptr4 = dummyres; + // fall through case 2: outptr5 = dummyres; + // fall through case 1: outptr6 = dummyres; + // fall through case 0: outptr7 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp index 9e5eb88dc1..0e638eef1c 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp @@ -66,16 +66,22 @@ inline void MergeResults<12,8,false>(__fp16 *out, const float *in, int ldout, in switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; + // fall through case 5: outptr2 = dummyres; + // fall through case 4: outptr3 = dummyres; + // fall through case 3: outptr4 = dummyres; + // fall through case 2: outptr5 = dummyres; + // fall through case 1: outptr6 = dummyres; + // fall through case 0: outptr7 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp index 3ed43b10bd..60cc2f32da 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp @@ -65,16 +65,22 @@ inline void MergeResults<24, 8>(__fp16 *out, const __fp16 *in, const int ldout, switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; + // fall through case 5: outptr2 = dummyres; + // fall through case 4: outptr3 = dummyres; + // fall through case 3: outptr4 = dummyres; + // fall through case 2: outptr5 = dummyres; + // fall through case 1: outptr6 = dummyres; + // fall through case 0: outptr7 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp index 35d4cc5d73..0212dfdbb6 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp @@ -63,16 +63,22 @@ inline void MergeResults<12, 8, false>(int32_t *out, const int32_t *in, const in switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; + // fall through case 5: outptr2 = dummyres; + // fall through case 4: outptr3 = dummyres; + // fall through case 3: outptr4 = dummyres; + // fall through case 2: outptr5 = dummyres; + // fall through case 1: outptr6 = dummyres; + // fall through case 0: outptr7 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp index 20ad301a18..a460fdfcf4 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp @@ -60,12 +60,16 @@ inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T * /* Everything falls through in here */ case 4: inptr1 = zerobuff; + // fall through case 3: inptr2 = zerobuff; + // fall through case 2: inptr3 = zerobuff; + // fall through case 1: inptr4 = zerobuff; + // fall through case 0: inptr5 = zerobuff; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp index 2f513a6118..6a15fc42e4 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp @@ -57,8 +57,10 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in /* Everything falls through in here */ case 2: inptr1 = zerobuff; + // fall through case 1: inptr2 = zerobuff; + // fall through case 0: inptr3 = zerobuff; break; @@ -93,8 +95,10 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in /* Everything falls through in here */ case 2: inptr1 = zerobuff; + // fall through case 1: inptr2 = zerobuff; + // fall through case 0: inptr3 = zerobuff; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp index 27136d144a..0028ab08a9 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp @@ -64,16 +64,22 @@ void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int /* Everything falls through in here */ case 6: inptr1 = zerobuff; + // fall through case 5: inptr2 = zerobuff; + // fall through case 4: inptr3 = zerobuff; + // fall through case 3: inptr4 = zerobuff; + // fall through case 2: inptr5 = zerobuff; + // fall through case 1: inptr6 = zerobuff; + // fall through case 0: inptr7 = zerobuff; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp index 54822c81b0..758c084a46 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp @@ -64,16 +64,22 @@ inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T * /* Everything falls through in here */ case 6: inptr1 = zerobuff; + // fall through case 5: inptr2 = zerobuff; + // fall through case 4: inptr3 = zerobuff; + // fall through case 3: inptr4 = zerobuff; + // fall through case 2: inptr5 = zerobuff; + // fall through case 1: inptr6 = zerobuff; + // fall through case 0: inptr7 = zerobuff; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp index 0606330d27..de8e95a6d7 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp @@ -64,16 +64,22 @@ inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const /* Everything falls through in here */ case 6: inptr1 = zerobuff; + // fall through case 5: inptr2 = zerobuff; + // fall through case 4: inptr3 = zerobuff; + // fall through case 3: inptr4 = zerobuff; + // fall through case 2: inptr5 = zerobuff; + // fall through case 1: inptr6 = zerobuff; + // fall through case 0: inptr7 = zerobuff; break; diff --git a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp index 0fc3610014..d00f204b81 100644 --- a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp +++ b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -33,11 +33,11 @@ using namespace arm_compute; INEGEMMWrapperKernel::INEGEMMWrapperKernel() - : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _window3d(), _window_shape() + : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _gemm_info(), _window3d(), _window_shape() { } -INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c) +INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info) { Params p; @@ -45,21 +45,30 @@ INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITen ARM_COMPUTE_ERROR_ON_NULLPTR(b); ARM_COMPUTE_ERROR_ON_NULLPTR(c); + // Initalize params p.M = c->info()->tensor_shape().y(); p.N = c->info()->tensor_shape().x(); p.K = a->info()->tensor_shape().x(); p.multis = b->info()->tensor_shape().z(); p.batches = c->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs + // Update M in case of GEMM3D for output + if(gemm_info.depth_output_gemm3d() != 0) + { + p.M = c->info()->tensor_shape().y() * c->info()->tensor_shape().z(); + p.batches = c->info()->tensor_shape().total_size_upper(3) / p.multis; + } + return p; } -void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta) +void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info) { - _params = extract_parameters(a, b, c); - _a = a; - _b = b; - _c = c; + _gemm_info = gemm_info; + _params = extract_parameters(a, b, c, gemm_info); + _a = a; + _b = b; + _c = c; _window3d = configure_internal(alpha, beta); _window_shape = _window3d.shape(); diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h index 26d9e9999d..6e30148b5d 100644 --- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h +++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h @@ -76,32 +76,34 @@ public: * @param[in] transformed_a Reshaped tensor A. * @param[in] block_walker Window representing the layout of the matrix's blocks. * @param[in] params M, N, K sizes. + * @param[in] gemm_info GEMM meta-data * * @return A wrapped specialized transformA kernel */ virtual std::unique_ptr instantiate_transformA(const ITensor *a, ITensor *transformed_a, const Window &block_walker, - const INEGEMMWrapperKernel::Params ¶ms) = 0; + const INEGEMMWrapperKernel::Params ¶ms, + const GEMMInfo &gemm_info) = 0; /** Instantiate and configure a prepareB Kernel * - * @param transformed_a Already reshaped tensor A. - * @param transformed_b Already reshaped tensor B. - * @param tmp_c Temporary buffer to be used to store intermediate results. - * @param c Result tensor C. - * @param block_walker Window containing iteration information for the M and batch dimensions. - * @param block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes). - * @param params M, N, K sizes. - * @param alpha Alpha value - * @param beta Beta value - * @param pretranspose_b Is B also pretransposed ? - * @param num_threads Maximum number of threads that might be used for the calculations. + * @param[in] transformed_a Already reshaped tensor A. + * @param[in] transformed_b Already reshaped tensor B. + * @param[in] tmp_c Temporary buffer to be used to store intermediate results. + * @param[in] c Result tensor C. + * @param[in] block_walker Window containing iteration information for the M and batch dimensions. + * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes). + * @param[in] params M, N, K sizes. + * @param[in] alpha Alpha value + * @param[in] beta Beta value + * @param[in] gemm_info GEMM meta-data + * @param[in] num_threads Maximum number of threads that might be used for the calculations. * * @return A wrapped specialized MatrixMultiply kernel */ virtual std::unique_ptr instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes, - const INEGEMMWrapperKernel::Params ¶ms, float alpha, float beta, bool pretranspose_b, + const INEGEMMWrapperKernel::Params ¶ms, float alpha, float beta, const GEMMInfo &gemm_info, unsigned int num_threads) = 0; /** Calculates the block sizes of a given strategy * @@ -138,19 +140,20 @@ public: std::unique_ptr instantiate_transformA(const ITensor *a, ITensor *transformed_a, const Window &block_walker, - const INEGEMMWrapperKernel::Params ¶ms) override + const INEGEMMWrapperKernel::Params ¶ms, + const GEMMInfo &gemm_info) override { auto transform_a = support::cpp14::make_unique>(); - transform_a->configure(a, transformed_a, false, block_walker, params); + transform_a->configure(a, transformed_a, false, gemm_info.reinterpret_input_as_3d(), block_walker, params); return std::move(transform_a); } std::unique_ptr instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes, - const INEGEMMWrapperKernel::Params ¶ms, float alpha, float beta, bool pretranspose_b, + const INEGEMMWrapperKernel::Params ¶ms, float alpha, float beta, const GEMMInfo &gemm_info, unsigned int num_threads) override { auto matrix_multiply = support::cpp14::make_unique>(); - matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, num_threads); + matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, gemm_info, alpha, beta, num_threads); return std::move(matrix_multiply); } diff --git a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp index 97c20dbd4e..ecdb5a938c 100644 --- a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp +++ b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp @@ -81,12 +81,20 @@ void NEGEMMNativeWrapperKernel::run_internal(const Window &window, const TensorAccessor b(*_b); TensorAccessor c(*_c); - if(_a->info()->data_layout() == DataLayout::NHWC) + // Handle 3d input re-interpretation + if(_gemm_info.reinterpret_input_as_3d()) { - // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is - // the relevant multiple of the row stride. - const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _c->info()->dimension(1); - a.set_stride(2, nhwc_batch_stride); + Strides a_strides_as_3d = _a->info()->strides_in_bytes(); + a_strides_as_3d.remove(Window::DimZ); + a.set_strides(a_strides_as_3d); + } + + // Handle 3d output re-interpretation + if(_gemm_info.depth_output_gemm3d() != 0) + { + Strides c_strides_as_3d = _c->info()->strides_in_bytes(); + c_strides_as_3d.remove(Window::DimZ); + c.set_strides(c_strides_as_3d); } unsigned int m_end = 0; -- cgit v1.2.1