aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2019-11-27 16:17:30 +0000
committerMichele Di Giorgio <michele.digiorgio@arm.com>2019-12-03 11:15:36 +0000
commitf9179d393a07eb9eed753e315df79d22391906c6 (patch)
treed8a1fd9d984bdd335d3ecac117ec33c4523211ef
parentb714b1d6a53e6c33df2ea3c1e8340f20480d799b (diff)
downloadComputeLibrary-f9179d393a07eb9eed753e315df79d22391906c6.tar.gz
COMPMID-2793: Add support for QASYMM8_SIGNED in CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel
Change-Id: I8abfdd3372cc394b98ec038b9fcb4abfe9216894 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/2401 Reviewed-by: Giorgio Arena <giorgio.arena@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--arm_compute/core/CL/CLHelpers.h8
-rw-r--r--arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h10
-rw-r--r--src/core/CL/CLHelpers.cpp23
-rw-r--r--src/core/CL/cl_kernels/gemm_helpers.h99
-rw-r--r--src/core/CL/cl_kernels/gemmlowp.cl165
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp2
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp2
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp11
-rw-r--r--src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp3
-rw-r--r--src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp2
-rw-r--r--tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp70
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h133
12 files changed, 348 insertions, 180 deletions
diff --git a/arm_compute/core/CL/CLHelpers.h b/arm_compute/core/CL/CLHelpers.h
index 8801af579e..cd65eafc9c 100644
--- a/arm_compute/core/CL/CLHelpers.h
+++ b/arm_compute/core/CL/CLHelpers.h
@@ -74,6 +74,14 @@ std::string get_cl_unsigned_type_from_element_size(size_t element_size);
*/
std::string get_cl_select_type_from_data_type(const DataType &dt);
+/** Translates a tensor data type to the appropriate OpenCL dot8 accumulator type.
+ *
+ * @param[in] dt @ref DataType to be translated to OpenCL dot8 accumulator type.
+ *
+ * @return The string specifying the OpenCL dot8 accumulator type to be used.
+ */
+std::string get_cl_dot8_acc_type_from_data_type(const DataType &dt);
+
/** Get the size of a data type in number of bits.
*
* @param[in] dt @ref DataType.
diff --git a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h
index 5328ee44bc..9dd5496c00 100644
--- a/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h
+++ b/arm_compute/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.h
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef __ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__
-#define __ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__
+#ifndef ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H
+#define ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H
#include "arm_compute/core/CL/ICLKernel.h"
@@ -49,7 +49,7 @@ public:
CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel &operator=(CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel &&) = default;
/** Initialise the kernel's input and output.
*
- * @param[in] input0 Input tensor containing the LHS matrix. Data type supported: QASYMM8
+ * @param[in] input0 Input tensor containing the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED
* @param[in] input1 Input tensor containing the RHS reshaped matrix. Data type supported: same as @p input0
* @param[out] output Output tensor to store the result of matrix multiplication. Data type supported: S32
* @param[in] lhs_info LHS matrix information used to retrieve the number of rows to be processed by each thread
@@ -64,7 +64,7 @@ public:
void configure(const ICLTensor *input0, const ICLTensor *input1, ICLTensor *output, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, const GEMMReshapeInfo &gemm_info);
/** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel
*
- * @param[in] input0 Input tensor info for the LHS matrix. Data type supported: QASYMM8
+ * @param[in] input0 Input tensor info for the LHS matrix. Data type supported: QASYMM8/QASYMM8_SIGNED
* @param[in] input1 Input tensor info for the RHS reshaped matrix. Data type supported: same as @p input0
* @param[in] output Output tensor info. Data type supported: S32
* @param[in] lhs_info LHS matrix information used to retrieve the number of rows to be processed by each thread
@@ -94,4 +94,4 @@ private:
bool _use_dummy_work_items;
};
} // namespace arm_compute
-#endif /*__ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H__*/ \ No newline at end of file
+#endif /* ARM_COMPUTE_CLGEMMLOWPMATRIXMULTIPLYRESHAPEDONLYRHSKERNEL_H */ \ No newline at end of file
diff --git a/src/core/CL/CLHelpers.cpp b/src/core/CL/CLHelpers.cpp
index 17274d38ad..28b1a3224f 100644
--- a/src/core/CL/CLHelpers.cpp
+++ b/src/core/CL/CLHelpers.cpp
@@ -42,6 +42,7 @@ std::string get_cl_type_from_data_type(const DataType &dt)
case DataType::QASYMM8:
return "uchar";
case DataType::S8:
+ case DataType::QASYMM8_SIGNED:
case DataType::QSYMM8:
case DataType::QSYMM8_PER_CHANNEL:
return "char";
@@ -77,6 +78,7 @@ std::string get_cl_promoted_type_from_data_type(const DataType &dt)
case DataType::QASYMM8:
return "ushort";
case DataType::S8:
+ case DataType::QASYMM8_SIGNED:
case DataType::QSYMM8:
case DataType::QSYMM8_PER_CHANNEL:
return "short";
@@ -124,6 +126,7 @@ std::string get_cl_select_type_from_data_type(const DataType &dt)
case DataType::QASYMM8:
return "uchar";
case DataType::S8:
+ case DataType::QASYMM8_SIGNED:
case DataType::QSYMM8:
case DataType::QSYMM8_PER_CHANNEL:
return "char";
@@ -149,6 +152,24 @@ std::string get_cl_select_type_from_data_type(const DataType &dt)
}
}
+std::string get_cl_dot8_acc_type_from_data_type(const DataType &dt)
+{
+ switch(dt)
+ {
+ case DataType::U8:
+ case DataType::QASYMM8:
+ return "uint";
+ case DataType::S8:
+ case DataType::QASYMM8_SIGNED:
+ case DataType::QSYMM8:
+ case DataType::QSYMM8_PER_CHANNEL:
+ return "int";
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type.");
+ return "";
+ }
+}
+
std::string get_data_size_from_data_type(const DataType &dt)
{
switch(dt)
@@ -157,6 +178,7 @@ std::string get_data_size_from_data_type(const DataType &dt)
case DataType::S8:
case DataType::QSYMM8:
case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
case DataType::QSYMM8_PER_CHANNEL:
return "8";
case DataType::U16:
@@ -300,6 +322,7 @@ size_t preferred_vector_width(const cl::Device &device, const DataType dt)
case DataType::U8:
case DataType::S8:
case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
case DataType::QSYMM8:
case DataType::QSYMM8_PER_CHANNEL:
return device.getInfo<CL_DEVICE_PREFERRED_VECTOR_WIDTH_CHAR>();
diff --git a/src/core/CL/cl_kernels/gemm_helpers.h b/src/core/CL/cl_kernels/gemm_helpers.h
index 64914259a4..66e83c3558 100644
--- a/src/core/CL/cl_kernels/gemm_helpers.h
+++ b/src/core/CL/cl_kernels/gemm_helpers.h
@@ -559,20 +559,26 @@
* @param[in] IDX_COL The index value
* @param[in] BASENAME The basename of the destination vectors
* @param[in] X The basename of the source vectors
+ * @param[in] TYPE The data type of the destination vectors
* @{
*/
-#define COLUMN_VECTOR1(IDX_COL, BASENAME, X) \
- uchar BASENAME##IDX_COL = (uchar)((X##0).s##IDX_COL);
-#define COLUMN_VECTOR2(IDX_COL, BASENAME, X) \
- uchar2 BASENAME##IDX_COL = (uchar2)((X##0).s##IDX_COL, (X##1).s##IDX_COL);
-#define COLUMN_VECTOR3(IDX_COL, BASENAME, X) \
- uchar3 BASENAME##IDX_COL = (uchar3)((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL);
-#define COLUMN_VECTOR4(IDX_COL, BASENAME, X) \
- uchar4 BASENAME##IDX_COL = (uchar4)((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL);
-#define COLUMN_VECTOR8(IDX_COL, BASENAME, X) \
- uchar8 BASENAME##IDX_COL = (uchar8)((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL);
-#define COLUMN_VECTOR16(IDX_COL, BASENAME, X) \
- uchar16 BASENAME##IDX_COL = (uchar16)((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL, (X##8).s##IDX_COL, (X##9).s##IDX_COL, (X##A).s##IDX_COL, (X##B).s##IDX_COL, (X##C).s##IDX_COL, (X##D).s##IDX_COL, (X##E).s##IDX_COL, (X##F).s##IDX_COL);
+#define COLUMN_VECTOR1(IDX_COL, BASENAME, X, TYPE) \
+ TYPE BASENAME##IDX_COL = (TYPE)((X##0).s##IDX_COL);
+#define COLUMN_VECTOR2(IDX_COL, BASENAME, X, TYPE) \
+ VEC_DATA_TYPE(TYPE, 2) \
+ BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 2))((X##0).s##IDX_COL, (X##1).s##IDX_COL);
+#define COLUMN_VECTOR3(IDX_COL, BASENAME, X, TYPE) \
+ VEC_DATA_TYPE(TYPE, 3) \
+ BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 3))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL);
+#define COLUMN_VECTOR4(IDX_COL, BASENAME, X, TYPE) \
+ VEC_DATA_TYPE(TYPE, 4) \
+ BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 4))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL);
+#define COLUMN_VECTOR8(IDX_COL, BASENAME, X, TYPE) \
+ VEC_DATA_TYPE(TYPE, 8) \
+ BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 8))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL);
+#define COLUMN_VECTOR16(IDX_COL, BASENAME, X, TYPE) \
+ VEC_DATA_TYPE(TYPE, 16) \
+ BASENAME##IDX_COL = (VEC_DATA_TYPE(TYPE, 16))((X##0).s##IDX_COL, (X##1).s##IDX_COL, (X##2).s##IDX_COL, (X##3).s##IDX_COL, (X##4).s##IDX_COL, (X##5).s##IDX_COL, (X##6).s##IDX_COL, (X##7).s##IDX_COL, (X##8).s##IDX_COL, (X##9).s##IDX_COL, (X##A).s##IDX_COL, (X##B).s##IDX_COL, (X##C).s##IDX_COL, (X##D).s##IDX_COL, (X##E).s##IDX_COL, (X##F).s##IDX_COL);
/** @} */ // end of group COLUMN_VECTORn
/** Create transposed vectors of the given vectors
@@ -581,35 +587,36 @@
* @param[in] K0 The size of the source vectors
* @param[in] BASENAME The basename of transposed vectors
* @param[in] B The basename of source vectors for transposition
+ * @param[in] TYPE The data type of the transposed vectors
* @{
*/
-#define TRANSPOSE_K0X1(K0, BASENAME, B) \
- COLUMN_VECTOR(K0, 0, BASENAME, B);
-#define TRANSPOSE_K0X2(K0, BASENAME, B) \
- TRANSPOSE_K0X1(K0, BASENAME, B); \
- COLUMN_VECTOR(K0, 1, BASENAME, B);
-#define TRANSPOSE_K0X3(K0, BASENAME, B) \
- TRANSPOSE_K0X2(K0, BASENAME, B); \
- COLUMN_VECTOR(K0, 2, BASENAME, B);
-#define TRANSPOSE_K0X4(K0, BASENAME, B) \
- TRANSPOSE_K0X3(K0, BASENAME, B); \
- COLUMN_VECTOR(K0, 3, BASENAME, B);
-#define TRANSPOSE_K0X8(K0, BASENAME, B) \
- TRANSPOSE_K0X4(K0, BASENAME, B); \
- COLUMN_VECTOR(K0, 4, BASENAME, B); \
- COLUMN_VECTOR(K0, 5, BASENAME, B); \
- COLUMN_VECTOR(K0, 6, BASENAME, B); \
- COLUMN_VECTOR(K0, 7, BASENAME, B);
-#define TRANSPOSE_K0X16(K0, BASENAME, B) \
- TRANSPOSE_K0X8(K0, BASENAME, B); \
- COLUMN_VECTOR(K0, 8, BASENAME, B); \
- COLUMN_VECTOR(K0, 9, BASENAME, B); \
- COLUMN_VECTOR(K0, A, BASENAME, B); \
- COLUMN_VECTOR(K0, B, BASENAME, B); \
- COLUMN_VECTOR(K0, C, BASENAME, B); \
- COLUMN_VECTOR(K0, D, BASENAME, B); \
- COLUMN_VECTOR(K0, E, BASENAME, B); \
- COLUMN_VECTOR(K0, F, BASENAME, B);
+#define TRANSPOSE_K0X1(K0, BASENAME, B, TYPE) \
+ COLUMN_VECTOR(K0, 0, BASENAME, B, TYPE);
+#define TRANSPOSE_K0X2(K0, BASENAME, B, TYPE) \
+ TRANSPOSE_K0X1(K0, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, 1, BASENAME, B, TYPE);
+#define TRANSPOSE_K0X3(K0, BASENAME, B, TYPE) \
+ TRANSPOSE_K0X2(K0, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, 2, BASENAME, B, TYPE);
+#define TRANSPOSE_K0X4(K0, BASENAME, B, TYPE) \
+ TRANSPOSE_K0X3(K0, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, 3, BASENAME, B, TYPE);
+#define TRANSPOSE_K0X8(K0, BASENAME, B, TYPE) \
+ TRANSPOSE_K0X4(K0, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, 4, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, 5, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, 6, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, 7, BASENAME, B, TYPE);
+#define TRANSPOSE_K0X16(K0, BASENAME, B, TYPE) \
+ TRANSPOSE_K0X8(K0, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, 8, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, 9, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, A, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, B, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, C, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, D, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, E, BASENAME, B, TYPE); \
+ COLUMN_VECTOR(K0, F, BASENAME, B, TYPE);
/** @} */ // end of group TRANSPOSE_K0Xn
@@ -619,10 +626,11 @@
* @param[in] IDX_COL The index value
* @param[in] BASENAME The basename of the destination vectors
* @param[in] B The basename of the source vectors
+ * @param[in] TYPE The data type of the destination vectors
*/
-#define COLUMN_VECTOR(K0, IDX_COL, BASENAME, B) \
- CONCAT(COLUMN_VECTOR, K0) \
- (IDX_COL, BASENAME, B);
+#define COLUMN_VECTOR(K0, IDX_COL, BASENAME, B, TYPE) \
+ CONCAT(COLUMN_VECTOR, K0) \
+ (IDX_COL, BASENAME, B, TYPE);
/** Create transposed vectors form the given source vectors
*
@@ -630,11 +638,12 @@
* @param[in] N0 The number of source vectors
* @param[in] BASENAME The basename of transposed vectors
* @param[in] B The basename of source vectors for transposition
+ * @param[in] TYPE The data type of the transposed vectors
*
*/
-#define TRANSPOSE_K0XN0(K0, N0, BASENAME, B) \
- CONCAT(TRANSPOSE_K0X, N0) \
- (K0, BASENAME, B);
+#define TRANSPOSE_K0XN0(K0, N0, BASENAME, B, TYPE) \
+ CONCAT(TRANSPOSE_K0X, N0) \
+ (K0, BASENAME, B, TYPE);
/** Add the variables (BIAS0 to BIASn-1) to the others (BASENAME0 to BASENAMEn-1)
* @name ADD_ROW_n
diff --git a/src/core/CL/cl_kernels/gemmlowp.cl b/src/core/CL/cl_kernels/gemmlowp.cl
index fa08b149e4..47791fbe74 100644
--- a/src/core/CL/cl_kernels/gemmlowp.cl
+++ b/src/core/CL/cl_kernels/gemmlowp.cl
@@ -25,6 +25,8 @@
#include "helpers_asymm.h"
#include "repeat.h"
+#if defined(DATA_TYPE) && defined(ACC_DATA_TYPE)
+
#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
#if defined(ARM_COMPUTE_OPENCL_DOT8_ACC_ENABLED) && defined(cl_arm_integer_dot_product_accumulate_int8)
#define ARM_DOT(x, y, val) val = arm_dot_acc((x), (y), (val));
@@ -36,17 +38,17 @@
#if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
/** Specialized macros to perform the dot product instruction between two vectors of size N [1,16]. These macros use the dot8 instruction */
-#define ARM_DOT1(a, b, c) \
- ({ \
- ARM_DOT((uchar4)(a, (uchar3)0), (uchar4)(b, (uchar3)0), c); \
+#define ARM_DOT1(a, b, c) \
+ ({ \
+ ARM_DOT((VEC_DATA_TYPE(DATA_TYPE, 4))(a, (VEC_DATA_TYPE(DATA_TYPE, 3))0), (VEC_DATA_TYPE(DATA_TYPE, 4))(b, (VEC_DATA_TYPE(DATA_TYPE, 3))0), c); \
})
-#define ARM_DOT2(a, b, c) \
- ({ \
- ARM_DOT((uchar4)(a, (uchar2)0), (uchar4)(b, (uchar2)0), c); \
+#define ARM_DOT2(a, b, c) \
+ ({ \
+ ARM_DOT((VEC_DATA_TYPE(DATA_TYPE, 4))(a, (VEC_DATA_TYPE(DATA_TYPE, 2))0), (VEC_DATA_TYPE(DATA_TYPE, 4))(b, (VEC_DATA_TYPE(DATA_TYPE, 2))0), c); \
})
-#define ARM_DOT3(a, b, c) \
- ({ \
- ARM_DOT((uchar4)(a, (uchar)0), (uchar4)(b, (uchar)0), c); \
+#define ARM_DOT3(a, b, c) \
+ ({ \
+ ARM_DOT((VEC_DATA_TYPE(DATA_TYPE, 4))(a, (DATA_TYPE)0), (VEC_DATA_TYPE(DATA_TYPE, 4))(b, (DATA_TYPE)0), c); \
})
#define ARM_DOT4(a, b, c) \
({ \
@@ -66,24 +68,24 @@
#else // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
/** Specialized macros to perform the dot product instruction between two vectors of size K0 [1,16] without using the dot8 instruction. */
-#define ARM_DOT1(a, b, c) \
- ({ \
- c += (uint)a * b; \
+#define ARM_DOT1(a, b, c) \
+ ({ \
+ c += (ACC_DATA_TYPE)a * b; \
})
-#define ARM_DOT2(a, b, c) \
- ({ \
- c += (uint)a.s0 * b.s0; \
- c += (uint)a.s1 * b.s1; \
+#define ARM_DOT2(a, b, c) \
+ ({ \
+ c += (ACC_DATA_TYPE)a.s0 * b.s0; \
+ c += (ACC_DATA_TYPE)a.s1 * b.s1; \
})
-#define ARM_DOT3(a, b, c) \
- ({ \
- ARM_DOT2(a, b, c); \
- c += (uint)a.s2 * b.s2; \
+#define ARM_DOT3(a, b, c) \
+ ({ \
+ ARM_DOT2(a, b, c); \
+ c += (ACC_DATA_TYPE)a.s2 * b.s2; \
})
-#define ARM_DOT4(a, b, c) \
- ({ \
- ARM_DOT3(a, b, c); \
- c += (uint)a.s3 * b.s3; \
+#define ARM_DOT4(a, b, c) \
+ ({ \
+ ARM_DOT3(a, b, c); \
+ c += (ACC_DATA_TYPE)a.s3 * b.s3; \
})
#define ARM_DOT8(a, b, c) \
({ \
@@ -194,13 +196,15 @@
})
#if defined(NUM_ELEMS_PROCESSED_PER_THREAD_X) && defined(NUM_ELEMS_PROCESSED_PER_THREAD_Y) && defined(COLS_A)
-#define VECTOR_UCHAR VEC_DATA_TYPE(uchar, NUM_ELEMS_PROCESSED_PER_THREAD_X)
-#define VECTOR_UINT VEC_DATA_TYPE(uint, NUM_ELEMS_PROCESSED_PER_THREAD_X)
+#define VECTOR_TYPE VEC_DATA_TYPE(DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
+#define VECTOR_ACC_TYPE VEC_DATA_TYPE(ACC_DATA_TYPE, NUM_ELEMS_PROCESSED_PER_THREAD_X)
#define VECTOR_INT VEC_DATA_TYPE(int, NUM_ELEMS_PROCESSED_PER_THREAD_X)
/** This OpenCL kernel computes the matrix multiplication between matrix A (src0) and matrix B (src1) in case both matrices have not beed reshaped
*
* @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
*
+ * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
+ * @note The accumulator data type must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
* @note In case the input or output have to be reinterpreted as a 3D tensor, the following information must be passed at compile time:
* -# REINTERPRET_INPUT_AS_3D: To reinterpret the input as 3D
* -# REINTERPRET_OUTPUT_AS_3D: To reinterpret the output as 3D
@@ -302,93 +306,98 @@ __kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0),
int end_row_vec_a = src_addr.s0 + COLS_A;
- VECTOR_UINT acc0 = 0;
+ VECTOR_ACC_TYPE acc0 = 0;
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- VECTOR_UINT acc1 = 0;
+ VECTOR_ACC_TYPE acc1 = 0;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- VECTOR_UINT acc2 = 0;
+ VECTOR_ACC_TYPE acc2 = 0;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- VECTOR_UINT acc3 = 0;
+ VECTOR_ACC_TYPE acc3 = 0;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- VECTOR_UINT acc4 = 0;
+ VECTOR_ACC_TYPE acc4 = 0;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
for(; src_addr.s0 <= (end_row_vec_a - 2); src_addr += (int2)(2, 2 * src1_stride_y))
{
// Load values from matrix A
- uchar2 a0 = vload2(0, src0_ptr + src_addr.s0 + 0 * src0_stride_y);
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a0 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar2 a1 = vload2(0, src0_ptr + src_addr.s0 + 1 * src0_stride_y);
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a1 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar2 a2 = vload2(0, src0_ptr + src_addr.s0 + 2 * src0_stride_y);
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a2 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar2 a3 = vload2(0, src0_ptr + src_addr.s0 + 3 * src0_stride_y);
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a3 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- uchar2 a4 = vload2(0, src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+ VEC_DATA_TYPE(DATA_TYPE, 2)
+ a4 = vload2(0, (__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
// Load values from matrix B
- VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
- VECTOR_UCHAR b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1 + src1_stride_y);
+ VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
+ VECTOR_TYPE b1 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1 + src1_stride_y));
// Accumulate
- acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0.s0;
- acc0 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a0.s1;
+ acc0 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a0.s0;
+ acc0 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a0.s1;
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1.s0;
- acc1 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a1.s1;
+ acc1 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a1.s0;
+ acc1 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a1.s1;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2.s0;
- acc2 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a2.s1;
+ acc2 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a2.s0;
+ acc2 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a2.s1;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3.s0;
- acc3 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a3.s1;
+ acc3 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a3.s0;
+ acc3 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a3.s1;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4.s0;
- acc4 += CONVERT(b1, VECTOR_UINT) * (VECTOR_UINT)a4.s1;
+ acc4 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a4.s0;
+ acc4 += CONVERT(b1, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a4.s1;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}
for(; src_addr.s0 < end_row_vec_a; src_addr += (int2)(1, src1_stride_y))
{
// Load values from matrix A
- uchar a0 = *(src0_ptr + src_addr.s0 + 0 * src0_stride_y);
+ DATA_TYPE a0 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 0 * src0_stride_y));
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- uchar a1 = *(src0_ptr + src_addr.s0 + 1 * src0_stride_y);
+ DATA_TYPE a1 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 1 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- uchar a2 = *(src0_ptr + src_addr.s0 + 2 * src0_stride_y);
+ DATA_TYPE a2 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 2 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- uchar a3 = *(src0_ptr + src_addr.s0 + 3 * src0_stride_y);
+ DATA_TYPE a3 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 3 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- uchar a4 = *(src0_ptr + src_addr.s0 + 4 * src0_stride_y);
+ DATA_TYPE a4 = *((__global DATA_TYPE *)(src0_ptr + src_addr.s0 + 4 * src0_stride_y));
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
// Load values from matrix B
- VECTOR_UCHAR b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, src1_ptr + src_addr.s1);
+ VECTOR_TYPE b0 = VLOAD(NUM_ELEMS_PROCESSED_PER_THREAD_X)(0, (__global DATA_TYPE *)(src1_ptr + src_addr.s1));
// Accumulate
- acc0 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a0;
+ acc0 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a0;
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
- acc1 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a1;
+ acc1 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a1;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 1
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
- acc2 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a2;
+ acc2 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a2;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 2
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
- acc3 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a3;
+ acc3 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a3;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 3
#if NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
- acc4 += CONVERT(b0, VECTOR_UINT) * (VECTOR_UINT)a4;
+ acc4 += CONVERT(b0, VECTOR_ACC_TYPE) * (VECTOR_ACC_TYPE)a4;
#endif // NUM_ELEMS_PROCESSED_PER_THREAD_Y > 4
}
@@ -476,6 +485,8 @@ __kernel void gemmlowp_mm_midgard(IMAGE_DECLARATION(src0),
* The LHS matrix must be reshaped with @ref CLGEMMReshapeLHSMatrixKernel and the M0xK0 must be NOT transposed
* The RHS matrix must be reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the K0xN0 must be transposed
*
+ * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
+ * @note The accumulator data type must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
* @note If the first two dimensions of NDRange have been dispatched with "dummy_work_items" support, the option -DDUMMY_WORK_ITEMS must be passed at compile time.
* @note The GEMM's dimensions M and N must be passed at compile time using -DM and -DN (i.e. -DM=52 and -DN=90).
* @note The block's dimensions used for reshaping the LHS matrix and the RHS matrix (M0, N0 and K0) must be passed at compile time using -DM0, -DN0 and -DK0 (i.e. -DM0=4, -DN0=8, -DK0=4).
@@ -588,15 +599,15 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
REPEAT_VAR_INIT_TO_CONST(16, uint, zrhs, 0);
// Initialize the accumulators
- REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(ACC_DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(ACC_DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
for(int i = 0; i < k; i += K0)
{
// Load values from LHS matrix
- LOAD_BLOCK(M0, K0, uchar, a, lhs_addr, 0, LHS_STEP_X, zlhs);
+ LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_addr, 0, LHS_STEP_X, zlhs);
// Load values from RHS matrix
- LOAD_BLOCK(N0, K0, uchar, b, rhs_addr, 0, RHS_STEP_X, zrhs);
+ LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_addr, 0, RHS_STEP_X, zrhs);
// Partial matrix multiplication M0,N0,K0
ARM_MM_K0XN0XM0(M0, N0, K0, a, b, c);
@@ -643,6 +654,8 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
* The LHS matrix is NOT reshaped
* The RHS matrix is reshaped with @ref CLGEMMReshapeRHSMatrixKernel and the block K0xN0 is transposed
*
+ * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
+ * @note The accumulator data type must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
* @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
* @note The block's dimensions used for reshaping the RHS matrix (N0 and K0) must be passed at compile time using -DN0 and -DK0 (i.e. -DN0=8, -DK0=4).
* @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
@@ -661,7 +674,7 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
* -# DEPTH_GEMM3D: The depth of the output in case it has to be reinterpreted as a 3D tensor
* (HEIGHT_GEMM3D * DEPTH_GEMM3D) = columns LHS matrix
*
- * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: F16/F32
+ * @param[in] lhs_ptr Pointer to the LHS reshaped matrix. Supported data type: QASYMM8/QASYMM8_SIGNED
* @param[in] lhs_stride_x Stride of the LHS reshaped matrix in X dimension (in bytes)
* @param[in] lhs_step_x src_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] lhs_stride_y Stride of the LHS reshaped matrix in Y dimension (in bytes)
@@ -673,7 +686,7 @@ __kernel void gemmlowp_mm_reshaped_lhs_nt_rhs_t(IMAGE_DECLARATION(lhs),
* @param[in] rhs_stride_y Stride of the RHS reshaped matrix in Y dimension (in bytes)
* @param[in] rhs_step_y src_stride_y * number of elements along Y processed per workitem(in bytes)
* @param[in] rhs_offset_first_element_in_bytes The offset of the first element in the RHS reshaped matrix
- * @param[out] dst_ptr Pointer to the destination matrix Supported data type: same as @p lhs_ptr
+ * @param[out] dst_ptr Pointer to the destination matrix Supported data type: S32
* @param[in] dst_stride_x Stride of the destination matrix in X dimension (in bytes)
* @param[in] dst_step_x dst_stride_x * number of elements along X processed per workitem(in bytes)
* @param[in] dst_stride_y Stride of the destination matrix in Y dimension (in bytes)
@@ -758,15 +771,15 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
#endif // defined(REINTERPRET_INPUT_AS_3D)
// Initialize the accumulators
- REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(ACC_DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(ACC_DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(N0-1)=0;
for(int i = 0; i < K; i += K0)
{
// Load values from LHS matrix
- LOAD_BLOCK(M0, K0, uchar, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
+ LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
// Load values from RHS matrix
- LOAD_BLOCK(N0, K0, uchar, b, rhs_ptr, rhs_offset, RHS_STEP_X, zrhs);
+ LOAD_BLOCK(N0, K0, DATA_TYPE, b, rhs_ptr, rhs_offset, RHS_STEP_X, zrhs);
// Partial matrix multiplication M0,N0,K0
ARM_MM_K0XN0XM0(M0, N0, K0, a, b, c);
@@ -809,6 +822,8 @@ __kernel void gemmlowp_mm_reshaped_only_rhs_t(IMAGE_DECLARATION(lhs),
* The LHS matrix is NOT reshaped
* The RHS matrix is NOT reshaped
*
+ * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
+ * @note The accumulator data type must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
* @note The number of columns of LHS matrix must be passed at compile time using -DK (i.e. -DK=64)
* @note The number of M0 rows to process must be passed at compile time using -DM0 (i.e. -DM0=2)
* @note The number of N0 columns to process must be passed at compile time using -DN0 (i.e. -DN0=2)
@@ -908,20 +923,20 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs),
#endif // defined(REINTERPRET_INPUT_AS_3D)
// Initialize the accumulators
- REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(uint, N0), c, 0); //VEC_DATA_TYPE(uint, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
+ REPEAT_VAR_INIT_TO_CONST(M0, VEC_DATA_TYPE(ACC_DATA_TYPE, N0), c, 0); //VEC_DATA_TYPE(ACC_DATA_TYPE, N0) c0=0,c1=0,c2=0,... c(M0-1)=0;
int i = 0;
for(; i <= (K - K0); i += K0)
{
// Load values from LHS matrix
- LOAD_BLOCK(M0, K0, uchar, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
+ LOAD_BLOCK(M0, K0, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
// Load values from RHS matrix
- LOAD_BLOCK(K0, N0, uchar, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs);
+ LOAD_BLOCK(K0, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs);
// Transpose the values from RHS matrix
- TRANSPOSE_K0XN0(K0, N0, b_t, b);
+ TRANSPOSE_K0XN0(K0, N0, b_t, b, DATA_TYPE);
// Partial matrix multiplication M0,N0,K0
ARM_MM_K0XN0XM0(M0, N0, K0, a, b_t, c);
@@ -935,13 +950,13 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs),
for(; i < K; ++i)
{
// Load values from LHS matrix
- LOAD_BLOCK(M0, 1, uchar, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
+ LOAD_BLOCK(M0, 1, DATA_TYPE, a, lhs_ptr, lhs_offset, lhs_stride_y, zlhs);
// Load values from RHS matrix
- LOAD_BLOCK(1, N0, uchar, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs);
+ LOAD_BLOCK(1, N0, DATA_TYPE, b, rhs_ptr, rhs_offset, rhs_stride_y, zrhs);
// Transpose the values from RHS matrix
- TRANSPOSE_K0XN0(1, N0, b_t, b);
+ TRANSPOSE_K0XN0(1, N0, b_t, b, DATA_TYPE);
// Partial matrix multiplication M0,N0,1
ARM_MM_K0XN0XM0(M0, N0, 1, a, b_t, c);
@@ -975,6 +990,8 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs),
}
#endif // defined(M0) && defined(N0) && defined(K0) && defined(K)
+#endif // defined(DATA_TYPE) && defined(ACC_DATA_TYPE)
+
#if defined(COLS_A)
/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A.
*
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
index cda7a83de7..78df0eec16 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyKernel.cpp
@@ -212,6 +212,8 @@ void CLGEMMLowpMatrixMultiplyKernel::configure(const ICLTensor *input0, const IC
build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(input0->info()->dimension(0)));
build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_X=" + support::cpp11::to_string(num_elements_processed.x()));
build_opts.add_option("-DNUM_ELEMS_PROCESSED_PER_THREAD_Y=" + support::cpp11::to_string(num_elements_processed.y()));
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type()));
+ build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(input0->info()->data_type()));
kernel_name = "gemmlowp_mm_midgard";
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp
index 09caeeea55..3e887d8163 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyNativeKernel.cpp
@@ -216,6 +216,8 @@ void CLGEMMLowpMatrixMultiplyNativeKernel::configure(const ICLTensor *input0, co
build_opts.add_option("-DM0=" + support::cpp11::to_string(lhs_info.m0));
build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0));
build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0));
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type()));
+ build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(input0->info()->data_type()));
std::string kernel_name("gemmlowp_mm_native");
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp
index 050b792c4e..8d3aff6603 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedKernel.cpp
@@ -42,13 +42,9 @@
#include <cstdint>
#include <tuple>
-using namespace arm_compute;
-using namespace arm_compute::misc::shape_calculator;
-
namespace arm_compute
{
-class Coordinates;
-} // namespace arm_compute
+using namespace misc::shape_calculator;
namespace
{
@@ -210,6 +206,8 @@ void CLGEMMLowpMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0,
build_opts.add_option("-DK0=" + support::cpp11::to_string(lhs_info.k0));
build_opts.add_option("-DV0=" + support::cpp11::to_string(lhs_info.v0));
build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type()));
+ build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(input0->info()->data_type()));
std::string kernel_name("gemmlowp_mm_reshaped_");
kernel_name += lhs_info.transpose ? "lhs_t_" : "lhs_nt_";
@@ -310,4 +308,5 @@ void CLGEMMLowpMatrixMultiplyReshapedKernel::run(const Window &window, cl::Comma
enqueue(queue, *this, slice, lws_hint(), _use_dummy_work_items);
}
while(window.slide_window_slice_3D(slice));
-} \ No newline at end of file
+}
+} // namespace arm_compute \ No newline at end of file
diff --git a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp
index 779f96e7cf..3fa2fad8fd 100644
--- a/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel.cpp
@@ -54,6 +54,7 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1,
const GEMMReshapeInfo &gemm_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input0, input1, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3");
@@ -218,6 +219,8 @@ void CLGEMMLowpMatrixMultiplyReshapedOnlyRHSKernel::configure(const ICLTensor *i
build_opts.add_option("-DN0=" + support::cpp11::to_string(rhs_info.n0));
build_opts.add_option("-DK0=" + support::cpp11::to_string(rhs_info.k0));
build_opts.add_option("-DH0=" + support::cpp11::to_string(rhs_info.h0));
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(input0->info()->data_type()));
+ build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(input0->info()->data_type()));
std::string kernel_name("gemmlowp_mm_reshaped_only_rhs_");
kernel_name += rhs_info.transpose ? "t" : "nt";
diff --git a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
index d322723150..a4270d7923 100644
--- a/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
+++ b/src/runtime/CL/functions/CLGEMMConvolutionLayer.cpp
@@ -402,7 +402,7 @@ Status CLGEMMConvolutionLayer::validate(const ITensorInfo *input, const ITensorI
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(weights_info.are_reshaped(), "Weights already reshaped are not supported!");
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32);
const bool is_quantized_per_channel = is_data_type_quantized_per_channel(weights->data_type());
if(is_quantized_per_channel)
diff --git a/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp b/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp
index 6ead11ab23..106d650109 100644
--- a/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp
+++ b/tests/validation/CL/GEMMLowpMatrixMultiplyReshapedOnlyRHS.cpp
@@ -82,7 +82,8 @@ const auto k_values = framework::dataset::make("K", 23);
const auto b_values = framework::dataset::make("batch_size", 1, 3);
/** M0 values to test - Precommit */
-const auto m0_values_precommit = framework::dataset::make("M0", {4, 6});
+const auto m0_values_precommit_1 = framework::dataset::make("M0", {4});
+const auto m0_values_precommit_2 = framework::dataset::make("M0", {6});
/** N0 values to test - Precommit */
const auto n0_values_precommit = framework::dataset::make("N0", { 4 });
@@ -162,7 +163,7 @@ DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(combine(combi
n_values),
k_values),
framework::dataset::make("batch_size", 1)),
- m0_values_precommit),
+ m0_values_precommit_1),
n0_values_precommit),
k0_values_precommit),
h0_values_precommit),
@@ -172,25 +173,44 @@ m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_va
validate_configuration(m_value, n_value, k_value, b_value, m0_value, n0_value, k0_value, h0_value, i_value_rhs);
}
-FIXTURE_DATA_TEST_CASE(RunSmall, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(
+FIXTURE_DATA_TEST_CASE(RunSmall_1, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_values,
+ n_values),
+ k_values),
+ b_values),
+ m0_values_precommit_1),
+ n0_values_precommit),
+ k0_values_precommit),
+ h0_values_precommit),
+ i_values_rhs),
+ t_values_rhs),
+ framework::dataset::make("DataType", { DataType::QASYMM8 })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall_2, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
b_values),
- m0_values_precommit),
+ m0_values_precommit_2),
n0_values_precommit),
k0_values_precommit),
h0_values_precommit),
i_values_rhs),
- t_values_rhs))
+ t_values_rhs),
+ framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED })))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_values,
n_values),
k_values),
@@ -200,32 +220,53 @@ FIXTURE_DATA_TEST_CASE(RunLarge, CLGEMMLowpMatrixMultiplyReshapedOnlyRHSFixture,
k0_values_nightly),
h0_values_nightly),
i_values_rhs),
- t_values_rhs))
+ t_values_rhs),
+ framework::dataset::make("DataType", { DataType::QASYMM8 })))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunSmall3D, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3DFixture, framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+FIXTURE_DATA_TEST_CASE(RunSmall3D_1, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3DFixture, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ m_w_values,
+ m_h_values),
+ n_values),
+ k_values),
+ b_values),
+ m0_values_precommit_1),
+ n0_values_precommit),
+ k0_values_precommit),
+ h0_values_precommit),
+ i_values_rhs),
+ t_values_rhs),
+ framework::dataset::make("DataType", { DataType::QASYMM8 })))
+{
+ // Validate output
+ validate(CLAccessor(_target), _reference);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmall3D_2, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3DFixture, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_w_values,
m_h_values),
n_values),
k_values),
b_values),
- m0_values_precommit),
+ m0_values_precommit_2),
n0_values_precommit),
k0_values_precommit),
h0_values_precommit),
i_values_rhs),
- t_values_rhs))
+ t_values_rhs),
+ framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED })))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3DFixture, framework::DatasetMode::NIGHTLY,
- combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(combine(
m_w_values,
m_h_values),
n_values),
@@ -236,7 +277,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge3D, CLGEMMLowpMatrixMultiplyReshapedOnlyRHS3DFixt
k0_values_nightly),
h0_values_nightly),
i_values_rhs),
- t_values_rhs))
+ t_values_rhs),
+ framework::dataset::make("DataType", { DataType::QASYMM8 })))
{
// Validate output
validate(CLAccessor(_target), _reference);
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index c17105edad..db52be5062 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -877,7 +877,8 @@ class GEMMLowpMatrixMultiplyReshapedOnlyRHSValidationFixture : public framework:
{
public:
template <typename...>
- void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs)
+ void setup(unsigned int m, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
+ unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, DataType data_type)
{
GEMMLHSMatrixInfo lhs_info;
lhs_info.m0 = m0;
@@ -894,24 +895,40 @@ public:
const TensorShape lhs_shape(k, m, batch_size);
const TensorShape rhs_shape(n, k, batch_size);
- _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info);
- _reference = compute_reference(lhs_shape, rhs_shape);
+ _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, data_type);
+ _reference = compute_reference(lhs_shape, rhs_shape, data_type);
}
protected:
template <typename U>
void fill(U &&tensor, int i)
{
- // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
- std::uniform_int_distribution<> distribution(1, 254);
- library->fill(tensor, distribution, i);
+ switch(tensor.data_type())
+ {
+ case DataType::QASYMM8:
+ {
+ // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
+ std::uniform_int_distribution<> distribution(1, 254);
+ library->fill(tensor, distribution, i);
+ }
+ break;
+ case DataType::QASYMM8_SIGNED:
+ {
+ std::uniform_int_distribution<> distribution(-127, 126);
+ library->fill(tensor, distribution, i);
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type");
+ }
}
- TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info)
+ TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info, DataType data_type)
{
// Create tensors
- TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
- TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
+ TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
+ TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
TensorType rhs_reshaped;
TensorType dst;
@@ -952,21 +969,36 @@ protected:
return dst;
}
- SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape)
+ SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, DataType data_type)
{
TensorShape dst_shape = lhs_shape;
dst_shape[0] = rhs_shape[0];
dst_shape[1] = lhs_shape[1];
- // Create reference
- SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
- SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
+ if(data_type == DataType::QASYMM8)
+ {
+ // Create reference
+ SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
+ SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
- // Fill reference
- fill(lhs, 0);
- fill(rhs, 1);
+ // Fill reference
+ fill(lhs, 0);
+ fill(rhs, 1);
- return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
+ return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
+ }
+ else
+ {
+ // Create reference
+ SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
+ SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
+
+ // Fill reference
+ fill(lhs, 0);
+ fill(rhs, 1);
+
+ return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
+ }
}
TensorType _target{};
@@ -978,8 +1010,8 @@ class GEMMLowpMatrixMultiplyReshapedOnlyRHS3DValidationFixture : public framewor
{
public:
template <typename...>
- void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0, unsigned int k0, unsigned int h0,
- bool interleave_rhs, bool transpose_rhs)
+ void setup(unsigned int m_w, unsigned int m_h, unsigned int n, unsigned int k, unsigned int batch_size, unsigned int m0, unsigned int n0,
+ unsigned int k0, unsigned int h0, bool interleave_rhs, bool transpose_rhs, DataType data_type)
{
GEMMLHSMatrixInfo lhs_info;
lhs_info.m0 = m0;
@@ -999,24 +1031,40 @@ public:
const TensorShape lhs_shape(k, m, batch_size);
const TensorShape rhs_shape(n, k, batch_size);
- _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h);
- _reference = compute_reference(lhs_shape, rhs_shape, m_h);
+ _target = compute_target(lhs_shape, rhs_shape, lhs_info, rhs_info, m_h, data_type);
+ _reference = compute_reference(lhs_shape, rhs_shape, m_h, data_type);
}
protected:
template <typename U>
void fill(U &&tensor, int i)
{
- // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
- std::uniform_int_distribution<> distribution(1, 254);
- library->fill(tensor, distribution, i);
+ switch(tensor.data_type())
+ {
+ case DataType::QASYMM8:
+ {
+ // Between 1 and 254 in order to avoid having -128 and 128 for the DOT product path
+ std::uniform_int_distribution<> distribution(1, 254);
+ library->fill(tensor, distribution, i);
+ }
+ break;
+ case DataType::QASYMM8_SIGNED:
+ {
+ std::uniform_int_distribution<> distribution(-127, 126);
+ library->fill(tensor, distribution, i);
+ }
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported data type");
+ }
}
- TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info, const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h)
+ TensorType compute_target(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const GEMMLHSMatrixInfo &lhs_info,
+ const GEMMRHSMatrixInfo &rhs_info, unsigned int m_h, DataType data_type)
{
// Create tensors
- TensorType lhs = create_tensor<TensorType>(lhs_shape, DataType::QASYMM8, 1);
- TensorType rhs = create_tensor<TensorType>(rhs_shape, DataType::QASYMM8, 1);
+ TensorType lhs = create_tensor<TensorType>(lhs_shape, data_type, 1);
+ TensorType rhs = create_tensor<TensorType>(rhs_shape, data_type, 1);
TensorType rhs_reshaped;
TensorType dst;
@@ -1057,7 +1105,7 @@ protected:
return dst;
}
- SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h)
+ SimpleTensor<int32_t> compute_reference(const TensorShape &lhs_shape, const TensorShape &rhs_shape, unsigned int m_h, DataType data_type)
{
TensorShape dst_shape = lhs_shape;
dst_shape.set(0, rhs_shape[0]);
@@ -1065,15 +1113,30 @@ protected:
dst_shape.set(2, m_h);
dst_shape.set(3, lhs_shape[2]);
- // Create reference
- SimpleTensor<uint8_t> lhs{ lhs_shape, DataType::QASYMM8, 1 };
- SimpleTensor<uint8_t> rhs{ rhs_shape, DataType::QASYMM8, 1 };
+ if(data_type == DataType::QASYMM8)
+ {
+ // Create reference
+ SimpleTensor<uint8_t> lhs{ lhs_shape, data_type, 1 };
+ SimpleTensor<uint8_t> rhs{ rhs_shape, data_type, 1 };
- // Fill reference
- fill(lhs, 0);
- fill(rhs, 1);
+ // Fill reference
+ fill(lhs, 0);
+ fill(rhs, 1);
- return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
+ return reference::gemmlowp_matrix_multiply_core<int32_t, uint8_t>(lhs, rhs, dst_shape, 0, 0);
+ }
+ else
+ {
+ // Create reference
+ SimpleTensor<int8_t> lhs{ lhs_shape, data_type, 1 };
+ SimpleTensor<int8_t> rhs{ rhs_shape, data_type, 1 };
+
+ // Fill reference
+ fill(lhs, 0);
+ fill(rhs, 1);
+
+ return reference::gemmlowp_matrix_multiply_core<int32_t, int8_t>(lhs, rhs, dst_shape, 0, 0);
+ }
}
TensorType _target{};