From 37d080f2f11cfd734104b76512e1fb191486216e Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 21 Jun 2019 18:43:12 +0100 Subject: COMPMID-2378: Sanitize GEMM configuration for NEON Change-Id: I7859b82b2059e14685f8792424648ac5eacd67f1 Signed-off-by: Georgios Pinitas Reviewed-on: https://review.mlplatform.org/c/1418 Comments-Addressed: Arm Jenkins Reviewed-by: Michele Di Giorgio Reviewed-by: Michalis Spyrou Tested-by: Arm Jenkins --- SConstruct | 15 ++--- arm_compute/core/Dimensions.h | 23 +++++++- .../NEON/kernels/assembly/INEGEMMWrapperKernel.h | 18 +++--- .../NEGEMMInterleavedMatrixMultiplyWrapper.h | 55 +++++++++-------- .../assembly/NEGEMMInterleavedTransformAWrapper.h | 38 ++++++------ arm_compute/core/Types.h | 64 +++++++++++++++----- arm_compute/core/WindowIterator.h | 11 +++- .../NEON/functions/NEGEMMAssemblyDispatch.h | 46 +++++++-------- .../functions/assembly/NEGEMMInterleavedWrapper.h | 14 ++--- .../arm_gemm/merges/a32_merge_float_8x6.hpp | 4 ++ .../arm_gemm/merges/a64_merge_float_12x8.hpp | 6 ++ .../merges/a64_merge_float_to_half_12x8.hpp | 6 ++ .../arm_gemm/merges/a64_merge_half_24x8.hpp | 6 ++ .../arm_gemm/merges/a64_merge_int32_12x8.hpp | 6 ++ .../transforms/a32_interleave_6way_32bit.hpp | 4 ++ .../transforms/a64_block16_interleave4_8bit.hpp | 4 ++ .../transforms/a64_interleave_8way_16bit.hpp | 6 ++ .../transforms/a64_interleave_8way_32bit.hpp | 6 ++ .../a64_interleave_8way_half_to_float.hpp | 6 ++ .../NEON/kernels/assembly/INEGEMMWrapperKernel.cpp | 25 +++++--- .../kernels/assembly/NEGEMMInterleavedStrategies.h | 37 ++++++------ .../kernels/assembly/NEGEMMNativeWrapperKernel.cpp | 18 ++++-- src/runtime/NEON/functions/NEGEMM.cpp | 10 ++-- .../NEON/functions/NEGEMMAssemblyDispatch.cpp | 68 +++++++++++++--------- .../NEGEMMLowpAssemblyMatrixMultiplyCore.cpp | 2 +- .../functions/NEGEMMLowpMatrixMultiplyCore.cpp | 9 ++- .../assembly/NEGEMMInterleavedWrapper.cpp | 12 ++-- 27 files changed, 338 insertions(+), 181 deletions(-) diff --git a/SConstruct b/SConstruct index f18b13782c..5fa6cdfc48 100644 --- a/SConstruct +++ b/SConstruct @@ -187,18 +187,15 @@ elif env['arch'] == 'arm64-v8a': env.Append(CXXFLAGS = ['-no-integrated-as']) elif 'arm64-v8.2-a' in env['arch']: if env['arch'] == 'arm64-v8.2-a-sve': - if env['os'] != 'bare_metal': - print("Only bare metal SVE is supported at the moment") - Exit(1) env.Append(CXXFLAGS = ['-march=armv8.2-a+sve+fp16+dotprod']) else: env.Append(CXXFLAGS = ['-march=armv8.2-a+fp16']) # explicitly enable fp16 extension otherwise __ARM_FEATURE_FP16_VECTOR_ARITHMETIC is undefined - if env['os'] == 'linux': - prefix = "aarch64-linux-gnu-" - elif env['os'] == 'bare_metal': - prefix = "aarch64-elf-" - elif env['os'] == 'android': - prefix = "aarch64-linux-android-" + if env['os'] == 'linux': + prefix = "aarch64-linux-gnu-" + elif env['os'] == 'bare_metal': + prefix = "aarch64-elf-" + elif env['os'] == 'android': + prefix = "aarch64-linux-android-" env.Append(CPPDEFINES = ['ARM_COMPUTE_AARCH64_V8_2','NO_DOT_IN_TOOLCHAIN']) if 'clang++' in cpp_compiler: env.Append(CXXFLAGS = ['-no-integrated-as']) diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h index 0a9264f6b0..9c38c60779 100644 --- a/arm_compute/core/Dimensions.h +++ b/arm_compute/core/Dimensions.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -166,6 +166,27 @@ public: collapse(num_dimensions() - start, start); } + /** Remove dimension of a given index + * + * @note If index is greater than the number of dimensions no operation is performed + * + * @param[in] idx Dimension index to remove + */ + void remove(size_t idx) + { + ARM_COMPUTE_ERROR_ON(_num_dimensions < 1); + if(idx >= _num_dimensions) + { + return; + } + + std::copy(_id.begin() + idx + 1, _id.end(), _id.begin() + idx); + _num_dimensions--; + + // Make sure all empty dimensions are filled with 0 + std::fill(_id.begin() + _num_dimensions, _id.end(), 0); + } + /** Returns a read/write iterator that points to the first element in the dimension array. * * @return an iterator. diff --git a/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h b/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h index 63178a738a..352f73d7f1 100644 --- a/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h +++ b/arm_compute/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -45,7 +45,7 @@ public: unsigned int multis{ 0 }; /**< Number of "multi" GEMMs (unique A, B and C). */ }; - static Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c); + static Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info); /** Constructor */ INEGEMMWrapperKernel(); @@ -61,13 +61,14 @@ public: * * @note The input and output tensor must have the same dimensions * - * @param[in] a Input tensor (Matrix A) - * @param[in] b Input tensor (Matrix B) - * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] alpha Scalar multiplier to apply to AB matrix product. - * @param[in] beta Scalar multiplier to apply to input C matrix before adding product. + * @param[in] a Input tensor (Matrix A) + * @param[in] b Input tensor (Matrix B) + * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. + * @param[in] alpha Scalar multiplier to apply to AB matrix product. + * @param[in] beta Scalar multiplier to apply to input C matrix before adding product. + * @param[in] gemm_info GEMM meta-data */ - void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta); + void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info); // Inherited methods overridden: void run(const Window &window, const ThreadInfo &info) override; @@ -95,6 +96,7 @@ protected: const ITensor *_b; ITensor *_c; Params _params; + GEMMInfo _gemm_info; private: Window _window3d; diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h index e2b849aa3d..40b6f5da39 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedMatrixMultiplyWrapper.h @@ -95,31 +95,32 @@ class NEGEMMInterleavedMatrixMultiplyWrapperTemplate : public NEGEMMInterleavedM public: /** Configure the matrix multiplication: C = alpha * A * B + beta * C * - * @param[in] prepared_a Already reshaped matrix A. - * @param[in] transformed_b Already reshaped matrix B. - * @param[out] tmp_c Temporary buffer to be used to store intermediate results. - * @param[in,out] c Result matrix C. - * @param[in] block_walker Window containing iteration information for the M and batch dimensions. - * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes). - * @param[in] params M, N, K sizes. - * @param[in] is_pretransposed Is B also pretransposed ? - * @param[in] alpha Alpha value - * @param[in] beta Beta value - * @param[in] max_num_threads Maximum number of threads that might be used for the calculations. + * @param[in] prepared_a Already reshaped matrix A. + * @param[in] transformed_b Already reshaped matrix B. + * @param[out] tmp_c Temporary buffer to be used to store intermediate results. + * @param[in,out] c Result matrix C. + * @param[in] block_walker Window containing iteration information for the M and batch dimensions. + * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes). + * @param[in] params M, N, K sizes. + * @param[in] gemm_info GEMM meta-data + * @param[in] alpha Alpha value + * @param[in] beta Beta value + * @param[in] max_num_threads Maximum number of threads that might be used for the calculations. */ void configure(const ITensor *prepared_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes, - const INEGEMMWrapperKernel::Params ¶ms, bool b_is_pretransposed, float alpha, float beta, unsigned int max_num_threads) + const INEGEMMWrapperKernel::Params ¶ms, const GEMMInfo &gemm_info, float alpha, float beta, unsigned int max_num_threads) { - _prepared_a = prepared_a; - _transformed_b = transformed_b; - _tmp_c = tmp_c; - _c = c; - _block_walker = block_walker; - _block_sizes = block_sizes; - _params = params; - _b_is_pretransposed = b_is_pretransposed; - _alpha = alpha; - _beta = beta; + _prepared_a = prepared_a; + _transformed_b = transformed_b; + _tmp_c = tmp_c; + _c = c; + _block_walker = block_walker; + _block_sizes = block_sizes; + _params = params; + _b_is_pretransposed = gemm_info.pretranpose_B(); + _reinterpret_c_as_3d = gemm_info.depth_output_gemm3d() != 0; + _alpha = alpha; + _beta = beta; auto_init_if_empty(*_tmp_c->info(), c->info()->clone()->set_tensor_shape(TensorShape{ _block_sizes.x_block * strategy::out_height(), max_num_threads })); } @@ -133,6 +134,14 @@ public: TensorAccessor c(*_c); TensorAccessor tmp_c(*_tmp_c); + // Handle 3d output re-interpretation + if(_reinterpret_c_as_3d) + { + Strides c_strides_as_3d = _c->info()->strides_in_bytes(); + c_strides_as_3d.remove(Window::DimZ); + c.set_strides(c_strides_as_3d); + } + int prev_batch = -1; typename strategy::operand_type *a_ptr = nullptr; auto window_iterator = arm_compute::create_window_iterator(batch_window, start_offset, end_offset, [&](const Coordinates & id) @@ -216,9 +225,9 @@ private: INEGEMMWrapperKernel::Params _params{}; Window _block_walker{}; bool _b_is_pretransposed{ false }; + bool _reinterpret_c_as_3d{ false }; typename strategy::result_type _alpha{}; typename strategy::result_type _beta{}; }; - } // namespace arm_compute #endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDMATRIXMULTIPLYWRAPPER_H__ */ diff --git a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h index 5d6cd02398..b18d327339 100644 --- a/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h +++ b/arm_compute/core/NEON/kernels/assembly/NEGEMMInterleavedTransformAWrapper.h @@ -87,20 +87,22 @@ class NEGEMMInterleavedTransformAWrapperTemplate : public NEGEMMInterleavedTrans public: /** Configure the reshape A routine. * - * @param[in] a Input matrix A. - * @param[out] transformed_a Reshaped matrix A. - * @param[in] transpose_a Also transpose A ? - * @param[in] block_walker Window representing the layout of the matrix's blocks - * @param[in] params M, N, K sizes. + * @param[in] a Input matrix A. + * @param[out] transformed_a Reshaped matrix A. + * @param[in] transpose_a Also transpose A ? + * @param[in] reinterpret_a_as_3d Re-interpret as 3D ? + * @param[in] block_walker Window representing the layout of the matrix's blocks + * @param[in] params M, N, K sizes. */ - void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, const Window &block_walker, const INEGEMMWrapperKernel::Params ¶ms) + void configure(const ITensor *a, ITensor *transformed_a, bool transpose_a, bool reinterpret_a_as_3d, const Window &block_walker, const INEGEMMWrapperKernel::Params ¶ms) { - _a = a; - _transformed_a = transformed_a; - _transpose_a = transpose_a; - _Ksize = params.K; - _Msize = params.M; - _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension + _a = a; + _transformed_a = transformed_a; + _transpose_a = transpose_a; + _reinterpret_a_as_3d = reinterpret_a_as_3d; + _Ksize = params.K; + _Msize = params.M; + _k_multi_window = block_walker.shift_dimensions(1); // block_walker contains (M,K,Multi) --> shift by 1 to get rid of the "M" dimension } // Inherited methods overridden: @@ -110,12 +112,12 @@ public: TensorAccessor a(*_a); TensorAccessor transformed_a(*_transformed_a); - if(_a->info()->data_layout() == DataLayout::NHWC) + // Handle 3d input re-interpretation + if(_reinterpret_a_as_3d) { - // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is - // the relevant multiple of the row stride. - const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _Msize; - a.set_stride(2, nhwc_batch_stride); + Strides a_strides_as_3d = _a->info()->strides_in_bytes(); + a_strides_as_3d.remove(Window::DimZ); + a.set_strides(a_strides_as_3d); } unsigned int last_m = 0; @@ -164,8 +166,8 @@ private: unsigned int _Msize{ 0 }; unsigned int _Ksize{ 0 }; bool _transpose_a{ false }; + bool _reinterpret_a_as_3d{ false }; Window _k_multi_window{}; }; - } // namespace arm_compute #endif /* __ARM_COMPUTE_NEGEMMINTERLEAVEDTRANSFORMAWRAPPER_H__ */ diff --git a/arm_compute/core/Types.h b/arm_compute/core/Types.h index ad679d6786..b4d94eced4 100644 --- a/arm_compute/core/Types.h +++ b/arm_compute/core/Types.h @@ -1765,9 +1765,17 @@ class GEMMInfo { public: /** Default constructor */ - GEMMInfo() - : _is_a_reshaped(false), _is_b_reshaped(false), _reshape_b_only_on_first_run(true), _depth_output_gemm3d(0), _reinterpret_input_as_3d(false), _retain_internal_weights(false), _gemmlowp_output_stage(), - _fp_mixed_precision(false), _broadcast_bias(false) + GEMMInfo() noexcept + : _is_a_reshaped(false), + _is_b_reshaped(false), + _reshape_b_only_on_first_run(true), + _depth_output_gemm3d(0), + _reinterpret_input_as_3d(false), + _retain_internal_weights(false), + _gemmlowp_output_stage(), + _fp_mixed_precision(false), + _broadcast_bias(false), + _pretranpose_B(true) { } /** Constructor @@ -1785,10 +1793,17 @@ public: * @param[in] broadcast_bias (Optional) Broadcast the shape of the bias tensor from a vector to a matrix. */ GEMMInfo(bool is_a_reshaped, bool is_b_reshaped, bool reshape_b_only_on_first_run, int depth_output_gemm3d = 0, bool reinterpret_input_as_3d = false, bool retain_internal_weights = false, - GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false) - : _is_a_reshaped(is_a_reshaped), _is_b_reshaped(is_b_reshaped), _reshape_b_only_on_first_run(reshape_b_only_on_first_run), _depth_output_gemm3d(depth_output_gemm3d), - _reinterpret_input_as_3d(reinterpret_input_as_3d), _retain_internal_weights(retain_internal_weights), _gemmlowp_output_stage(gemmlowp_output_stage), _fp_mixed_precision(fp_mixed_precision), - _broadcast_bias(broadcast_bias) + GEMMLowpOutputStageInfo gemmlowp_output_stage = GEMMLowpOutputStageInfo(), bool fp_mixed_precision = false, bool broadcast_bias = false) noexcept + : _is_a_reshaped(is_a_reshaped), + _is_b_reshaped(is_b_reshaped), + _reshape_b_only_on_first_run(reshape_b_only_on_first_run), + _depth_output_gemm3d(depth_output_gemm3d), + _reinterpret_input_as_3d(reinterpret_input_as_3d), + _retain_internal_weights(retain_internal_weights), + _gemmlowp_output_stage(gemmlowp_output_stage), + _fp_mixed_precision(fp_mixed_precision), + _broadcast_bias(broadcast_bias), + _pretranpose_B(reshape_b_only_on_first_run) { } /** Flag which specifies if the matrix A has been reshaped @@ -1865,17 +1880,34 @@ public: { return _broadcast_bias; }; + /** Flag which specifies whether b should be pre-transposed if supported. + * + * @return True if b should be pre-transposed else false. + */ + bool pretranpose_B() const + { + return _pretranpose_B; + }; + /** Set pre-transpose b flag + * + * @param[in] flag Flag to set + */ + void set_pretranpose_B(bool flag) + { + _pretranpose_B = flag; + } private: - const bool _is_a_reshaped; - const bool _is_b_reshaped; - const bool _reshape_b_only_on_first_run; - const int _depth_output_gemm3d; - const bool _reinterpret_input_as_3d; - const bool _retain_internal_weights; - const GEMMLowpOutputStageInfo _gemmlowp_output_stage; - const bool _fp_mixed_precision; - const bool _broadcast_bias; + bool _is_a_reshaped; + bool _is_b_reshaped; + bool _reshape_b_only_on_first_run; + int _depth_output_gemm3d; + bool _reinterpret_input_as_3d; + bool _retain_internal_weights; + GEMMLowpOutputStageInfo _gemmlowp_output_stage; + bool _fp_mixed_precision; + bool _broadcast_bias; + bool _pretranpose_B; }; /** Winograd information */ diff --git a/arm_compute/core/WindowIterator.h b/arm_compute/core/WindowIterator.h index 32d6293a5a..15289b6d69 100644 --- a/arm_compute/core/WindowIterator.h +++ b/arm_compute/core/WindowIterator.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -86,6 +86,15 @@ public: _strides[dim] = size; } + /** Manually set the strides + * + * @param[in] strides Strides to set + */ + void set_strides(const Strides &strides) + { + _strides = strides; + } + /** Returns a pointer to the element at coordinates (x,y,z,w) * * @param[in] x X coordinates diff --git a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h index 2fc2cf4a99..b5a2978ea1 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h +++ b/arm_compute/runtime/NEON/functions/NEGEMMAssemblyDispatch.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -64,17 +64,17 @@ private: /** If supported create the ACL function corresponding to the GemmMethod provided to process the other passed parameters * - * @param[in] method GemmMethod to use to perform the matrix multiplication. - * @param[in] a Input tensor (Matrix A). - * @param[in] b Input tensor (Matrix B). - * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] alpha Scalar multiplier to apply to AB matrix product. - * @param[in] beta Scalar multiplier to apply to input D matrix before adding product. - * @param[in] pretransposed_hint Can the B tensor can be pretransposed (ie shared across invocations)? + * @param[in] method GemmMethod to use to perform the matrix multiplication. + * @param[in] a Input tensor (Matrix A). + * @param[in] b Input tensor (Matrix B). + * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. + * @param[in] alpha Scalar multiplier to apply to AB matrix product. + * @param[in] beta Scalar multiplier to apply to input D matrix before adding product. + * @param[in] gemm_info GEMM meta-data * * @return True if the method is supported and the function was successfully created, false otherwise. */ - bool create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint); + bool create_function(arm_gemm::GemmMethod method, const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info); /** Interface for the arm_gemm fallback */ std::unique_ptr _arm_gemm; @@ -83,27 +83,27 @@ private: public: /** If supported create an ACL function else fallback to the arm_gemm function. * - * @param[in] a Input tensor (Matrix A) - * @param[in] b Input tensor (Matrix B) - * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] alpha Scalar multiplier to apply to AB matrix product. - * @param[in] beta Scalar multiplier to apply to input D matrix before adding product. - * @param[in] pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)? + * @param[in] a Input tensor (Matrix A) + * @param[in] b Input tensor (Matrix B) + * @param[out] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. + * @param[in] alpha Scalar multiplier to apply to AB matrix product. + * @param[in] beta Scalar multiplier to apply to input D matrix before adding product. + * @param[in] gemm_info GEMM meta-data */ - void configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint); + void configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info); /** Indicates whether or not this function can be used to process the given parameters. * - * @param[in] a Input tensor (Matrix A) - * @param[in] b Input tensor (Matrix B) - * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] alpha Scalar multiplier to apply to AB matrix product. - * @param[in] beta Scalar multiplier to apply to input D matrix before adding product. - * @param[in] pretranspose_hint Can the B tensor can be pretransposed (ie shared across invocations)? + * @param[in] a Input tensor (Matrix A) + * @param[in] b Input tensor (Matrix B) + * @param[in] d Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. + * @param[in] alpha Scalar multiplier to apply to AB matrix product. + * @param[in] beta Scalar multiplier to apply to input D matrix before adding product. + * @param[in] gemm_info GEMM meta-data * * @return a status. */ - static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint); + static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info); /** Was the function successfully configured ? * * @return True if the function is configured and ready to run diff --git a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h index 949564750b..ad89e1fbec 100644 --- a/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h +++ b/arm_compute/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.h @@ -104,14 +104,14 @@ public: * * @note The input and output tensor must have the same dimensions * - * @param[in] a Input tensor (Matrix A) - * @param[in] b Input tensor (Matrix B) - * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. - * @param[in] alpha Scalar multiplier to apply to AB matrix product. - * @param[in] beta Scalar multiplier to apply to input C matrix before adding product. - * @param[in] pretranspose_b If true, pretranspose B once during the prepare() stage instead of on the fly every time. + * @param[in] a Input tensor (Matrix A) + * @param[in] b Input tensor (Matrix B) + * @param[out] c Output tensor to store the result of matrix multiplication. Data type supported: same as @p input0. + * @param[in] alpha Scalar multiplier to apply to AB matrix product. + * @param[in] beta Scalar multiplier to apply to input C matrix before adding product. + * @param[in] gemm_info GEMM meta-data */ - void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b); + void configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info); // Inherited methods overridden: void run() override; diff --git a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp index f4485bcbb1..e1af2d4490 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp @@ -61,12 +61,16 @@ inline void MergeResults<8, 6, false>(float *out, const float *in, const int ldo switch ((y + 5) - ymax) { case 4: outptr1 = dummyres; + // fall through case 3: outptr2 = dummyres; + // fall through case 2: outptr3 = dummyres; + // fall through case 1: outptr4 = dummyres; + // fall through case 0: outptr5 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp index be23978b80..9fca4e3a84 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_12x8.hpp @@ -63,16 +63,22 @@ inline void MergeResults<12, 8, false>(float *out, const float *in, const int ld switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; + // fall through case 5: outptr2 = dummyres; + // fall through case 4: outptr3 = dummyres; + // fall through case 3: outptr4 = dummyres; + // fall through case 2: outptr5 = dummyres; + // fall through case 1: outptr6 = dummyres; + // fall through case 0: outptr7 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp index 9e5eb88dc1..0e638eef1c 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_float_to_half_12x8.hpp @@ -66,16 +66,22 @@ inline void MergeResults<12,8,false>(__fp16 *out, const float *in, int ldout, in switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; + // fall through case 5: outptr2 = dummyres; + // fall through case 4: outptr3 = dummyres; + // fall through case 3: outptr4 = dummyres; + // fall through case 2: outptr5 = dummyres; + // fall through case 1: outptr6 = dummyres; + // fall through case 0: outptr7 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp index 3ed43b10bd..60cc2f32da 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_half_24x8.hpp @@ -65,16 +65,22 @@ inline void MergeResults<24, 8>(__fp16 *out, const __fp16 *in, const int ldout, switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; + // fall through case 5: outptr2 = dummyres; + // fall through case 4: outptr3 = dummyres; + // fall through case 3: outptr4 = dummyres; + // fall through case 2: outptr5 = dummyres; + // fall through case 1: outptr6 = dummyres; + // fall through case 0: outptr7 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp index 35d4cc5d73..0212dfdbb6 100644 --- a/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp +++ b/src/core/NEON/kernels/arm_gemm/merges/a64_merge_int32_12x8.hpp @@ -63,16 +63,22 @@ inline void MergeResults<12, 8, false>(int32_t *out, const int32_t *in, const in switch ((y + 7) - ymax) { case 6: outptr1 = dummyres; + // fall through case 5: outptr2 = dummyres; + // fall through case 4: outptr3 = dummyres; + // fall through case 3: outptr4 = dummyres; + // fall through case 2: outptr5 = dummyres; + // fall through case 1: outptr6 = dummyres; + // fall through case 0: outptr7 = dummyres; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp index 20ad301a18..a460fdfcf4 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a32_interleave_6way_32bit.hpp @@ -60,12 +60,16 @@ inline void TransformImpl<6, 1, false, 4, 4, false>::Transform(T *out, const T * /* Everything falls through in here */ case 4: inptr1 = zerobuff; + // fall through case 3: inptr2 = zerobuff; + // fall through case 2: inptr3 = zerobuff; + // fall through case 1: inptr4 = zerobuff; + // fall through case 0: inptr5 = zerobuff; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp index 2f513a6118..6a15fc42e4 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp @@ -57,8 +57,10 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in /* Everything falls through in here */ case 2: inptr1 = zerobuff; + // fall through case 1: inptr2 = zerobuff; + // fall through case 0: inptr3 = zerobuff; break; @@ -93,8 +95,10 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in /* Everything falls through in here */ case 2: inptr1 = zerobuff; + // fall through case 1: inptr2 = zerobuff; + // fall through case 0: inptr3 = zerobuff; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp index 27136d144a..0028ab08a9 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_16bit.hpp @@ -64,16 +64,22 @@ void TransformImpl<8, 1, false, 2, 2, false>::Transform(T *out, const T *in, int /* Everything falls through in here */ case 6: inptr1 = zerobuff; + // fall through case 5: inptr2 = zerobuff; + // fall through case 4: inptr3 = zerobuff; + // fall through case 3: inptr4 = zerobuff; + // fall through case 2: inptr5 = zerobuff; + // fall through case 1: inptr6 = zerobuff; + // fall through case 0: inptr7 = zerobuff; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp index 54822c81b0..758c084a46 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_32bit.hpp @@ -64,16 +64,22 @@ inline void TransformImpl<8, 1, false, 4, 4, false>::Transform(T *out, const T * /* Everything falls through in here */ case 6: inptr1 = zerobuff; + // fall through case 5: inptr2 = zerobuff; + // fall through case 4: inptr3 = zerobuff; + // fall through case 3: inptr4 = zerobuff; + // fall through case 2: inptr5 = zerobuff; + // fall through case 1: inptr6 = zerobuff; + // fall through case 0: inptr7 = zerobuff; break; diff --git a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp index 0606330d27..de8e95a6d7 100644 --- a/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp +++ b/src/core/NEON/kernels/arm_gemm/transforms/a64_interleave_8way_half_to_float.hpp @@ -64,16 +64,22 @@ inline void TransformImpl<8, 1, false, 4, 2, false>::Transform(float *out, const /* Everything falls through in here */ case 6: inptr1 = zerobuff; + // fall through case 5: inptr2 = zerobuff; + // fall through case 4: inptr3 = zerobuff; + // fall through case 3: inptr4 = zerobuff; + // fall through case 2: inptr5 = zerobuff; + // fall through case 1: inptr6 = zerobuff; + // fall through case 0: inptr7 = zerobuff; break; diff --git a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp index 0fc3610014..d00f204b81 100644 --- a/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp +++ b/src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018 ARM Limited. + * Copyright (c) 2018-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -33,11 +33,11 @@ using namespace arm_compute; INEGEMMWrapperKernel::INEGEMMWrapperKernel() - : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _window3d(), _window_shape() + : _a(nullptr), _b(nullptr), _c(nullptr), _params(), _gemm_info(), _window3d(), _window_shape() { } -INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c) +INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITensor *a, const ITensor *b, const ITensor *c, const GEMMInfo &gemm_info) { Params p; @@ -45,21 +45,30 @@ INEGEMMWrapperKernel::Params INEGEMMWrapperKernel::extract_parameters(const ITen ARM_COMPUTE_ERROR_ON_NULLPTR(b); ARM_COMPUTE_ERROR_ON_NULLPTR(c); + // Initalize params p.M = c->info()->tensor_shape().y(); p.N = c->info()->tensor_shape().x(); p.K = a->info()->tensor_shape().x(); p.multis = b->info()->tensor_shape().z(); p.batches = c->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs + // Update M in case of GEMM3D for output + if(gemm_info.depth_output_gemm3d() != 0) + { + p.M = c->info()->tensor_shape().y() * c->info()->tensor_shape().z(); + p.batches = c->info()->tensor_shape().total_size_upper(3) / p.multis; + } + return p; } -void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta) +void INEGEMMWrapperKernel::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info) { - _params = extract_parameters(a, b, c); - _a = a; - _b = b; - _c = c; + _gemm_info = gemm_info; + _params = extract_parameters(a, b, c, gemm_info); + _a = a; + _b = b; + _c = c; _window3d = configure_internal(alpha, beta); _window_shape = _window3d.shape(); diff --git a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h index 26d9e9999d..6e30148b5d 100644 --- a/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h +++ b/src/core/NEON/kernels/assembly/NEGEMMInterleavedStrategies.h @@ -76,32 +76,34 @@ public: * @param[in] transformed_a Reshaped tensor A. * @param[in] block_walker Window representing the layout of the matrix's blocks. * @param[in] params M, N, K sizes. + * @param[in] gemm_info GEMM meta-data * * @return A wrapped specialized transformA kernel */ virtual std::unique_ptr instantiate_transformA(const ITensor *a, ITensor *transformed_a, const Window &block_walker, - const INEGEMMWrapperKernel::Params ¶ms) = 0; + const INEGEMMWrapperKernel::Params ¶ms, + const GEMMInfo &gemm_info) = 0; /** Instantiate and configure a prepareB Kernel * - * @param transformed_a Already reshaped tensor A. - * @param transformed_b Already reshaped tensor B. - * @param tmp_c Temporary buffer to be used to store intermediate results. - * @param c Result tensor C. - * @param block_walker Window containing iteration information for the M and batch dimensions. - * @param block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes). - * @param params M, N, K sizes. - * @param alpha Alpha value - * @param beta Beta value - * @param pretranspose_b Is B also pretransposed ? - * @param num_threads Maximum number of threads that might be used for the calculations. + * @param[in] transformed_a Already reshaped tensor A. + * @param[in] transformed_b Already reshaped tensor B. + * @param[in] tmp_c Temporary buffer to be used to store intermediate results. + * @param[in] c Result tensor C. + * @param[in] block_walker Window containing iteration information for the M and batch dimensions. + * @param[in] block_sizes Block sizes to use for the matrix multiplication (A & B must have been reshaped using these same block sizes). + * @param[in] params M, N, K sizes. + * @param[in] alpha Alpha value + * @param[in] beta Beta value + * @param[in] gemm_info GEMM meta-data + * @param[in] num_threads Maximum number of threads that might be used for the calculations. * * @return A wrapped specialized MatrixMultiply kernel */ virtual std::unique_ptr instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes, - const INEGEMMWrapperKernel::Params ¶ms, float alpha, float beta, bool pretranspose_b, + const INEGEMMWrapperKernel::Params ¶ms, float alpha, float beta, const GEMMInfo &gemm_info, unsigned int num_threads) = 0; /** Calculates the block sizes of a given strategy * @@ -138,19 +140,20 @@ public: std::unique_ptr instantiate_transformA(const ITensor *a, ITensor *transformed_a, const Window &block_walker, - const INEGEMMWrapperKernel::Params ¶ms) override + const INEGEMMWrapperKernel::Params ¶ms, + const GEMMInfo &gemm_info) override { auto transform_a = support::cpp14::make_unique>(); - transform_a->configure(a, transformed_a, false, block_walker, params); + transform_a->configure(a, transformed_a, false, gemm_info.reinterpret_input_as_3d(), block_walker, params); return std::move(transform_a); } std::unique_ptr instantiate_matrix_multiply(const ITensor *transformed_a, const ITensor *transformed_b, ITensor *tmp_c, ITensor *c, const Window &block_walker, const BlockSizes &block_sizes, - const INEGEMMWrapperKernel::Params ¶ms, float alpha, float beta, bool pretranspose_b, + const INEGEMMWrapperKernel::Params ¶ms, float alpha, float beta, const GEMMInfo &gemm_info, unsigned int num_threads) override { auto matrix_multiply = support::cpp14::make_unique>(); - matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, pretranspose_b, alpha, beta, num_threads); + matrix_multiply->configure(transformed_a, transformed_b, tmp_c, c, block_walker, block_sizes, params, gemm_info, alpha, beta, num_threads); return std::move(matrix_multiply); } diff --git a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp index 97c20dbd4e..ecdb5a938c 100644 --- a/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp +++ b/src/core/NEON/kernels/assembly/NEGEMMNativeWrapperKernel.cpp @@ -81,12 +81,20 @@ void NEGEMMNativeWrapperKernel::run_internal(const Window &window, const TensorAccessor b(*_b); TensorAccessor c(*_c); - if(_a->info()->data_layout() == DataLayout::NHWC) + // Handle 3d input re-interpretation + if(_gemm_info.reinterpret_input_as_3d()) { - // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is - // the relevant multiple of the row stride. - const size_t nhwc_batch_stride = _a->info()->strides_in_bytes().y() * _c->info()->dimension(1); - a.set_stride(2, nhwc_batch_stride); + Strides a_strides_as_3d = _a->info()->strides_in_bytes(); + a_strides_as_3d.remove(Window::DimZ); + a.set_strides(a_strides_as_3d); + } + + // Handle 3d output re-interpretation + if(_gemm_info.depth_output_gemm3d() != 0) + { + Strides c_strides_as_3d = _c->info()->strides_in_bytes(); + c_strides_as_3d.remove(Window::DimZ); + c.set_strides(c_strides_as_3d); } unsigned int m_end = 0; diff --git a/src/runtime/NEON/functions/NEGEMM.cpp b/src/runtime/NEON/functions/NEGEMM.cpp index 55bcc45d12..2f36397c8e 100644 --- a/src/runtime/NEON/functions/NEGEMM.cpp +++ b/src/runtime/NEON/functions/NEGEMM.cpp @@ -58,17 +58,19 @@ void NEGEMM::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITe _run_vector_matrix_multiplication = a->info()->dimension(1) < 2; _original_b = b; - bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, _reshape_b_only_on_first_run)); + bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, gemm_info)); if(run_optimised) { if(MEMInfo::get_policy() == MemoryPolicy::MINIMIZE) { - _asm_glue.configure(a, b, d, alpha, beta, false); + GEMMInfo gemm_info_ntb = gemm_info; + gemm_info_ntb.set_pretranpose_B(false); + _asm_glue.configure(a, b, d, alpha, beta, gemm_info_ntb); } else { - _asm_glue.configure(a, b, d, alpha, beta, _reshape_b_only_on_first_run); + _asm_glue.configure(a, b, d, alpha, beta, gemm_info); } ARM_COMPUTE_ERROR_ON(!_asm_glue.is_configured()); } @@ -176,7 +178,7 @@ Status NEGEMM::validate(const ITensorInfo *a, const ITensorInfo *b, const ITenso } // Check if we need to run the optimized assembly kernel - const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, true)); + const bool run_optimised = c == nullptr && bool(NEGEMMAssemblyDispatch::validate(a, b, output, alpha, beta, gemm_info)); if(!run_optimised) { diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp index 55e067f52d..2de7d2b279 100644 --- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp +++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp @@ -36,21 +36,22 @@ namespace arm_compute namespace { std::unique_ptr create_function_all_types(const arm_gemm::KernelDescription &gemm_kernel_info, - const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint, + const ITensor *a, const ITensor *b, ITensor *d, + float alpha, float beta, const GEMMInfo &gemm_info, std::shared_ptr memory_manager) { - //Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() + // Note: It's safe to not check for FP16 support because this was already checked in NEGEMMAssemblyDispatch::configure() switch(gemm_kernel_info.method) { case arm_gemm::GemmMethod::GEMM_INTERLEAVED: { - if(!pretranspose_hint) + if(!gemm_info.pretranpose_B()) { return nullptr; } auto function = support::cpp14::make_unique(memory_manager); - function->configure(a, b, d, alpha, beta, pretranspose_hint); + function->configure(a, b, d, alpha, beta, gemm_info); return std::move(function); } #if defined(__aarch64__) @@ -59,7 +60,7 @@ std::unique_ptr create_function_all_types(const arm_gemm::KernelDescr if(gemm_kernel_info.name.find("sgemm_native_16x4") != std::string::npos) { auto kernel = support::cpp14::make_unique>(); - kernel->configure(a, b, d, alpha, beta); + kernel->configure(a, b, d, alpha, beta, gemm_info); auto function = support::cpp14::make_unique(); function->configure(std::move(kernel)); return std::move(function); @@ -83,9 +84,11 @@ public: * @param[in] b Input tensor containing the Matrix B. * @param[out] d Output tensor to store the result of matrix multiplication. * @param[in] args Matrix multiplication information. + * @param[in] gemm_info GEMM meta-data * @param[in] memory_group Memory group to be used by the function. */ - void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, MemoryGroup &memory_group); + void configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, + const GEMMInfo &gemm_info, MemoryGroup &memory_group); // Inherited methods overridden: void run() override; @@ -123,10 +126,13 @@ private: Tensor _pretranspose{}; /** Prepared flag */ bool _is_prepared{ false }; + /** GEMM meta-data */ + GEMMInfo _gemm_info{}; }; template -void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, MemoryGroup &memory_group) +void Fallback::configure(const ITensor *a, const ITensor *b, ITensor *d, arm_gemm::GemmArgs args, + const GEMMInfo &gemm_info, MemoryGroup &memory_group) { arm_gemm::GemmConfig gemm_cfg; const arm_gemm::KernelDescription gemm_kernel_info = arm_gemm::get_gemm_method(args); @@ -168,6 +174,7 @@ void Fallback::configure(const ITensor *a, const ITensor _a = a; _b = b; _d = d; + _gemm_info = gemm_info; // Check for pre-transposed support if(_gemm_kernel_asm->B_pretranspose_required()) { @@ -222,17 +229,17 @@ void Fallback::run() int ldb = 0; const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput); - // In the case of NHWC we want to interpret the output shape as 3D. Thus, the batch stride for A is - // the relevant multiple of the row stride. - const bool is_nhwc = _a->info()->data_layout() == DataLayout::NHWC; - const int stride_in_bytes_a = is_nhwc ? _a->info()->strides_in_bytes().y() * _d->info()->dimension(1) : _a->info()->strides_in_bytes().z(); + const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d() != 0 ? 3 : 2; + const size_t a_multi_idx = a_batch_idx + 1; + const size_t d_batch_idx = _gemm_info.depth_output_gemm3d() != 0 ? 3 : 2; + const size_t d_multi_idx = d_batch_idx + 1; - const int batch_stride_a = stride_in_bytes_a / sizeof(TypeInput); - const int batch_stride_d = _d->info()->strides_in_bytes().z() / sizeof(TypeOutput); + const int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput); + const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput); - const int multi_stride_a = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput); + const int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput); int multi_stride_b = 0; - const int multi_stride_d = _d->info()->strides_in_bytes()[3] / sizeof(TypeOutput); + const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput); const auto in0_ptr = reinterpret_cast(_a->buffer() + _a->info()->offset_first_element_in_bytes()); const TypeInput *in1_ptr = nullptr; @@ -270,24 +277,27 @@ void Fallback::run() } template -void create_function_or_arm_gemm(std::unique_ptr &acl_function, std::unique_ptr &arm_gemm, MemoryGroup &memory_group, const ITensor *a, const ITensor *b, - ITensor *d, float alpha, float beta, bool pretranspose_hint, std::shared_ptr memory_manager) +void create_function_or_arm_gemm(std::unique_ptr &acl_function, + std::unique_ptr &arm_gemm, + MemoryGroup &memory_group, const ITensor *a, const ITensor *b, + ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info, + std::shared_ptr memory_manager) { - INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d); + INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info); const CPUInfo &ci = NEScheduler::get().cpu_info(); unsigned int num_threads = NEScheduler::get().num_threads(); - arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, pretranspose_hint); + arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, false, false, alpha, beta, num_threads, gemm_info.pretranpose_B()); //Try to create an ACL function: - acl_function = create_function_all_types(arm_gemm::get_gemm_method(args), a, b, d, alpha, beta, pretranspose_hint, std::move(memory_manager)); + acl_function = create_function_all_types(arm_gemm::get_gemm_method(args), a, b, d, alpha, beta, gemm_info, std::move(memory_manager)); //If we still don't have an ACL function: if(acl_function == nullptr) { //Fallback onto arm_gemm function if ACL doesn't support this method. auto fallback = support::cpp14::make_unique>(); - fallback->configure(a, b, d, args, memory_group); + fallback->configure(a, b, d, args, gemm_info, memory_group); arm_gemm = std::move(fallback); } } @@ -299,11 +309,11 @@ NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr m { } -Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, bool pretranspose_hint) +Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_UNUSED(alpha); ARM_COMPUTE_UNUSED(beta); - ARM_COMPUTE_UNUSED(pretranspose_hint); + ARM_COMPUTE_UNUSED(gemm_info); ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d); ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a); #ifndef __aarch64__ @@ -319,14 +329,14 @@ Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo return Status{}; } -void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, bool pretranspose_hint) +void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITensor *d, float alpha, float beta, const GEMMInfo &gemm_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(a); ARM_COMPUTE_ERROR_ON_NULLPTR(b); ARM_COMPUTE_ERROR_ON_NULLPTR(d); //If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured() - if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, pretranspose_hint)) + if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), d->info(), alpha, beta, gemm_info)) { return; } @@ -334,20 +344,20 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, ITens switch(a->info()->data_type()) { case DataType::F32: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager); break; #ifdef __aarch64__ case DataType::U8: case DataType::QASYMM8: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager); break; case DataType::S8: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager); break; #endif /* __aarch64__ */ #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC case DataType::F16: - create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, pretranspose_hint, _memory_manager); + create_function_or_arm_gemm(_function, _arm_gemm, _memory_group, a, b, d, alpha, beta, gemm_info, _memory_manager); break; #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ default: diff --git a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp index ede89bf558..5b70c8724c 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpAssemblyMatrixMultiplyCore.cpp @@ -59,7 +59,7 @@ void NEGEMMLowpAssemblyMatrixMultiplyCore::configure(const ITensor *a, const ITe case DataType::QASYMM8: case DataType::U8: { - _asm_glue.configure(a, b, output, 1.f, 0.f, true); + _asm_glue.configure(a, b, output, 1.f, 0.f, GEMMInfo(false, false, true)); run_optimised = _asm_glue.is_configured(); break; } diff --git a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp index d8773e37ab..f10f114287 100644 --- a/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp +++ b/src/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.cpp @@ -87,7 +87,7 @@ void NEGEMMLowpMatrixMultiplyCore::configure(const ITensor *a, const ITensor *b, case DataType::U8: case DataType::S8: { - _asm_glue.configure(a, b, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, _reshape_b_only_on_first_run); + _asm_glue.configure(a, b, _fuse_output_stage ? &_mm_result_s32 : output, 1.f, 0.f, gemm_info); _dot_product_path = _asm_glue.is_configured(); break; } @@ -224,9 +224,8 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso TensorInfo tmp_b_info{}; TensorInfo mm_result_s32_info{}; - int32_t a_offset = a->quantization_info().uniform().offset; - int32_t b_offset = b->quantization_info().uniform().offset; - const bool reshape_b_only_on_first_run = gemm_info.reshape_b_only_on_first_run(); + int32_t a_offset = a->quantization_info().uniform().offset; + int32_t b_offset = b->quantization_info().uniform().offset; bool fuse_output_stage = gemm_info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE; if(fuse_output_stage) @@ -235,7 +234,7 @@ Status NEGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso } // Check if we need to run the optimized assembly kernel - const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, reshape_b_only_on_first_run)); + const bool run_optimised = bool(NEGEMMAssemblyDispatch::validate(a, b, fuse_output_stage ? &mm_result_s32_info : output, 1.f, 0.f, gemm_info)); if(run_optimised) { diff --git a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp index 20aa1496b6..ac809fa142 100644 --- a/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp +++ b/src/runtime/NEON/functions/assembly/NEGEMMInterleavedWrapper.cpp @@ -339,19 +339,19 @@ void NEGEMMInterleavedWrapper::prepare() } } -void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, bool pretranspose_b) +void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITensor *c, float alpha, float beta, const GEMMInfo &gemm_info) { - _params = INEGEMMWrapperKernel::extract_parameters(a, b, c); + _params = INEGEMMWrapperKernel::extract_parameters(a, b, c, gemm_info); _a = a; _b = b; _c = c; - _pretranspose_b = pretranspose_b; + _pretranspose_b = gemm_info.pretranpose_B(); const DataType input_type = a->info()->data_type(); const CPUInfo &ci = NEScheduler::get().cpu_info(); const unsigned int num_threads = NEScheduler::get().num_threads(); - const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, pretranspose_b); + const arm_gemm::KernelDescription gemm_kernel_info = get_gemm_info(input_type, ci, num_threads, _params, alpha, beta, _pretranspose_b); ARM_COMPUTE_ERROR_ON(gemm_kernel_info.method != arm_gemm::GemmMethod::GEMM_INTERLEAVED); // Forcing 128-byte alignment (required by 32-bit kernels) @@ -411,8 +411,8 @@ void NEGEMMInterleavedWrapper::configure(const ITensor *a, const ITensor *b, ITe _memory_group.manage(&_transformed_a); _memory_group.manage(&_tmp_c); - _transform_a = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params); - _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, pretranspose_b, num_threads); + _transform_a = strategy->instantiate_transformA(_a, &_transformed_a, _block_walker, _params, gemm_info); + _matrix_multiply = strategy->instantiate_matrix_multiply(&_transformed_a, &_transformed_b, &_tmp_c, c, _block_walker, _block_sizes, _params, alpha, beta, gemm_info, num_threads); ARM_COMPUTE_ERROR_ON(_transform_a == nullptr); ARM_COMPUTE_ERROR_ON(_matrix_multiply == nullptr); -- cgit v1.2.1