From 2d7e683e79c8ad328d4930c1f82a46827313faf4 Mon Sep 17 00:00:00 2001 From: George Wort Date: Fri, 22 Feb 2019 16:37:41 +0000 Subject: COMPMID-1694: Fuse offset contribution with the output stage when we use NEGEMMLowpMatrixMultiplyCore Change-Id: Ic1a681e4cc03e1eba3bf8485d9cdb17b3e926047 Signed-off-by: giuros01 Reviewed-on: https://review.mlplatform.org/c/561 Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- .../kernels/CLGEMMLowpOffsetContributionKernel.h | 4 +- arm_compute/core/NEON/NEAsymm.h | 50 ++++++++ arm_compute/core/NEON/NEAsymm.inl | 9 +- arm_compute/core/NEON/NEKernels.h | 1 + ...NEGEMMLowpOffsetContributionOutputStageKernel.h | 136 +++++++++++++++++++++ .../CL/functions/CLGEMMLowpMatrixMultiplyCore.h | 2 +- .../NEON/functions/NEGEMMConvolutionLayer.h | 37 +++--- .../NEON/functions/NEGEMMLowpMatrixMultiplyCore.h | 60 +++++---- 8 files changed, 254 insertions(+), 45 deletions(-) create mode 100644 arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h (limited to 'arm_compute') diff --git a/arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h b/arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h index e6b79176b5..c66a470a75 100644 --- a/arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMLowpOffsetContributionKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -58,7 +58,7 @@ public: CLGEMMLowpOffsetContributionKernel &operator=(CLGEMMLowpOffsetContributionKernel &&) = default; /** Initialise the kernel's input and output. * - * @param[in, out] mm_result Input tensor containing the result of @ref CLGEMMLowpMatrixMultiplyKernel + * @param[in, out] mm_result Input tensor containing the result of @ref CLGEMMLowpMatrixMultiplyKernel. Data type supported: S32 * @param[in] vector_sum_col Input row-vector of sums of all the entries in each column of matrix B. * Note: vector_sum_col can be a nullptr in case a_offset = 0. Data type supported: same as @p mm_result * @param[in] vector_sum_row Input row-vector of sums of all the entries in each row of matrix A. diff --git a/arm_compute/core/NEON/NEAsymm.h b/arm_compute/core/NEON/NEAsymm.h index c7f59e9eba..997f28f1f0 100644 --- a/arm_compute/core/NEON/NEAsymm.h +++ b/arm_compute/core/NEON/NEAsymm.h @@ -45,6 +45,17 @@ using qasymm8x16_t = uint8x16_t; /**< 8 bit quantized asymmetric vector with 1 */ int32x4_t rounding_divide_by_pow2(int32x4_t x, int exponent); +/** Round to the nearest division by a power-of-two using exponent + * + * @note This function calculates the following expression: (x + 2^n -1 ) / 2^n where n = exponent + * + * @param[in] x Element to divide. + * @param[in] exponent Integer value used to round to nearest division by a power-of-two + * + * @return the nearest division by a power-of-two using exponent + */ +int32_t rounding_divide_by_pow2(int32_t x, int exponent); + /** Perform a multiply-accumulate on all 16 components of a QASYMM8 vector * * vd*vs + vo @@ -125,6 +136,45 @@ uint8x16_t finalize_quantization(int32x4x4_t &in_s32, return out_u8; } +/** Performs final quantization step on single element + * + * @tparam is_bounded_relu Specified if a fused bounded relu should be applied + * + * @param[in] in_value Input to be quantized. + * @param[in] result_fixedpoint_multiplier Result multiplier parameter + * @param[in] result_shift Result shift parameter + * @param[in] result_offset_after_shift_s32 Result offset parameter + * @param[in] min_u8 Relu lower bound + * @param[in] max_u8 Relu upper bound + * + * @return Quantized value + */ +template +inline uint8_t finalize_quantization(int32_t in_value, int result_fixedpoint_multiplier, + int32_t result_shift, int32_t result_offset_after_shift_s32, + uint8_t min_u8, uint8_t max_u8) +{ + int32x4_t in_s32 = vdupq_n_s32(in_value); + + // Fixed point multiplication with vector saturating rounding doubling multiply high with scalar + in_value = vgetq_lane_s32(vqrdmulhq_n_s32(in_s32, result_fixedpoint_multiplier), 0); + + // Shift value by result_shift_s32 + in_value = rounding_divide_by_pow2(in_value, result_shift); + + // Add the offset term + in_value += result_offset_after_shift_s32; + + // Bound the result + uint8_t out_u8 = static_cast(std::max(0, std::min(255, in_value))); + if(is_bounded_relu) + { + out_u8 = static_cast(std::max(min_u8, std::min(max_u8, out_u8))); + } + + return out_u8; +} + /** Dequantize a neon vector holding 16 quantized values. * * @param qv Input values to be dequantized. diff --git a/arm_compute/core/NEON/NEAsymm.inl b/arm_compute/core/NEON/NEAsymm.inl index ce999a5413..209785d94e 100644 --- a/arm_compute/core/NEON/NEAsymm.inl +++ b/arm_compute/core/NEON/NEAsymm.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -31,6 +31,13 @@ inline int32x4_t rounding_divide_by_pow2(int32x4_t x, int exponent) return vrshlq_s32(fixed_up_x, shift_vec); } +inline int32_t rounding_divide_by_pow2(int32_t x, int exponent) +{ + const int32_t mask = (1 << exponent) - 1; + const int32_t threshold = (mask >> 1) + (x < 0 ? 1 : 0); + return (x >> exponent) + ((x & mask) > threshold ? 1 : 0); +} + inline qasymm8x16_t vmlaq_qasymm8(qasymm8x16_t vd, float32x4_t vs, float32x4_t vo) { // Convert uint8 vectors to uint16 vectors diff --git a/arm_compute/core/NEON/NEKernels.h b/arm_compute/core/NEON/NEKernels.h index f1d94c89db..5b1b701a9d 100644 --- a/arm_compute/core/NEON/NEKernels.h +++ b/arm_compute/core/NEON/NEKernels.h @@ -73,6 +73,7 @@ #include "arm_compute/core/NEON/kernels/NEGEMMInterleave4x4Kernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpMatrixMultiplyKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.h" +#include "arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpReductionKernel.h" diff --git a/arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h b/arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h new file mode 100644 index 0000000000..c284ca5c5f --- /dev/null +++ b/arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2019 ARM Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef __ARM_COMPUTE_NEGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H__ +#define __ARM_COMPUTE_NEGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H__ + +#include "arm_compute/core/NEON/INEKernel.h" + +namespace arm_compute +{ +class ITensor; + +/** NEON kernel used to add the offset contribution and perform the output stage after @ref NEGEMMLowpMatrixMultiplyKernel. + * + * The computation is performed in-place + * + * This kernel takes a final int32 accumulator value (the output of @ref NEGEMMLowpMatrixMultiplyKernel), + * and adds to it the offset contribution of matrix A and matrix B in-place. + * + * The output stage can perform either QuantizeDownInt32ToUint8Scale or QuantizeDownInt32ToUint8ScaleByFixedPoint. + * + * For QuantizeDownInt32ToUint8Scale the final result is: + * + * ((mm_result'[i][k] + result_offset) * result_mult_int) >> result_shift + * + * For QuantizeDownInt32ToUint8ScaleByFixedPoint the final result is: + * + * (FixedPointMul(mm_result'[i][k], result_fixedpoint_multiplier) >> result_shift) + result_offset_after_shift + * + * where FixedPointMul(x, y) is the nearest integer to the following + * mathematical expression, evaluated without overflow or intermediate rounding: + * + * (x * y) / 2^31 + * + * and mm_result'[i][k] = mm_result[i][k] + + * (vector_sum_col[k] * a_offset) + + * (vector_sum_row[i] * b_offset) + + * (a_offset * b_offset * k) + */ + +class NEGEMMLowpOffsetContributionOutputStageKernel : public INEKernel +{ +public: + const char *name() const override + { + return "NEGEMMLowpOffsetContributionOutputStageKernel"; + } + /** Constructor */ + NEGEMMLowpOffsetContributionOutputStageKernel(); + /** Prevent instances of this class from being copied (As this class contains pointers)*/ + NEGEMMLowpOffsetContributionOutputStageKernel(const NEGEMMLowpOffsetContributionOutputStageKernel &) = delete; + /** Prevent instances of this class from being copied (As this class contains pointers)*/ + NEGEMMLowpOffsetContributionOutputStageKernel &operator=(const NEGEMMLowpOffsetContributionOutputStageKernel &) = delete; + /** Allow instances of this class to be moved */ + NEGEMMLowpOffsetContributionOutputStageKernel(NEGEMMLowpOffsetContributionOutputStageKernel &&) = default; + /** Allow instances of this class to be moved */ + NEGEMMLowpOffsetContributionOutputStageKernel &operator=(NEGEMMLowpOffsetContributionOutputStageKernel &&) = default; + /** Initialise the kernel's input and output. + * + * @param[in] mm_result Input tensor containing the result of @ref NEGEMMLowpMatrixMultiplyKernel. Data type supported: S32 + * @param[in] vector_sum_col Input row-vector of sums of all the entries in each column of matrix B. + * Note: vector_sum_col can be a nullptr in case a_offset = 0. Data type supported: same as @p mm_result + * @param[in] vector_sum_row Input row-vector of sums of all the entries in each row of matrix A. + * @param[in] bias Biases tensor. Only shared biases supported and it can be a nullptr if the addition of biases is not required. + * Biases are 1D tensor with dimensions [OFM]. Data type supported: Same as @p mm_result. + * @param[out] output Output tensor containing the final quantized result. Data type supported: QASYMM8 + * @param[in] k Number of matrix A columns or Matrix B rows + * @param[in] a_offset Offset to be added to each element of the matrix A. + * @param[in] b_offset Offset to be added to each element of the matrix B. + * @param[in] output_stage GEMMLowp output stage info, providing the type of quantization and the necessary parameters. + */ + void configure(const ITensor *mm_result, const ITensor *vector_sum_col, const ITensor *vector_sum_row, const ITensor *bias, ITensor *output, int32_t k, int32_t a_offset, int32_t b_offset, + GEMMLowpOutputStageInfo output_stage); + /** Static function to check if given info will lead to a valid configuration of @ref NEGEMMLowpOffsetContributionOutputStageKernel + * + * @param[in] mm_result Input tensor info containing the result of @ref NEGEMMLowpMatrixMultiplyKernel. Data type supported: S32 + * @param[in] vector_sum_col Tensor info for the input row-vector of sums of all the entries in each column of matrix B. + * Note: vector_sum_col can be a nullptr in case a_offset = 0. Data type supported: same as @p mm_result + * @param[in] vector_sum_row Tensor info for the input row-vector of sums of all the entries in each row of matrix A. + * Note: vector_sum_row can be a nullptr in case b_offset = 0. Data type supported: same as @p mm_result + * @param[in] bias Biases tensor info. Only shared biases supported and it can be a nullptr if the addition of biases is not required. + * Biases are 1D tensor with dimensions [OFM]. Data type supported: Same as @p mm_result. + * @param[in] output Output tensor info containing the final quantized result. Data type supported: QASYMM8 + * @param[in] a_offset Offset to be added to each element of the matrix A. + * @param[in] b_offset Offset to be added to each element of the matrix B. + * @param[in] output_stage GEMMLowp output stage info, providing the type of quantization and the necessary parameters. + * + * @return a status + */ + static Status validate(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, const ITensorInfo *bias, const ITensorInfo *output, int32_t a_offset, + int32_t b_offset, + GEMMLowpOutputStageInfo output_stage); + + // Inherited methods overridden: + void run(const Window &window, const ThreadInfo &info) override; + + using NEGEMMLowpOffsetContributionOutputStageFunction = std::function; + +private: + /** Function to use for the particular tensors passed to configure() */ + NEGEMMLowpOffsetContributionOutputStageFunction _function; + const ITensor *_vector_sum_col; + const ITensor *_vector_sum_row; + const ITensor *_bias; + const ITensor *_mm_result; + ITensor *_output; + int32_t _a_offset; + int32_t _b_offset; + int32_t _k_offset; + bool _slide_vector_sum_col; + GEMMLowpOutputStageInfo _output_stage; +}; +} // namespace arm_compute + +#endif /* __ARM_COMPUTE_NEGEMMLOWPOFFSETCONTRIBUTIONOUTPUTSTAGEKERNEL_H__ */ diff --git a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h index 4345ff267b..67b22821da 100644 --- a/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h +++ b/arm_compute/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.h @@ -71,7 +71,7 @@ public: * This kernel performs the following computations: * * -# Convert a values from QASYMM8 to int32 and add a_offset to each of them. - * -# Convert b values from QASYMM8 to int32 add b_offset to each of them. + * -# Convert b values from QASYMM8 to int32 and add b_offset to each of them. * -# Compute the matrix product of the resulting a * b in int32. * -# Quantize to uint8 if gemm_info.gemmlowp_output_stage != NONE * diff --git a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h index 6df7af0d86..ace924f146 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h +++ b/arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -151,44 +151,51 @@ private: * * @param[in] input Input tensor. Data types supported: QASYMM8/F16/F32. * @param[in] weights Weights tensor. Data type supported: Same as @p input. + * @param[in] biases Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. + * Data type supported: Should match @p input data type, except for input of QASYMM8 type where biases should be of S32 type. * @param[out] output Output tensor. Data types supported: Same as @p input, * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. Only RELU, BOUNDED_RELU and LU_BOUNDED_RELU supported. * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) */ - void configure_mm(const ITensor *input, const ITensor *weights, ITensor *output, int gemm_3d_depth = 1); + void configure_mm(const ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(), int gemm_3d_depth = 1); /** Static function to check if given info will lead to a valid configuration of @ref NEGEMMConvolutionLayer matrix multiply routines * * @param[in] input Input tensor. Data types supported: QASYMM8/F16/F32. * @param[in] weights Weights tensor. Data type supported: Same as @p input. + * @param[in] biases Biases tensor. Shared biases supported. Biases are 1D tensor with dimensions [OFM]. + * Data type supported: Should match @p input data type, except for input of QASYMM8 type where biases should be of S32 type. * @param[in] output Output tensor. Data types supported: Same as @p input, * except for input of QASYMM8 type where output should be of S32 type. + * @param[in] act_info (Optional) Activation layer information in case of a fused activation. Only RELU, BOUNDED_RELU and LU_BOUNDED_RELU supported. * @param[in] gemm_3d_depth (Optional) Depth of GEMM 3D (Defaults to 1) * @param[in] skip_im2col (Optional) Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout. (Default to false) * * @return a status */ - static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *output, int gemm_3d_depth = 1, bool skip_im2col = false); + static Status validate_mm(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ActivationLayerInfo &act_info = ActivationLayerInfo(), + int gemm_3d_depth = 1, bool skip_im2col = false); /** Static function to check if GEMM3D is supported in @ref NEGEMM or in @ref NEGEMMLowpMatrixMultiplyCore * - * @param[in] data_type Input data type + * @param[in] input_info Input tensor info. Data types supported: QASYMM8/F16/F32. + * @param[in] act_info Activation layer information in case of a fused activation. Only RELU, BOUNDED_RELU and LU_BOUNDED_RELU supported. * @param[in] gemm_3d_depth Depth of GEMM 3D * @param[in] skip_im2col Flag which specifies if im2col has to be skipped. i.e. 1x1 convolution with NHWC data layout * * @return a status */ - static Status validate_gemm3d(DataType data_type, int gemm_3d_depth, bool skip_im2col); + static Status validate_gemm3d(const ITensorInfo *input_info, const ActivationLayerInfo &act_info, int gemm_3d_depth, bool skip_im2col); private: - MemoryGroup _memory_group; - NEConvolutionLayerReshapeWeights _reshape_weights; - NEIm2ColKernel _im2col_kernel; - NEGEMM _mm_gemm; - NEGEMMLowpMatrixMultiplyCore _mm_gemmlowp; - NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint _gemmlowp_output_stage; - NECol2ImKernel _col2im_kernel; - NEActivationLayer _activationlayer_function; - NEArithmeticAdditionKernel _add_bias_kernel; - NEReshapeLayer _reshape_layer; + MemoryGroup _memory_group; + NEConvolutionLayerReshapeWeights _reshape_weights; + NEIm2ColKernel _im2col_kernel; + NEGEMM _mm_gemm; + NEGEMMLowpMatrixMultiplyCore _mm_gemmlowp; + NECol2ImKernel _col2im_kernel; + NEActivationLayer _activationlayer_function; + NEArithmeticAdditionKernel _add_bias_kernel; + NEReshapeLayer _reshape_layer; const ITensor *_original_weights; diff --git a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h index 682475c824..d3b27e4faf 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h +++ b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2018 ARM Limited. + * Copyright (c) 2017-2019 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -26,6 +26,7 @@ #include "arm_compute/core/NEON/INEKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.h" +#include "arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionOutputStageKernel.h" #include "arm_compute/core/NEON/kernels/NEGEMMLowpReductionKernel.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" @@ -73,20 +74,24 @@ public: * -# Convert b values from QASYMM8 to int32 add b_offset to each of them. * -# Compute the matrix product of the resulting a * b in int32. * + * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8 otherwise + * * @param[in] a First input tensor (Matrix A). Data type supported: QASYMM8. * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32 - * @param[out] output Output tensor. Data type supported: Data type supported: S32 + * @param[out] output Output tensor. Data type supported: Data type supported: S32/QASYMM8 * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and * if the reshape of matrix B should be executed only for the first run */ void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *output, const GEMMInfo &gemm_info = GEMMInfo()); /** Static function to check if given info will lead to a valid configuration of @ref NEGEMMLowpMatrixMultiplyCore * - * @param[in] a First input tensor (Matrix A). Data type supported: QASYMM8. - * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a - * @param[in] c Third input tensor (Matrix C). It can be a nullptr. Data type supported: S32 - * @param[in] output Output tensor. Data type supported: Data type supported: S32 + * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8 otherwise + * + * @param[in] a First input tensor info (Matrix A). Data type supported: QASYMM8. + * @param[in] b Second input tensor info (Matrix B). Data type supported: same as @p a + * @param[in] c Third input tensor info (Matrix C). It can be a nullptr. Data type supported: S32 + * @param[in] output Output tensor info. Data type supported: Data type supported: S32/QASYMM8 * @param[in] gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and * if the reshape of matrix B should be executed only for the first run * @@ -99,25 +104,28 @@ public: void prepare() override; private: - MemoryGroup _memory_group; - NEGEMMAssemblyDispatch _asm_glue; - std::unique_ptr _mm_kernel; - std::unique_ptr _mtx_a_reshape_kernel; - std::unique_ptr _mtx_b_reshape_kernel; - NEGEMMLowpMatrixAReductionKernel _mtx_a_reduction_kernel; - NEGEMMLowpMatrixBReductionKernel _mtx_b_reduction_kernel; - NEGEMMLowpOffsetContributionKernel _offset_contribution_kernel; - Tensor _vector_sum_col; - Tensor _vector_sum_row; - Tensor _tmp_a; - Tensor _tmp_b; - const ITensor *_original_b; - int32_t _a_offset; - int32_t _b_offset; - bool _run_vector_matrix_multiplication; - bool _dot_product_path; - bool _reshape_b_only_on_first_run; - bool _is_prepared; + MemoryGroup _memory_group; + NEGEMMAssemblyDispatch _asm_glue; + std::unique_ptr _mm_kernel; + std::unique_ptr _mtx_a_reshape_kernel; + std::unique_ptr _mtx_b_reshape_kernel; + NEGEMMLowpMatrixAReductionKernel _mtx_a_reduction_kernel; + NEGEMMLowpMatrixBReductionKernel _mtx_b_reduction_kernel; + NEGEMMLowpOffsetContributionKernel _offset_contribution_kernel; + NEGEMMLowpOffsetContributionOutputStageKernel _offset_contribution_output_stage_kernel; + Tensor _vector_sum_col; + Tensor _vector_sum_row; + Tensor _tmp_a; + Tensor _tmp_b; + Tensor _mm_result_s32; + const ITensor *_original_b; + int32_t _a_offset; + int32_t _b_offset; + bool _run_vector_matrix_multiplication; + bool _dot_product_path; + bool _reshape_b_only_on_first_run; + bool _is_prepared; + bool _fuse_output_stage; }; -} +} // namespace arm_compute #endif /*__ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H__ */ -- cgit v1.2.1