diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/core/CL/CLHelpers.cpp | 23 | ||||
-rw-r--r-- | src/core/CL/cl_kernels/gemm_helpers.h | 99 | ||||
-rw-r--r-- | src/core/CL/cl_kernels/gemmlowp.cl | 165 | ||||
-rw-r--r-- | src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp | 2 | ||||
-rw-r--r-- | src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp | 2 | ||||
-rw-r--r-- | src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp | 11 | ||||
-rw-r--r-- | src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp | 3 | ||||
-rw-r--r-- | src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp | 2 |
8 files changed, 181 insertions, 126 deletions
diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp index 17274d38ad..28b1a3224f 100644 --- a/src/core/CL/CLHelpers.cpp +++ b/src/core/CL/CLHelpers.cpp @@ -42,6 +42,7 @@ std::string get_cl_type_from_data_type(const DataType &dt) case DataType::QASYMM8: return "uchar"; case DataType::S8: + case DataType::QASYMM8_SIGNED: case DataType::QSYMM8: case DataType::QSYMM8_PER_CHANNEL: return "char"; @@ -77,6 +78,7 @@ std::string get_cl_promoted_type_from_data_type(const DataType &dt) case DataType::QASYMM8: return "ushort"; case DataType::S8: + case DataType::QASYMM8_SIGNED: case DataType::QSYMM8: case DataType::QSYMM8_PER_CHANNEL: return "short"; @@ -124,6 +126,7 @@ std::string get_cl_select_type_from_data_type(const DataType &dt) case DataType::QASYMM8: return "uchar"; case DataType::S8: + case DataType::QASYMM8_SIGNED: case DataType::QSYMM8: case DataType::QSYMM8_PER_CHANNEL: return "char"; @@ -149,6 +152,24 @@ std::string get_cl_select_type_from_data_type(const DataType &dt) } } +std::string get_cl_dot8_acc_type_from_data_type(const DataType &dt) +{ + switch(dt) + { + case DataType::U8: + case DataType::QASYMM8: + return "uint"; + case DataType::S8: + case DataType::QASYMM8_SIGNED: + case DataType::QSYMM8: + case DataType::QSYMM8_PER_CHANNEL: + return "int"; + default: + ARM_COMPUTE_ERROR("Unsupported data type."); + return ""; + } +} + std::string get_data_size_from_data_type(const DataType &dt) { switch(dt) @@ -157,6 +178,7 @@ std::string get_data_size_from_data_type(const DataType &dt) case DataType::S8: case DataType::QSYMM8: case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: case DataType::QSYMM8_PER_CHANNEL: return "8"; case DataType::U16: @@ -300,6 +322,7 @@ size_t preferred_vector_width(const cl::Device &device, const DataType dt) case DataType::U8: case DataType::S8: case DataType::QASYMM8: + case DataType::QASYMM8_SIGNED: case DataType::QSYMM8: case DataType::QSYMM8_PER_CHANNEL: return device.getInfo<CL_DEVICE_PREFERRED_VECTOR_WIDTH_CHAR>(); diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h index 64914259a4..66e83c3558 100644 --- a/src/core/CL/cl_kernels/gemm_helpers.h +++ b/src/core/CL/cl_kernels/gemm_helpers.h @@ -559,20 +559,26 @@ * @param[in] IDX_COL The index value * @param[in] BASENAME The basename of the destination vectors * @param[in] X The basename of the source vectors + * @param[in] TYPE The data type of the destination vectors * @{ */ -#define COLUMN_VECTOR1(IDX_COL, BASENAME, X) \ - uchar BASENAME##IDX_COL = (uchar)((X##0).s##IDX_COL); -#define COLUMN_VECTOR2(IDX_COL, BASENAME, X) \ - uchar2 BASENAME##IDX_COL = (uchar2)((X##0).s##IDX_COL, (X##1).s##IDX_COL); -#define COLUMN_VECTOR3(IDX_COL, BASENAME, X) \ - uchar3 BASENAME##IDX_COL = (uchar3)((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL); -#define COLUMN_VECTOR4(IDX_COL, BASENAME, X) \ - uchar4 BASENAME##IDX_COL = (uchar4)((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL); -#define COLUMN_VECTOR8(IDX_COL, BASENAME, X) \ - uchar8 BASENAME##IDX_COL = (uchar8)((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL); -#define COLUMN_VECTOR16(IDX_COL, BASENAME, X) \ - uchar16 BASENAME##IDX_COL = (uchar16)((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL, (X##8).s##IDX_COL, (X##9).s##IDX_COL, (X##A).s##IDX_COL, (X##B).s##IDX_COL, (X##C).s##IDX_COL, (X##D).s##IDX_COL, (X##E).s##IDX_COL, (X##F).s##IDX_COL); +#define COLUMN_VECTOR1(IDX_COL, BASENAME, X, TYPE) \ + TYPE BASENAME##IDX_COL = (TYPE)((X##0).s##IDX_COL); +#define COLUMN_VECTOR2(IDX_COL, BASENAME, X, TYPE) \ + VEC_DATA_TYPE(TYPE, 2) \ + BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 2))((X##0).s##IDX_COL, (X##1).s##IDX_COL); +#define COLUMN_VECTOR3(IDX_COL, BASENAME, X, TYPE) \ + VEC_DATA_TYPE(TYPE, 3) \ + BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 3))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL); +#define COLUMN_VECTOR4(IDX_COL, BASENAME, X, TYPE) \ + VEC_DATA_TYPE(TYPE, 4) \ + BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 4))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL); +#define COLUMN_VECTOR8(IDX_COL, BASENAME, X, TYPE) \ + VEC_DATA_TYPE(TYPE, 8) \ + BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 8))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL); +#define COLUMN_VECTOR16(IDX_COL, BASENAME, X, TYPE) \ + VEC_DATA_TYPE(TYPE, 16) \ + BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 16))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL, (X##8).s##IDX_COL, (X##9).s##IDX_COL, (X##A).s##IDX_COL, (X##B).s##IDX_COL, (X##C).s##IDX_COL, (X##D).s##IDX_COL, (X##E).s##IDX_COL, (X##F).s##IDX_COL); /** @} */ // end of group COLUMN_VECTORn /** Create transposed vectors of the given vectors @@ -581,35 +587,36 @@ * @param[in] K0 The size of the source vectors * @param[in] BASENAME The basename of transposed vectors * @param[in] B The basename of source vectors for transposition + * @param[in] TYPE The data type of the transposed vectors * @{ */ -#define TRANSPOSE_K0X1(K0, BASENAME, B) \ - COLUMN_VECTOR(K0, 0, BASENAME, B); -#define TRANSPOSE_K0X2(K0, BASENAME, B) \ - TRANSPOSE_K0X1(K0, BASENAME, B); \ - COLUMN_VECTOR(K0, 1, BASENAME, B); -#define TRANSPOSE_K0X3(K0, BASENAME, B) \ - TRANSPOSE_K0X2(K0, BASENAME, B); \ - COLUMN_VECTOR(K0, 2, BASENAME, B); -#define TRANSPOSE_K0X4(K0, BASENAME, B) \ - TRANSPOSE_K0X3(K0, BASENAME, B); \ - COLUMN_VECTOR(K0, 3, BASENAME, B); -#define TRANSPOSE_K0X8(K0, BASENAME, B) \ - TRANSPOSE_K0X4(K0, BASENAME, B); \ - COLUMN_VECTOR(K0, 4, BASENAME, B); \ - COLUMN_VECTOR(K0, 5, BASENAME, B); \ - COLUMN_VECTOR(K0, 6, BASENAME, B); \ - COLUMN_VECTOR(K0, 7, BASENAME, B); -#define TRANSPOSE_K0X16(K0, BASENAME, B) \ - TRANSPOSE_K0X8(K0, BASENAME, B); \ - COLUMN_VECTOR(K0, 8, BASENAME, B); \ - COLUMN_VECTOR(K0, 9, BASENAME, B); \ - COLUMN_VECTOR(K0, A, BASENAME, B); \ - COLUMN_VECTOR(K0, B, BASENAME, B); \ - COLUMN_VECTOR(K0, C, BASENAME, B); \ - COLUMN_VECTOR(K0, D, BASENAME, B); \ - COLUMN_VECTOR(K0, E, BASENAME, B); \ - COLUMN_VECTOR(K0, F, BASENAME, B); +#define TRANSPOSE_K0X1(K0, BASENAME, B, TYPE) \ + COLUMN_VECTOR(K0, 0, BASENAME, B, TYPE); +#define TRANSPOSE_K0X2(K0, BASENAME, B, TYPE) \ + TRANSPOSE_K0X1(K0, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, 1, BASENAME, B, TYPE); +#define TRANSPOSE_K0X3(K0, BASENAME, B, TYPE) \ + TRANSPOSE_K0X2(K0, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, 2, BASENAME, B, TYPE); +#define TRANSPOSE_K0X4(K0, BASENAME, B, TYPE) \ + TRANSPOSE_K0X3(K0, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, 3, BASENAME, B, TYPE); +#define TRANSPOSE_K0X8(K0, BASENAME, B, TYPE) \ + TRANSPOSE_K0X4(K0, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, 4, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, 5, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, 6, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, 7, BASENAME, B, TYPE); +#define TRANSPOSE_K0X16(K0, BASENAME, B, TYPE) \ + TRANSPOSE_K0X8(K0, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, 8, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, 9, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, A, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, B, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, C, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, D, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, E, BASENAME, B, TYPE); \ + COLUMN_VECTOR(K0, F, BASENAME, B, TYPE); /** @} */ // end of group TRANSPOSE_K0Xn @@ -619,10 +626,11 @@ * @param[in] IDX_COL The index value * @param[in] BASENAME The basename of the destination vectors * @param[in] B The basename of the source vectors + * @param[in] TYPE The data type of the destination vectors */ -#define COLUMN_VECTOR(K0, IDX_COL, BASENAME, B) \ - CONCAT(COLUMN_VECTOR, K0) \ - (IDX_COL, BASENAME, B); +#define COLUMN_VECTOR(K0, IDX_COL, BASENAME, B, TYPE) \ + CONCAT(COLUMN_VECTOR, K0) \ + (IDX_COL, BASENAME, B, TYPE); /** Create transposed vectors form the given source vectors * @@ -630,11 +638,12 @@ * @param[in] N0 The number of source vectors * @param[in] BASENAME The basename of transposed vectors * @param[in] B The basename of source vectors for transposition + * @param[in] TYPE The data type of the transposed vectors * */ -#define TRANSPOSE_K0XN0(K0, N0, BASENAME, B) \ - CONCAT(TRANSPOSE_K0X, N0) \ - (K0, BASENAME, B); +#define TRANSPOSE_K0XN0(K0, N0, BASENAME, B, TYPE) \ + CONCAT(TRANSPOSE_K0X, N0) \ + (K0, BASENAME, B, TYPE); /** Add the variables (BIAS0 to BIASn-1) to the others (BASENAME0 to BASENAMEn-1) * @name ADD_ROW_n diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl index fa08b149e4..47791fbe74 100644 --- a/src/core/CL/cl_kernels/gemmlowp.cl +++ b/src/core/CL/cl_kernels/gemmlowp.cl @@ -25,6 +25,8 @@ #include "helpers_asymm.h" #include "repeat.h" +#if defined(DATA_TYPE) && defined(ACC_DATA_TYPE) + #if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) #if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8) #define ARM_DOT(x, y, val) val = arm_dot_acc((x), (y), (val)); @@ -36,17 +38,17 @@ #if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) /** Specialized macros to perform the dot product instruction between two vectors of size N [1,16]. These macros use the dot8 instruction */ -#define ARM_DOT1(a, b, c) \ - ({ \ - ARM_DOT((uchar4)(a, (uchar3)0), (uchar4)(b, (uchar3)0), c); \ +#define ARM_DOT1(a, b, c) \ + ({ \ + ARM_DOT((VEC_DATA_TYPE(DATA_TYPE, 4))(a, (VEC_DATA_TYPE(DATA_TYPE, 3))0), (VEC_DATA_TYPE(DATA_TYPE, 4))(b, (VEC_DATA_TYPE(DATA_TYPE, 3))0), c); \ }) -#define ARM_DOT2(a, b, c) \ - ({ \ - ARM_DOT((uchar4)(a, (uchar2)0), (uchar4)(b, (uchar2)0), c); \ +#define ARM_DOT2(a, b, c) \ + ({ \ + ARM_DOT((VEC_DATA_TYPE(DATA_TYPE, 4))(a, (VEC_DATA_TYPE(DATA_TYPE, 2))0), (VEC_DATA_TYPE(DATA_TYPE, 4))(b, (VEC_DATA_TYPE(DATA_TYPE, 2))0), c); \ }) -#define ARM_DOT3(a, b, c) \ - ({ \ - ARM_DOT((uchar4)(a, (uchar)0), (uchar4)(b, (uchar)0), c); \ +#define ARM_DOT3(a, b, c) \ + ({ \ + ARM_DOT((VEC_DATA_TYPE(DATA_TYPE, 4))(a, (DATA_TYPE)0), (VEC_DATA_TYPE(DATA_TYPE, 4))(b, (DATA_TYPE)0), c); \ }) #define ARM_DOT4(a, b, c) \ ({ \ @@ -66,24 +68,24 @@ #else // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8) /** Specialized macros to perform the dot product instruction between two vectors of size K0 [1,16] without using the dot8 instruction. */ -#define ARM_DOT1(a, b, c) \ - ({ \ - c += (uint)a * b; \ +#define ARM_DOT1(a, b, c) \ + ({ \ + c += (ACC_DATA_TYPE)a * b; \ }) -#define ARM_DOT2(a, b, c) \ - ({ \ - c += (uint)a.s0 * b.s0; \ - c += (uint)a.s1 * b.s1; \ +#define ARM_DOT2(a, b, c) \ + ({ \ + c += (ACC_DATA_TYPE)a.s0 * b.s0; \ + c += (ACC_DATA_TYPE)a.s1 * b.s1; \ }) -#define ARM_DOT3(a, b, c) \ - ({ \ - ARM_DOT2(a, b, c); \ - c += (uint)a.s2 * b.s2; \ +#define ARM_DOT3(a, b, c) \ + ({ \ + ARM_DOT2(a, b, c); \ + c += (ACC_DATA_TYPE)a.s2 * b.s2; \ }) -#define ARM_DOT4(a, b, c) \ - ({ \ - ARM_DOT3(a, b, c); \ - c += (uint)a.s3 * b.s3; \ +#define ARM_DOT4(a, b, c) \ + ({ \ + ARM_DOT3(a, b, c); \ + c += (ACC_DATA_TYPE)a.s3 * b.s3; \ }) #define ARM_DOT8(a, b, c) \ ({ \ @@ -194,13 +196,15 @@ }) #if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A) -#define VECTOR_UCHAR VEC_DATA_TYPE(uchar, NUM_ELEMS_PROCESSED_PER_THREAD_X) -#define VECTOR_UINT VEC_DATA_TYPE(uint, NUM_ELEMS_PROCESSED_PER_THREAD_X) +#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X) +#define VECTOR_ACC_TYPE VEC_DATA_TYPE(ACC_DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X) #define VECTOR_INT VEC_DATA_TYPE(int, NUM_ELEMS_PROCESSED_PER_THREAD_X) /** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped * * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A * + * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar) + * @note The accumulator data type must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint) * @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time: * -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D * -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D @@ -302,93 +306,98 @@ __kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0), int end_row_vec_a = src_addr.s0 + COLS_A; - VECTOR_UINT acc0 = 0; + VECTOR_ACC_TYPE acc0 = 0; #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - VECTOR_UINT acc1 = 0; + VECTOR_ACC_TYPE acc1 = 0; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - VECTOR_UINT acc2 = 0; + VECTOR_ACC_TYPE acc2 = 0; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - VECTOR_UINT acc3 = 0; + VECTOR_ACC_TYPE acc3 = 0; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - VECTOR_UINT acc4 = 0; + VECTOR_ACC_TYPE acc4 = 0; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y)) { // Load values from matrix A - uchar2 a0 = vload2(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y); + VEC_DATA_TYPE(DATA_TYPE, 2) + a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar2 a1 = vload2(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y); + VEC_DATA_TYPE(DATA_TYPE, 2) + a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar2 a2 = vload2(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y); + VEC_DATA_TYPE(DATA_TYPE, 2) + a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar2 a3 = vload2(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y); + VEC_DATA_TYPE(DATA_TYPE, 2) + a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - uchar2 a4 = vload2(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y); + VEC_DATA_TYPE(DATA_TYPE, 2) + a4 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 // Load values from matrix B - VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1); - VECTOR_UCHAR b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1 + src1_stride_y); + VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1)); + VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y)); // Accumulate - acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0.s0; - acc0 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a0.s1; + acc0 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a0.s0; + acc0 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a0.s1; #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1.s0; - acc1 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a1.s1; + acc1 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a1.s0; + acc1 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a1.s1; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2.s0; - acc2 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a2.s1; + acc2 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a2.s0; + acc2 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a2.s1; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3.s0; - acc3 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a3.s1; + acc3 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a3.s0; + acc3 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a3.s1; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4.s0; - acc4 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a4.s1; + acc4 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a4.s0; + acc4 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a4.s1; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 } for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y)) { // Load values from matrix A - uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y); + DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y)); #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y); + DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y); + DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y); + DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y); + DATA_TYPE a4 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y)); #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 // Load values from matrix B - VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1); + VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1)); // Accumulate - acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0; + acc0 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a0; #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 - acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1; + acc1 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a1; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 - acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2; + acc2 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a2; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 - acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3; + acc3 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a3; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3 #if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 - acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4; + acc4 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a4; #endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4 } @@ -476,6 +485,8 @@ __kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0), * The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed * The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed * + * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar) + * @note The accumulator data type must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint) * @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time. * @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90). * @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4). @@ -588,15 +599,15 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0); // Initialize the accumulators - REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; + REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(ACC_DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(ACC_DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; for(int i = 0; i < k; i += K0) { // Load values from LHS matrix - LOAD_BLOCK(M0, K0, uchar, a, lhs_addr, 0, LHS_STEP_X, zlhs); + LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X, zlhs); // Load values from RHS matrix - LOAD_BLOCK(N0, K0, uchar, b, rhs_addr, 0, RHS_STEP_X, zrhs); + LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X, zrhs); // Partial matrix multiplication M0,N0,K0 ARM_MM_K0XN0XM0(M0, N0, K0, a, b, c); @@ -643,6 +654,8 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), * The LHS matrix is NOT reshaped * The RHS matrix is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed * + * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar) + * @note The accumulator data type must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint) * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64) * @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4). * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2) @@ -661,7 +674,7 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), * -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor * (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix * - * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32 + * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: QASYMM8/QASYMM8_SIGNED * @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes) * @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes) @@ -673,7 +686,7 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs), * @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes) * @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes) * @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix - * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr + * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32 * @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes) * @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes) * @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes) @@ -758,15 +771,15 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), #endif // defined(REINTERPRET_INPUT_AS_3D) // Initialize the accumulators - REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(N0-1)=0; + REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(ACC_DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(ACC_DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0; for(int i = 0; i < K; i += K0) { // Load values from LHS matrix - LOAD_BLOCK(M0, K0, uchar, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); + LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); // Load values from RHS matrix - LOAD_BLOCK(N0, K0, uchar, b, rhs_ptr, rhs_offset, RHS_STEP_X, zrhs); + LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X, zrhs); // Partial matrix multiplication M0,N0,K0 ARM_MM_K0XN0XM0(M0, N0, K0, a, b, c); @@ -809,6 +822,8 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs), * The LHS matrix is NOT reshaped * The RHS matrix is NOT reshaped * + * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar) + * @note The accumulator data type must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint) * @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64) * @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2) * @note The number of N0 columns to process must be passed at compile time using -DN0 (i.e. -DN0=2) @@ -908,20 +923,20 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs), #endif // defined(REINTERPRET_INPUT_AS_3D) // Initialize the accumulators - REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; + REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(ACC_DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(ACC_DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0; int i = 0; for(; i <= (K - K0); i += K0) { // Load values from LHS matrix - LOAD_BLOCK(M0, K0, uchar, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); + LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); // Load values from RHS matrix - LOAD_BLOCK(K0, N0, uchar, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs); + LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs); // Transpose the values from RHS matrix - TRANSPOSE_K0XN0(K0, N0, b_t, b); + TRANSPOSE_K0XN0(K0, N0, b_t, b, DATA_TYPE); // Partial matrix multiplication M0,N0,K0 ARM_MM_K0XN0XM0(M0, N0, K0, a, b_t, c); @@ -935,13 +950,13 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs), for(; i < K; ++i) { // Load values from LHS matrix - LOAD_BLOCK(M0, 1, uchar, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); + LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs); // Load values from RHS matrix - LOAD_BLOCK(1, N0, uchar, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs); + LOAD_BLOCK(1, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs); // Transpose the values from RHS matrix - TRANSPOSE_K0XN0(1, N0, b_t, b); + TRANSPOSE_K0XN0(1, N0, b_t, b, DATA_TYPE); // Partial matrix multiplication M0,N0,1 ARM_MM_K0XN0XM0(M0, N0, 1, a, b_t, c); @@ -975,6 +990,8 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs), } #endif // defined(M0) && defined(N0) && defined(K0) && defined(K) +#endif // defined(DATA_TYPE) && defined(ACC_DATA_TYPE) + #if defined(COLS_A) /** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A. * diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp index cda7a83de7..78df0eec16 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp @@ -212,6 +212,8 @@ void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const IC build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0))); build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x())); build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y())); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type())); + build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(input0->info()->data_type())); kernel_name = "gemmlowp_mm_midgard"; diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp index 09caeeea55..3e887d8163 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp @@ -216,6 +216,8 @@ void CLGEMMLowpMatrixMultiplyNativeKernel::configure(const ICLTensor *input0, co build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0)); build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0)); build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0)); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type())); + build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(input0->info()->data_type())); std::string kernel_name("gemmlowp_mm_native"); diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp index 050b792c4e..8d3aff6603 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp @@ -42,13 +42,9 @@ #include <cstdint> #include <tuple> -using namespace arm_compute; -using namespace arm_compute::misc::shape_calculator; - namespace arm_compute { -class Coordinates; -} // namespace arm_compute +using namespace misc::shape_calculator; namespace { @@ -210,6 +206,8 @@ void CLGEMMLowpMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, build_opts.add_option("-DK0=" + support::cpp11::to_string(lhs_info.k0)); build_opts.add_option("-DV0=" + support::cpp11::to_string(lhs_info.v0)); build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0)); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type())); + build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(input0->info()->data_type())); std::string kernel_name("gemmlowp_mm_reshaped_"); kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_"; @@ -310,4 +308,5 @@ void CLGEMMLowpMatrixMultiplyReshapedKernel::run(const Window &window, cl::Comma enqueue(queue, *this, slice, lws_hint(), _use_dummy_work_items); } while(window.slide_window_slice_3D(slice)); -}
\ No newline at end of file +} +} // namespace arm_compute
\ No newline at end of file diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp index 779f96e7cf..3fa2fad8fd 100644 --- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp +++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp @@ -54,6 +54,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, const GEMMReshapeInfo &gemm_info) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); @@ -218,6 +219,8 @@ void CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *i build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0)); build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0)); build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0)); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type())); + build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(input0->info()->data_type())); std::string kernel_name("gemmlowp_mm_reshaped_only_rhs_"); kernel_name += rhs_info.transpose ? "t" : "nt"; diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp index d322723150..a4270d7923 100644 --- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp @@ -402,7 +402,7 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!"); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); const bool is_quantized_per_channel = is_data_type_quantized_per_channel(weights->data_type()); if(is_quantized_per_channel) |