From f9179d393a07eb9eed753e315df79d22391906c6 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Wed, 27 Nov 2019 16:17:30 +0000 Subject: COMPMID-2793: Add support for QASYMM8_SIGNED in CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel Change-Id: I8abfdd3372cc394b98ec038b9fcb4abfe9216894 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/2401 Reviewed-by: Giorgio Arena Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- arm_compute/core/CL/CLHelpers.h | 8 ++++++++ .../CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h | 10 +++++----- 2 files changed, 13 insertions(+), 5 deletions(-) (limited to 'arm_compute') diff --git a/arm_compute/core/CL/CLHelpers.h b/arm_compute/core/CL/CLHelpers.h index 8801af579e..cd65eafc9c 100644 --- a/arm_compute/core/CL/CLHelpers.h +++ b/arm_compute/core/CL/CLHelpers.h @@ -74,6 +74,14 @@ std::string get_cl_unsigned_type_from_element_size(size_t element_size); */ std::string get_cl_select_type_from_data_type(const DataType &dt); +/** Translates a tensor data type to the appropriate OpenCL dot8 accumulator type. + * + * @param[in] dt @ref DataType to be translated to OpenCL dot8 accumulator type. + * + * @return The string specifying the OpenCL dot8 accumulator type to be used. + */ +std::string get_cl_dot8_acc_type_from_data_type(const DataType &dt); + /** Get the size of a data type in number of bits. * * @param[in] dt @ref DataType. diff --git a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h index 5328ee44bc..9dd5496c00 100644 --- a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h +++ b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef __ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__ -#define __ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__ +#ifndef ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H +#define ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H #include "arm_compute/core/CL/ICLKernel.h" @@ -49,7 +49,7 @@ public: CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel &operator=(CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel &&) = default; /** Initialise the kernel's input and output. * - * @param[in] input0 Input tensor containing the LHS matrix. Data type supported: QASYMM8 + * @param[in] input0 Input tensor containing the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED * @param[in] input1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0 * @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: S32 * @param[in] lhs_info LHS matrix information used to retrieve the number of rows to be processed by each thread @@ -64,7 +64,7 @@ public: void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info); /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel * - * @param[in] input0 Input tensor info for the LHS matrix. Data type supported: QASYMM8 + * @param[in] input0 Input tensor info for the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED * @param[in] input1 Input tensor info for the RHS reshaped matrix. Data type supported: same as @p input0 * @param[in] output Output tensor info. Data type supported: S32 * @param[in] lhs_info LHS matrix information used to retrieve the number of rows to be processed by each thread @@ -94,4 +94,4 @@ private: bool _use_dummy_work_items; }; } // namespace arm_compute -#endif /*__ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__*/ \ No newline at end of file +#endif /* ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H */ \ No newline at end of file -- cgit v1.2.1