diff options
Diffstat (limited to 'arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h')
-rw-r--r-- | arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h | 33 |
1 files changed, 22 insertions, 11 deletions
diff --git a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h index c81a432295..0c441df4b9 100644 --- a/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h +++ b/arm_compute/runtime/NEON/functions/NEGEMMLowpMatrixMultiplyCore.h @@ -25,6 +25,8 @@ #define __ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H__ #include "arm_compute/core/NEON/INEKernel.h" +#include "arm_compute/core/NEON/kernels/NEGEMMLowpOffsetContributionKernel.h" +#include "arm_compute/core/NEON/kernels/NEGEMMLowpReductionKernel.h" #include "arm_compute/runtime/IFunction.h" #include "arm_compute/runtime/IMemoryManager.h" #include "arm_compute/runtime/MemoryGroup.h" @@ -41,11 +43,13 @@ class ITensor; * -# @ref NEGEMMInterleave4x4Kernel * -# @ref NEGEMMTranspose1xWKernel * -# @ref NEGEMMLowpMatrixMultiplyKernel + * -# @ref NEGEMMLowpOffsetContributionKernel * * otherwise if the DOT product instruction is available: * * -# @ref NEGEMMInterleaveBlockedKernel * -# @ref NEGEMMLowpAArch64V8P4Kernel + * -# @ref NEGEMMLowpOffsetContributionKernel * */ class NEGEMMLowpMatrixMultiplyCore : public IFunction @@ -58,11 +62,11 @@ public: * @note GEMM_LOWP: low precision GEMM kernel * This kernel performs the following computations: * - * -# Convert a values from uint8 to int32 - * -# Convert b values from uint8 to int32 - * -# Compute the int32 matrix product of the resulting a * b. + * -# Convert a values from QASYMM8 to int32 and add a_offset to each of them. + * -# Convert b values from QASYMM8 to int32 add b_offset to each of them. + * -# Compute the matrix product of the resulting a * b in int32. * - * @param[in] a First input tensor (Matrix A). Data type supported: U8. + * @param[in] a First input tensor (Matrix A). Data type supported: QASYMM8. * @param[in] b Second input tensor (Matrix B). Data type supported: same as @p a * @param[out] output Output tensor. Data type supported: Data type supported: S32 */ @@ -72,13 +76,20 @@ public: void run() override; private: - MemoryGroup _memory_group; - std::unique_ptr<INEKernel> _mm_kernel; - std::unique_ptr<INEKernel> _mtx_a_reshape_kernel; - std::unique_ptr<INEKernel> _mtx_b_reshape_kernel; - Tensor _tmp_a; - Tensor _tmp_b; - Tensor _workspace; + MemoryGroup _memory_group; + std::unique_ptr<INEKernel> _mm_kernel; + std::unique_ptr<INEKernel> _mtx_a_reshape_kernel; + std::unique_ptr<INEKernel> _mtx_b_reshape_kernel; + NEGEMMLowpMatrixAReductionKernel _mtx_a_reduction_kernel; + NEGEMMLowpMatrixBReductionKernel _mtx_b_reduction_kernel; + NEGEMMLowpOffsetContributionKernel _offset_contribution_kernel; + Tensor _vector_sum_col; + Tensor _vector_sum_row; + Tensor _tmp_a; + Tensor _tmp_b; + Tensor _workspace; + int32_t _a_offset; + int32_t _b_offset; }; } #endif /*__ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H__ */ |