aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2021-07-02 09:22:14 +0100
committerGeorgios Pinitas <georgios.pinitas@arm.com>2021-07-08 19:13:10 +0000
commitda816752cad76c8e1b367e8e9c648994a1af599a (patch)
treed6efc7e559e717c07c719037947b2533a835b583
parent4411e1f1d49efb78fb07a3c183f386307f951bad (diff)
downloadComputeLibrary-da816752cad76c8e1b367e8e9c648994a1af599a.tar.gz
Remove redundant implementations of Add/Sub operators
Allows only implementations where inputs/output are of the same data type and removes legacy Computer Vision ones. Signed-off-by: Georgios Pinitas <georgios.pinitas@arm.com> Change-Id: Ia2b3d23a04236aab682f0c36a1110a30f7c06d1c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5900 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
-rw-r--r--Android.bp3
-rw-r--r--arm_compute/runtime/NEON/functions/NEArithmeticAddition.h3
-rw-r--r--arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h3
-rw-r--r--filelist.json9
-rw-r--r--src/core/CL/cl_kernels/elementwise_operation.cl35
-rw-r--r--src/core/CL/cl_kernels/elementwise_operation_quantized.cl24
-rw-r--r--src/core/cpu/kernels/CpuAddKernel.cpp141
-rw-r--r--src/core/cpu/kernels/CpuAddKernel.h3
-rw-r--r--src/core/cpu/kernels/CpuSubKernel.cpp87
-rw-r--r--src/core/cpu/kernels/CpuSubKernel.h3
-rw-r--r--src/core/cpu/kernels/add/neon/integer.cpp170
-rw-r--r--src/core/cpu/kernels/add/neon/list.h3
-rw-r--r--src/core/cpu/kernels/add/sve/integer.cpp201
-rw-r--r--src/core/cpu/kernels/add/sve/list.h3
-rw-r--r--src/core/cpu/kernels/sub/neon/integer.cpp183
-rw-r--r--src/core/cpu/kernels/sub/neon/list.h3
-rw-r--r--src/core/gpu/cl/kernels/ClElementwiseKernel.cpp77
-rw-r--r--src/runtime/cpu/operators/CpuAdd.h3
-rw-r--r--src/runtime/cpu/operators/CpuSub.h3
-rw-r--r--tests/validation/CL/ArithmeticAddition.cpp76
-rw-r--r--tests/validation/CL/ArithmeticSubtraction.cpp63
-rw-r--r--tests/validation/CL/ElementwiseMax.cpp11
-rw-r--r--tests/validation/CL/ElementwiseMin.cpp11
-rw-r--r--tests/validation/CL/ElementwiseSquaredDiff.cpp11
-rw-r--r--tests/validation/CL/PReluLayer.cpp9
-rw-r--r--tests/validation/NEON/ArithmeticAddition.cpp57
-rw-r--r--tests/validation/NEON/ArithmeticSubtraction.cpp94
-rw-r--r--tests/validation/fixtures/ArithmeticOperationsFixture.h96
28 files changed, 249 insertions, 1136 deletions
diff --git a/Android.bp b/Android.bp
index 90facab3ca..6a76e4c92a 100644
--- a/Android.bp
+++ b/Android.bp
@@ -295,12 +295,10 @@ cc_library_static {
"src/core/cpu/kernels/activation/sve/qasymm8.cpp",
"src/core/cpu/kernels/activation/sve/qasymm8_signed.cpp",
"src/core/cpu/kernels/activation/sve/qsymm16.cpp",
- "src/core/cpu/kernels/add/neon/integer.cpp",
"src/core/cpu/kernels/add/neon/qasymm8.cpp",
"src/core/cpu/kernels/add/neon/qasymm8_signed.cpp",
"src/core/cpu/kernels/add/neon/qsymm16.cpp",
"src/core/cpu/kernels/add/sve/impl.cpp",
- "src/core/cpu/kernels/add/sve/integer.cpp",
"src/core/cpu/kernels/add/sve/qasymm8.cpp",
"src/core/cpu/kernels/add/sve/qasymm8_signed.cpp",
"src/core/cpu/kernels/add/sve/qsymm16.cpp",
@@ -325,7 +323,6 @@ cc_library_static {
"src/core/cpu/kernels/scale/sve/qasymm8.cpp",
"src/core/cpu/kernels/scale/sve/qasymm8_signed.cpp",
"src/core/cpu/kernels/softmax/impl/sve/impl.cpp",
- "src/core/cpu/kernels/sub/neon/integer.cpp",
"src/core/cpu/kernels/sub/neon/qasymm8.cpp",
"src/core/cpu/kernels/sub/neon/qasymm8_signed.cpp",
"src/core/cpu/kernels/sub/neon/qsymm16.cpp",
diff --git a/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h b/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h
index b8e46ff36e..b9012b02a9 100644
--- a/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h
+++ b/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h
@@ -62,9 +62,6 @@ public:
* |QSYMM16 |QSYMM16 |QASYMM16 |
* |QSYMM16 |QSYMM16 |S32 |
* |U8 |U8 |U8 |
- * |U8 |U8 |S16 |
- * |U8 |S16 |S16 |
- * |S16 |U8 |S16 |
* |S16 |S16 |S16 |
* |S32 |S32 |S32 |
* |F16 |F16 |F16 |
diff --git a/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h b/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h
index 0c72e946f6..0b4db61d29 100644
--- a/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h
+++ b/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h
@@ -68,9 +68,6 @@ public:
* |QSYMM16 |QSYMM16 |QASYMM16 |
* |QSYMM16 |QSYMM16 |S32 |
* |U8 |U8 |U8 |
- * |U8 |U8 |S16 |
- * |U8 |S16 |S16 |
- * |S16 |U8 |S16 |
* |S16 |S16 |S16 |
* |S32 |S32 |S32 |
* |F16 |F16 |F16 |
diff --git a/filelist.json b/filelist.json
index 56274954c8..9d07492e6a 100644
--- a/filelist.json
+++ b/filelist.json
@@ -708,9 +708,6 @@
],
"qasymm8_signed": [
"src/core/cpu/kernels/add/sve/qasymm8_signed.cpp"
- ],
- "integer": [
- "src/core/cpu/kernels/add/sve/integer.cpp"
]
},
"neon": {
@@ -722,9 +719,6 @@
],
"qasymm8_signed": [
"src/core/cpu/kernels/add/neon/qasymm8_signed.cpp"
- ],
- "integer": [
- "src/core/cpu/kernels/add/neon/integer.cpp"
]
}
}
@@ -1661,9 +1655,6 @@
],
"qasymm8_signed": [
"src/core/cpu/kernels/sub/neon/qasymm8_signed.cpp"
- ],
- "integer": [
- "src/core/cpu/kernels/sub/neon/integer.cpp"
]
}
}
diff --git a/src/core/CL/cl_kernels/elementwise_operation.cl b/src/core/CL/cl_kernels/elementwise_operation.cl
index c8250045dc..99f725645d 100644
--- a/src/core/CL/cl_kernels/elementwise_operation.cl
+++ b/src/core/CL/cl_kernels/elementwise_operation.cl
@@ -23,7 +23,7 @@
*/
#include "helpers.h"
-#if defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(DATA_TYPE_OUT)
+#if defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(DATA_TYPE)
/** List of all the operations supported by this kernel.
* @note ADD and SUB operations, when executed on integers, support saturation */
@@ -43,17 +43,17 @@
#if VEC_SIZE_OUT == 1
#define PRELU(x, y) (x > 0 ? x : x * y)
#else // VEC_SIZE_OUT == 1
-#define PRELU(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE_OUT)0), SELECT_VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT))))
+#define PRELU(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))))
#endif // VEC_SIZE_OUT == 1
#if defined(S32)
-#define DIV(x, y) CONVERT(floor(CONVERT(x, VEC_DATA_TYPE(float, VEC_SIZE_OUT)) / CONVERT(y, VEC_DATA_TYPE(float, VEC_SIZE_OUT))), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT));
+#define DIV(x, y) CONVERT(floor(CONVERT(x, VEC_DATA_TYPE(float, VEC_SIZE_OUT)) / CONVERT(y, VEC_DATA_TYPE(float, VEC_SIZE_OUT))), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT));
#else /* S32 */
#define DIV(x, y) (x / y)
#endif /* S32 */
-#define AND(x, y) (CONVERT((x && y), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT))1))
-#define OR(x, y) (CONVERT((x || y), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT))1))
+#define AND(x, y) (CONVERT((x && y), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))1))
+#define OR(x, y) (CONVERT((x || y), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))1))
#define OP_FUN_NAME_STR(op) elementwise_operation_##op
#define OP_FUN_NAME(op) OP_FUN_NAME_STR(op)
@@ -66,8 +66,7 @@
*
* @note Vector sizes of inputs and output have to be passed at compile time using -DVEC_SIZE_IN1, -DVEC_SIZE_IN2, -DVEC_SIZE_OUT.
* @note Leftover vector size has to be passed at compile time using -DVEC_SIZE_LEFTOVER. e.g. -DVEC_SIZE_OUT=3. It is defined as the remainder between the input's first dimension and VEC_SIZE_OUT
- * @note The input and output data_types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT:
- * e.g. -DDATA_TYPE_IN1=uchar -DDATA_TYPE_IN2=uchar -DDATA_TYPE_OUT=short
+ * @note The input and output data_types need to be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=uchar
* @note To perform saturating operation -DSATURATE has to be passed to the compiler otherwise wrapping policy will be used.
* @note The element-wise operation to be executed has to be passed at compile time using -DOP (e.g., -DOP=ADD)
*
@@ -114,23 +113,23 @@ __kernel void OP_FUN_NAME(OP)(
uint out_x_offs = max((int)(get_global_id(0) * VEC_SIZE_OUT - (VEC_SIZE_OUT - VEC_SIZE_LEFTOVER) % VEC_SIZE_OUT), 0);
// Get pixels pointer
- __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + in1_x_offs * sizeof(DATA_TYPE_IN1) + get_global_id(1) * in1_step_y + get_global_id(2) * in1_step_z;
- __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + in2_x_offs * sizeof(DATA_TYPE_IN2) + get_global_id(1) * in2_step_y + get_global_id(2) * in2_step_z;
- __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + out_x_offs * sizeof(DATA_TYPE_OUT) + get_global_id(1) * out_step_y + get_global_id(2) * out_step_z;
+ __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + in1_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * in1_step_y + get_global_id(2) * in1_step_z;
+ __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + in2_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * in2_step_y + get_global_id(2) * in2_step_z;
+ __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + out_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * out_step_y + get_global_id(2) * out_step_z;
// Load values
- VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)
- in_a = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN1, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE_IN1 *)in1_addr)), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT));
- VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)
- in_b = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN2, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE_IN2 *)in2_addr)), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT));
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)
+ in_a = CONVERT((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE *)in1_addr)), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT));
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)
+ in_b = CONVERT((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE *)in2_addr)), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT));
// Calculate and store result
- VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)
+ VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)
res0 = OP(in_a, in_b);
#if defined(ACTIVATION_TYPE)
- res0 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE_OUT, VEC_SIZE_OUT, res0, A_VAL, B_VAL);
+ res0 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE_OUT, res0, A_VAL, B_VAL);
#endif // defined(ACTIVATION_TYPE)
- STORE_VECTOR_SELECT(res, DATA_TYPE_OUT, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
+ STORE_VECTOR_SELECT(res, DATA_TYPE, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
}
-#endif /* defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(DATA_TYPE_OUT) */
+#endif /* defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(DATA_TYPE) */
diff --git a/src/core/CL/cl_kernels/elementwise_operation_quantized.cl b/src/core/CL/cl_kernels/elementwise_operation_quantized.cl
index a08c3b2d47..0051babf03 100644
--- a/src/core/CL/cl_kernels/elementwise_operation_quantized.cl
+++ b/src/core/CL/cl_kernels/elementwise_operation_quantized.cl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -28,7 +28,7 @@
#define MAX(x, y) max((x), (y))
#define MIN(x, y) min((x), (y))
#define SQUARED_DIFF(x, y) (x - y) * (x - y)
-#define PRELU(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE_OUT)0), SELECT_VEC_DATA_TYPE(float, VEC_SIZE_OUT))))
+#define PRELU(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(float, VEC_SIZE_OUT))))
#define DIV(x, y) (x / y)
#define CONVERT_RTE(x, type) (convert_##type##_rte((x)))
@@ -37,11 +37,11 @@
#define OP_FUN_NAME_STR(op) elementwise_operation_##op##_quantized
#define OP_FUN_NAME(op) OP_FUN_NAME_STR(op)
-#if defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(OFFSET_IN1) && defined(OFFSET_IN2) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE_OUT)
+#if defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(OFFSET_IN1) && defined(OFFSET_IN2) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE)
#define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE_OUT)
#define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE_OUT)
-#define VEC_TYPE VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)
+#define VEC_TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)
/** This function executes an element-wise operation among two tensors.
*
@@ -57,7 +57,7 @@
* @note To perform saturating operation -DSATURATE has to be passed to the compiler otherwise wrapping policy will be used.
* @note The element-wise operation to be executed has to be passed at compile time using -DOP (e.g., -DOP=ADD)
* @note For QSYMM16 operations OFFSET_IN1, OFFSET_IN2 and OFFSET_OUT must be set to zero
- * @note The data type must be passed at compile time using -DDATA_TYPE_OUT, i.e. -DDATA_TYPE_OUT=uchar
+ * @note The data type must be passed at compile time using -DDATA_TYPE, i.e. -DDATA_TYPE=uchar
*
* @param[in] in1_ptr Pointer to the source tensor. Supported data types: QASYMM8/QSYMM16
* @param[in] in1_stride_x Stride of the source tensor in X dimension (in bytes)
@@ -102,12 +102,12 @@ __kernel void OP_FUN_NAME(OP)(
uint out_x_offs = max((int)(get_global_id(0) * VEC_SIZE_OUT - (VEC_SIZE_OUT - VEC_SIZE_LEFTOVER) % VEC_SIZE_OUT), 0);
// Get pixels pointer
- __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + in1_x_offs * sizeof(DATA_TYPE_OUT) + get_global_id(1) * in1_step_y + get_global_id(2) * in1_step_z;
- __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + in2_x_offs * sizeof(DATA_TYPE_OUT) + get_global_id(1) * in2_step_y + get_global_id(2) * in2_step_z;
- __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + out_x_offs * sizeof(DATA_TYPE_OUT) + get_global_id(1) * out_step_y + get_global_id(2) * out_step_z;
+ __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + in1_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * in1_step_y + get_global_id(2) * in1_step_z;
+ __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + in2_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * in2_step_y + get_global_id(2) * in2_step_z;
+ __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + out_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * out_step_y + get_global_id(2) * out_step_z;
- VEC_INT in_a = CONVERT((VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE_OUT *)in1_addr)), VEC_INT);
- VEC_INT in_b = CONVERT((VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE_OUT *)in2_addr)), VEC_INT);
+ VEC_INT in_a = CONVERT((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE *)in1_addr)), VEC_INT);
+ VEC_INT in_b = CONVERT((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE *)in2_addr)), VEC_INT);
in_a = SUB(in_a, (VEC_INT)((int)OFFSET_IN1));
in_b = SUB(in_b, (VEC_INT)((int)OFFSET_IN2));
@@ -118,6 +118,6 @@ __kernel void OP_FUN_NAME(OP)(
const VEC_TYPE res0 = CONVERT_SAT(CONVERT_DOWN(qresf32, VEC_INT), VEC_TYPE);
// Store result
- STORE_VECTOR_SELECT(res, DATA_TYPE_OUT, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
+ STORE_VECTOR_SELECT(res, DATA_TYPE, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0)
}
-#endif /* defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(OFFSET_IN1) && defined(OFFSET_IN2) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE_OUT) */
+#endif /* defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(OFFSET_IN1) && defined(OFFSET_IN2) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE) */
diff --git a/src/core/cpu/kernels/CpuAddKernel.cpp b/src/core/cpu/kernels/CpuAddKernel.cpp
index 12766037a7..61b7b19443 100644
--- a/src/core/cpu/kernels/CpuAddKernel.cpp
+++ b/src/core/cpu/kernels/CpuAddKernel.cpp
@@ -45,14 +45,7 @@ namespace
{
struct AddSelectorData
{
- /* Data types for all ITensorInfos:
- dt1 -> src0
- dt2 -> src1
- dt3 -> dst
- */
- DataType dt1;
- DataType dt2;
- DataType dt3;
+ DataType dt;
const CPUInfo &ci;
};
@@ -72,7 +65,7 @@ static const AddKernel available_kernels[] =
"sve2_qu8_add",
[](const AddSelectorData & data)
{
- return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)) && data.ci.has_sve();
+ return (data.dt == DataType::QASYMM8) && data.ci.has_sve();
},
REGISTER_QASYMM8_SVE(arm_compute::cpu::add_qasymm8_sve)
},
@@ -80,7 +73,7 @@ static const AddKernel available_kernels[] =
"sve2_qs8_add",
[](const AddSelectorData & data)
{
- return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)) && data.ci.has_sve();
+ return (data.dt == DataType::QASYMM8_SIGNED) && data.ci.has_sve();
},
REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::add_qasymm8_signed_sve)
},
@@ -88,7 +81,7 @@ static const AddKernel available_kernels[] =
"sve2_qs16_add",
[](const AddSelectorData & data)
{
- return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)) && data.ci.has_sve();
+ return (data.dt == DataType::QSYMM16) && data.ci.has_sve();
},
REGISTER_QSYMM16_SVE(arm_compute::cpu::add_qsymm16_sve)
},
@@ -98,7 +91,7 @@ static const AddKernel available_kernels[] =
"sve_fp32_add",
[](const AddSelectorData & data)
{
- return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)) && data.ci.has_sve();
+ return (data.dt == DataType::F32) && data.ci.has_sve();
},
REGISTER_FP32_SVE(arm_compute::cpu::add_same_sve<float>)
},
@@ -106,7 +99,7 @@ static const AddKernel available_kernels[] =
"sve_fp16_add",
[](const AddSelectorData & data)
{
- return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_sve();
+ return (data.dt == DataType::F16) && data.ci.has_sve();
},
REGISTER_FP16_SVE(arm_compute::cpu::add_same_sve<float16_t>)
},
@@ -114,7 +107,7 @@ static const AddKernel available_kernels[] =
"sve_u8_add",
[](const AddSelectorData & data)
{
- return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)) && data.ci.has_sve();
+ return (data.dt == DataType::U8) && data.ci.has_sve();
},
REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<uint8_t>)
},
@@ -122,7 +115,7 @@ static const AddKernel available_kernels[] =
"sve_s16_add",
[](const AddSelectorData & data)
{
- return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)) && data.ci.has_sve();
+ return (data.dt == DataType::S16) && data.ci.has_sve();
},
REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<int16_t>)
},
@@ -130,39 +123,15 @@ static const AddKernel available_kernels[] =
"sve_s32_add",
[](const AddSelectorData & data)
{
- return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)) && data.ci.has_sve();
+ return (data.dt == DataType::S32) && data.ci.has_sve();
},
REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve<int32_t>)
},
- {
- "sve_u8_s16_s16_add",
- [](const AddSelectorData & data)
- {
- return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)) && data.ci.has_sve();
- },
- REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_s16_s16_sve)
- },
- {
- "sve_s16_u8_s16_add",
- [](const AddSelectorData & data)
- {
- return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)) && data.ci.has_sve();
- },
- REGISTER_INTEGER_SVE(arm_compute::cpu::add_s16_u8_s16_sve)
- },
- {
- "sve_u8_u8_s16_add",
- [](const AddSelectorData & data)
- {
- return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)) && data.ci.has_sve();
- },
- REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_u8_s16_sve)
- },
#endif /* defined(ARM_COMPUTE_ENABLE_SVE) */
#if defined(ARM_COMPUTE_ENABLE_NEON)
{
"neon_fp32_add",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); },
+ [](const AddSelectorData & data) { return (data.dt == DataType::F32); },
REGISTER_FP32_NEON(arm_compute::cpu::add_same_neon<float>)
},
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
@@ -170,56 +139,41 @@ static const AddKernel available_kernels[] =
"neon_fp16_add",
[](const AddSelectorData & data)
{
- return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_fp16();
+ return (data.dt == DataType::F16) && data.ci.has_fp16();
},
REGISTER_FP16_NEON(arm_compute::cpu::add_same_neon<float16_t>)
},
#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
{
"neon_u8_add",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)); },
+ [](const AddSelectorData & data) { return (data.dt == DataType::U8); },
REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon<uint8_t>)
},
{
"neon_s16_add",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)); },
+ [](const AddSelectorData & data) { return (data.dt == DataType::S16); },
REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon<int16_t>)
},
{
"neon_s32_add",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)); },
+ [](const AddSelectorData & data) { return (data.dt == DataType::S32); },
REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon<int32_t>)
},
- {
- "neon_u8_s16_s16_add",
- [](const AddSelectorData & data) { return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)); },
- REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_s16_s16_neon)
- },
- {
- "neon_s16_u8_s16_add",
- [](const AddSelectorData & data) { return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)); },
- REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_u8_s16_neon)
- },
- {
- "neon_u8_u8_s16_add",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); },
- REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_u8_s16_neon)
- },
#endif /* defined(ARM_COMPUTE_ENABLE_NEON) */
#if defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE)
{
"neon_qu8_add",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); },
+ [](const AddSelectorData & data) { return (data.dt == DataType::QASYMM8); },
REGISTER_QASYMM8_NEON(arm_compute::cpu::add_qasymm8_neon)
},
{
"neon_qs8_add",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)); },
+ [](const AddSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::add_qasymm8_signed_neon)
},
{
"neon_qs16_add",
- [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); },
+ [](const AddSelectorData & data) { return (data.dt == DataType::QSYMM16); },
REGISTER_QSYMM16_NEON(arm_compute::cpu::add_qsymm16_neon)
},
#endif /* defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) */
@@ -231,11 +185,11 @@ static const AddKernel available_kernels[] =
*
* @return A matching micro-kernel else nullptr
*/
-const AddKernel *get_implementation(const CPUInfo &cpuinfo, DataType dt1, DataType dt2, DataType dt3)
+const AddKernel *get_implementation(const CPUInfo &cpuinfo, DataType dt)
{
for(const auto &uk : available_kernels)
{
- if(uk.is_selected({ dt1, dt2, dt3, cpuinfo }))
+ if(uk.is_selected({ dt, cpuinfo }))
{
return &uk;
}
@@ -251,9 +205,7 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::S16, DataType::QSYMM16, DataType::F16,
DataType::S32, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
- DataType::S16, DataType::QSYMM16, DataType::F16,
- DataType::S32, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1);
const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
@@ -265,25 +217,12 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons
// Validate in case of configured dst
if(dst.total_size() > 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::U8)
- && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32 && dst.data_type() == DataType::S32)
- && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32 && dst.data_type() == DataType::F32)
- && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16 && dst.data_type() == DataType::F16)
- && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && dst.data_type() == DataType::QASYMM8)
- && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && dst.data_type() == DataType::QASYMM8_SIGNED)
- && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && dst.data_type() == DataType::QSYMM16),
- "You called addition with the wrong image formats");
-
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &dst);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0),
"Wrong shape for dst");
}
- const auto *uk = get_implementation(CPUInfo::get(), src0.data_type(), src1.data_type(), dst.data_type());
+ const auto *uk = get_implementation(CPUInfo::get(), src0.data_type());
ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
return Status{};
@@ -294,38 +233,8 @@ std::pair<Status, Window> validate_and_configure_window(const ITensorInfo &src0,
const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
// Auto initialize dst if not initialized
- {
- set_shape_if_empty(dst, out_shape);
-
- if(src0.data_type() == DataType::S16 || src1.data_type() == DataType::S16)
- {
- set_format_if_unknown(dst, Format::S16);
- }
- if(src0.data_type() == DataType::S32 || src1.data_type() == DataType::S32)
- {
- set_format_if_unknown(dst, Format::S32);
- }
- else if(src0.data_type() == DataType::F16 || src1.data_type() == DataType::F16)
- {
- set_format_if_unknown(dst, Format::F16);
- }
- else if(src0.data_type() == DataType::F32 || src1.data_type() == DataType::F32)
- {
- set_format_if_unknown(dst, Format::F32);
- }
- else if(src0.data_type() == DataType::QASYMM8 || src1.data_type() == DataType::QASYMM8)
- {
- set_data_type_if_unknown(dst, DataType::QASYMM8);
- }
- else if(src0.data_type() == DataType::QASYMM8_SIGNED || src1.data_type() == DataType::QASYMM8_SIGNED)
- {
- set_data_type_if_unknown(dst, DataType::QASYMM8_SIGNED);
- }
- else if(src0.data_type() == DataType::QSYMM16 || src1.data_type() == DataType::QSYMM16)
- {
- set_data_type_if_unknown(dst, DataType::QSYMM16);
- }
- }
+ set_shape_if_empty(dst, out_shape);
+ set_data_type_if_unknown(dst, src0.data_type());
Window win = calculate_max_window(out_shape, Steps());
@@ -339,7 +248,7 @@ void CpuAddKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I
ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst, policy));
- const auto uk = get_implementation(CPUInfo::get(), src0->data_type(), src1->data_type(), dst->data_type());
+ const auto uk = get_implementation(CPUInfo::get(), src0->data_type());
ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
_policy = policy;
diff --git a/src/core/cpu/kernels/CpuAddKernel.h b/src/core/cpu/kernels/CpuAddKernel.h
index 717d0132c6..1205b45dfb 100644
--- a/src/core/cpu/kernels/CpuAddKernel.h
+++ b/src/core/cpu/kernels/CpuAddKernel.h
@@ -44,9 +44,6 @@ public:
* Valid configurations (src0,src1) -> dst :
*
* - (U8,U8) -> U8
- * - (U8,U8) -> S16
- * - (S16,U8) -> S16
- * - (U8,S16) -> S16
* - (S16,S16) -> S16
* - (S32,S32) -> S32
* - (F16,F16) -> F16
diff --git a/src/core/cpu/kernels/CpuSubKernel.cpp b/src/core/cpu/kernels/CpuSubKernel.cpp
index 098a324377..fa7a55805e 100644
--- a/src/core/cpu/kernels/CpuSubKernel.cpp
+++ b/src/core/cpu/kernels/CpuSubKernel.cpp
@@ -41,9 +41,7 @@ namespace
{
struct SubSelectorData
{
- DataType dt1;
- DataType dt2;
- DataType dt3;
+ DataType dt;
};
using SubSelectorPtr = std::add_pointer<bool(const SubSelectorData &data)>::type;
@@ -60,59 +58,44 @@ static const SubKernel available_kernels[] =
{
{
"neon_fp32_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); },
+ [](const SubSelectorData & data) { return (data.dt == DataType::F32); },
REGISTER_FP32_NEON(arm_compute::cpu::sub_same_neon<float>)
},
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
{
"neon_fp16_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)); },
+ [](const SubSelectorData & data) { return (data.dt == DataType::F16); },
REGISTER_FP16_NEON(arm_compute::cpu::sub_same_neon<float16_t>)
},
#endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */
{
"neon_u8_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)); },
+ [](const SubSelectorData & data) { return (data.dt == DataType::U8); },
REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<uint8_t>)
},
{
"neon_s16_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)); },
+ [](const SubSelectorData & data) { return (data.dt == DataType::S16); },
REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int16_t>)
},
{
"neon_s32_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)); },
+ [](const SubSelectorData & data) { return (data.dt == DataType::S32); },
REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon<int32_t>)
},
{
- "neon_u8_s16_s16_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)); },
- REGISTER_INTEGER_NEON(arm_compute::cpu::sub_u8_s16_s16_neon)
- },
- {
- "neon_s16_u8_s16_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)); },
- REGISTER_INTEGER_NEON(arm_compute::cpu::sub_s16_u8_s16_neon)
- },
- {
- "neon_u8_u8_s16_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); },
- REGISTER_INTEGER_NEON(arm_compute::cpu::sub_u8_u8_s16_neon)
- },
- {
"neon_qu8_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); },
+ [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8); },
REGISTER_QASYMM8_NEON(arm_compute::cpu::sub_qasymm8_neon)
},
{
"neon_qs8_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)); },
+ [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::sub_qasymm8_signed_neon)
},
{
- "neon_s16_sub",
- [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); },
+ "neon_qs16_sub",
+ [](const SubSelectorData & data) { return (data.dt == DataType::QSYMM16); },
REGISTER_QSYMM16_NEON(arm_compute::cpu::sub_qsymm16_neon)
},
};
@@ -123,11 +106,11 @@ static const SubKernel available_kernels[] =
*
* @return A matching micro-kernel else nullptr
*/
-const SubKernel *get_implementation(DataType dt1, DataType dt2, DataType dt3)
+const SubKernel *get_implementation(DataType dt)
{
for(const auto &uk : available_kernels)
{
- if(uk.is_selected({ dt1, dt2, dt3 }))
+ if(uk.is_selected({ dt }))
{
return &uk;
}
@@ -141,54 +124,21 @@ inline Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src0);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16,
DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16,
- DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16,
- DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1);
- const auto *uk = get_implementation(src0.data_type(), src1.data_type(), dst.data_type());
+ const auto *uk = get_implementation(src0.data_type());
ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8)
- && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8)
- && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED)
- && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16)
- && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8)
- && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8)
- && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32)
- && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32)
- && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16),
- "You called subtract with the wrong image formats");
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- (src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP)
- || (src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP)
- || (src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP),
- "Convert policy cannot be WRAP if datatype is quantized");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_quantized(src0.data_type()) && (policy == ConvertPolicy::WRAP),
+ "Convert policy cannot be WRAP if datatype is quantized");
// Validate in case of configured dst
if(dst.total_size() > 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::U8)
- && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && dst.data_type() == DataType::QASYMM8)
- && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && dst.data_type() == DataType::QASYMM8_SIGNED)
- && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && dst.data_type() == DataType::QSYMM16)
- && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16)
- && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32 && dst.data_type() == DataType::S32)
- && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32 && dst.data_type() == DataType::F32)
- && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16 && dst.data_type() == DataType::F16),
- "You called subtract with the wrong image formats");
-
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &dst);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0),
"Wrong shape for dst");
}
@@ -205,8 +155,9 @@ void CpuSubKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I
// Auto initialize dst if not initialized
set_shape_if_empty(*dst, out_shape);
+ set_data_type_if_unknown(*dst, src0->data_type());
- const auto *uk = get_implementation(src0->data_type(), src1->data_type(), dst->data_type());
+ const auto *uk = get_implementation(src0->data_type());
ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
_policy = policy;
diff --git a/src/core/cpu/kernels/CpuSubKernel.h b/src/core/cpu/kernels/CpuSubKernel.h
index b9160bd150..cb64e64cfa 100644
--- a/src/core/cpu/kernels/CpuSubKernel.h
+++ b/src/core/cpu/kernels/CpuSubKernel.h
@@ -45,11 +45,8 @@ public:
* Valid configurations (src0,src1) -> dst :
*
* - (U8,U8) -> U8
- * - (U8,U8) -> S16
* - (QASYMM8, QASYMM8) -> QASYMM8
* - (QASYMM8_SIGNED, QASYMM8_SIGNED) -> QASYMM8_SIGNED
- * - (S16,U8) -> S16
- * - (U8,S16) -> S16
* - (S16,S16) -> S16
* - (S32,S32) -> S32
* - (F16,F16) -> F16
diff --git a/src/core/cpu/kernels/add/neon/integer.cpp b/src/core/cpu/kernels/add/neon/integer.cpp
deleted file mode 100644
index 24a0ac3b7c..0000000000
--- a/src/core/cpu/kernels/add/neon/integer.cpp
+++ /dev/null
@@ -1,170 +0,0 @@
-/*
- * Copyright (c) 2020-2021 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/Types.h"
-#include "arm_compute/core/utils/misc/Traits.h"
-#include "src/core/NEON/wrapper/wrapper.h"
-#include "src/core/helpers/WindowHelpers.h"
-
-namespace arm_compute
-{
-namespace cpu
-{
-void add_u8_u8_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window)
-{
- // Create input windows
- Window win = window;
- Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape());
- Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
-
- // Clear X Dimension on execution window as we handle manually
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input1(src0, input1_win);
- Iterator input2(src1, input2_win);
- Iterator output(dst, win);
-
- const int window_step_x = 8;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
- const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
- const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
-
- if(policy == ConvertPolicy::WRAP)
- {
- // Compute S elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x)));
- const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
- wrapper::vstore(output_ptr + x, wrapper::vadd(vin1, vin2));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(output_ptr + x) = static_cast<int16_t>(*(input1_ptr + x)) + static_cast<int16_t>(*(input2_ptr + x));
- }
- }
- else
- {
- // Compute S elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x)));
- const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
- wrapper::vstore(output_ptr + x, wrapper::vqadd(vin1, vin2));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(output_ptr + x) = wrapper::add_sat(static_cast<int16_t>(*(input1_ptr + x)),
- static_cast<int16_t>(*(input2_ptr + x)));
- }
- }
- },
- input1, input2, output);
-}
-
-void add_s16_u8_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window)
-{
- // Create input windows
- Window win = window;
- Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape());
- Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
-
- // Clear X Dimension on execution window as we handle manually
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input1(src0, input1_win);
- Iterator input2(src1, input2_win);
- Iterator output(dst, win);
-
- const int window_step_x = 8;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
- const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
- const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
-
- if(policy == ConvertPolicy::WRAP)
- {
- // Compute S elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin1 = wrapper::vloadq(input1_ptr + x);
- const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
- wrapper::vstore(output_ptr + x, wrapper::vadd(vin1, vin2));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(output_ptr + x) = *(input1_ptr + x) + static_cast<int16_t>(*(input2_ptr + x));
- }
- }
- else
- {
- // Compute S elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin1 = wrapper::vloadq(input1_ptr + x);
- const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
- wrapper::vstore(output_ptr + x, wrapper::vqadd(vin1, vin2));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(output_ptr + x) = wrapper::add_sat(*(input1_ptr + x), static_cast<int16_t>(*(input2_ptr + x)));
- }
- }
- },
- input1, input2, output);
-}
-
-void add_u8_s16_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window)
-{
- // Simply swap the two input buffers:
- add_s16_u8_s16_neon(src1, src0, dst, policy, window);
-}
-} // namespace cpu
-} // namespace arm_compute \ No newline at end of file
diff --git a/src/core/cpu/kernels/add/neon/list.h b/src/core/cpu/kernels/add/neon/list.h
index 3ab03dd40e..379bd32fb1 100644
--- a/src/core/cpu/kernels/add/neon/list.h
+++ b/src/core/cpu/kernels/add/neon/list.h
@@ -38,9 +38,6 @@ namespace cpu
DECLARE_ADD_KERNEL(add_qasymm8_neon);
DECLARE_ADD_KERNEL(add_qasymm8_signed_neon);
DECLARE_ADD_KERNEL(add_qsymm16_neon);
-DECLARE_ADD_KERNEL(add_s16_u8_s16_neon);
-DECLARE_ADD_KERNEL(add_u8_s16_s16_neon);
-DECLARE_ADD_KERNEL(add_u8_u8_s16_neon);
#undef DECLARE_ADD_KERNEL
diff --git a/src/core/cpu/kernels/add/sve/integer.cpp b/src/core/cpu/kernels/add/sve/integer.cpp
deleted file mode 100644
index bd8179205b..0000000000
--- a/src/core/cpu/kernels/add/sve/integer.cpp
+++ /dev/null
@@ -1,201 +0,0 @@
-/*
- * Copyright (c) 2020-2021 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#if defined(__ARM_FEATURE_SVE)
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/Types.h"
-#include "arm_compute/core/utils/misc/Traits.h"
-#include "src/core/NEON/SVEMath.h"
-#include "src/core/NEON/wrapper/intrinsics/intrinsics.h"
-#include <arm_sve.h>
-
-namespace arm_compute
-{
-namespace cpu
-{
-void add_u8_u8_s16_sve(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window)
-{
- // Create input windows
- Window win = window;
- Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape());
- Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
-
- // Clear X Dimension on execution window as we handle manually
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input1(src0, input1_win);
- Iterator input2(src1, input2_win);
- Iterator output(dst, win);
-
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
- const auto all_true_pg = svptrue_b8();
-
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
- const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
- const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
-
- if(policy == ConvertPolicy::WRAP)
- {
- int x = window_start_x;
- svbool_t pg_u = svwhilelt_b8(x, window_end_x);
- svbool_t pg_0 = svwhilelt_b16(x, window_end_x);
- svbool_t pg_1 = svwhilelt_b16(x, static_cast<int>(window_end_x + svcnth()));
- do
- {
- const auto vsrc0 = svld1(pg_u, input1_ptr + x);
- const auto vsrc1 = svld1(pg_u, input2_ptr + x);
-
- const auto vsrc0_lo = svreinterpret_s16_u16(svunpklo(vsrc0));
- const auto vsrc0_hi = svreinterpret_s16_u16(svunpkhi(vsrc0));
- const auto vsrc1_lo = svreinterpret_s16_u16(svunpklo(vsrc1));
- const auto vsrc1_hi = svreinterpret_s16_u16(svunpkhi(vsrc1));
- svst1(pg_0, output_ptr + x, svqadd(vsrc0_lo, vsrc1_lo));
- svst1(pg_1, output_ptr + x + svcnth(), svqadd(vsrc0_hi, vsrc1_hi));
-
- x += svcntb();
- pg_u = svwhilelt_b8(x, window_end_x);
- pg_0 = svwhilelt_b16(x, window_end_x);
- pg_1 = svwhilelt_b16(x, static_cast<int>(window_end_x + svcnth()));
- }
- while(svptest_any(all_true_pg, pg_u));
- }
- else
- {
- int x = window_start_x;
- svbool_t pg_u = svwhilelt_b8(x, window_end_x);
- svbool_t pg_0 = svwhilelt_b16(x, window_end_x);
- svbool_t pg_1 = svwhilelt_b16(x, static_cast<int>(window_end_x + svcnth()));
- do
- {
- const auto vsrc0 = svld1(pg_u, input1_ptr + x);
- const auto vsrc1 = svld1(pg_u, input2_ptr + x);
-
- const auto vsrc0_lo = svreinterpret_s16_u16(svunpklo(vsrc0));
- const auto vsrc0_hi = svreinterpret_s16_u16(svunpkhi(vsrc0));
- const auto vsrc1_lo = svreinterpret_s16_u16(svunpklo(vsrc1));
- const auto vsrc1_hi = svreinterpret_s16_u16(svunpkhi(vsrc1));
- svst1(pg_0, output_ptr + x, svqadd(vsrc0_lo, vsrc1_lo));
- svst1(pg_1, output_ptr + x + svcnth(), svqadd(vsrc0_hi, vsrc1_hi));
-
- x += svcntb();
- pg_u = svwhilelt_b8(x, window_end_x);
- pg_0 = svwhilelt_b16(x, window_end_x);
- pg_1 = svwhilelt_b16(x, static_cast<int>(window_end_x + svcnth()));
- }
- while(svptest_any(all_true_pg, pg_u));
- }
- },
- input1, input2, output);
-}
-
-void add_s16_u8_s16_sve(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window)
-{
- // Create input windows
- Window win = window;
- Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape());
- Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
-
- // Clear X Dimension on execution window as we handle manually
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input1(src0, input1_win);
- Iterator input2(src1, input2_win);
- Iterator output(dst, win);
-
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
- const auto all_true_pg = svptrue_b8();
-
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
- const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
- const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
-
- if(policy == ConvertPolicy::WRAP)
- {
- int x = window_start_x;
- svbool_t pg_u = svwhilelt_b8(x, window_end_x);
- svbool_t pg_0 = svwhilelt_b16(x, window_end_x);
- svbool_t pg_1 = svwhilelt_b16(x + static_cast<int>(svcnth()), window_end_x);
- do
- {
- const auto vsrc0_0 = svld1_s16(pg_0, input1_ptr + x);
- const auto vsrc0_1 = svld1_s16(pg_1, input1_ptr + x + svcnth());
- const auto vsrc1_u8 = svld1_u8(pg_u, input2_ptr + x);
- const auto vsrc1_0 = svreinterpret_s16_u16(svunpklo(vsrc1_u8));
- const auto vsrc1_1 = svreinterpret_s16_u16(svunpkhi(vsrc1_u8));
- svst1_s16(pg_0, output_ptr + x, svadd_s16_z(pg_0, vsrc0_0, vsrc1_0));
- svst1_s16(pg_1, output_ptr + x + svcnth(), svadd_s16_z(pg_1, vsrc0_1, vsrc1_1));
-
- x += svcntb();
- pg_u = svwhilelt_b8(x, window_end_x);
- pg_0 = svwhilelt_b16(x, window_end_x);
- pg_1 = svwhilelt_b16(x + static_cast<int>(svcnth()), window_end_x);
- }
- while(svptest_any(all_true_pg, pg_u));
- }
- else
- {
- int x = window_start_x;
- svbool_t pg_u = svwhilelt_b8(x, window_end_x);
- svbool_t pg_0 = svwhilelt_b16(x, window_end_x);
- svbool_t pg_1 = svwhilelt_b16(x + static_cast<int>(svcnth()), window_end_x);
- do
- {
- const auto vsrc0_0 = svld1_s16(pg_0, input1_ptr + x);
- const auto vsrc0_1 = svld1_s16(pg_1, input1_ptr + x + svcnth());
- const auto vsrc1_u8 = svld1_u8(pg_u, input2_ptr + x);
- const auto vsrc1_0 = svreinterpret_s16_u16(svunpklo(vsrc1_u8));
- const auto vsrc1_1 = svreinterpret_s16_u16(svunpkhi(vsrc1_u8));
-
- svst1_s16(pg_0, output_ptr + x, svqadd(vsrc0_0, vsrc1_0));
- svst1_s16(pg_1, output_ptr + x + svcnth(), svqadd(vsrc0_1, vsrc1_1));
-
- x += svcntb();
- pg_u = svwhilelt_b8(x, window_end_x);
- pg_0 = svwhilelt_b16(x, window_end_x);
- pg_1 = svwhilelt_b16(x + static_cast<int>(svcnth()), window_end_x);
- }
- while(svptest_any(all_true_pg, pg_u));
- }
- },
- input1, input2, output);
-}
-
-void add_u8_s16_s16_sve(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window)
-{
- // Simply swap the two input buffers:
- add_s16_u8_s16_sve(src1, src0, dst, policy, window);
-}
-} // namespace cpu
-} // namespace arm_compute
-#endif /* defined(__ARM_FEATURE_SVE) */ \ No newline at end of file
diff --git a/src/core/cpu/kernels/add/sve/list.h b/src/core/cpu/kernels/add/sve/list.h
index 9e439497c9..4d29c2a8f1 100644
--- a/src/core/cpu/kernels/add/sve/list.h
+++ b/src/core/cpu/kernels/add/sve/list.h
@@ -42,9 +42,6 @@ namespace cpu
DECLARE_ADD_KERNEL(add_qasymm8_sve);
DECLARE_ADD_KERNEL(add_qasymm8_signed_sve);
DECLARE_ADD_KERNEL(add_qsymm16_sve);
-DECLARE_ADD_KERNEL(add_s16_u8_s16_sve);
-DECLARE_ADD_KERNEL(add_u8_s16_s16_sve);
-DECLARE_ADD_KERNEL(add_u8_u8_s16_sve);
#undef DECLARE_ADD_KERNEL
diff --git a/src/core/cpu/kernels/sub/neon/integer.cpp b/src/core/cpu/kernels/sub/neon/integer.cpp
deleted file mode 100644
index bba73df1e8..0000000000
--- a/src/core/cpu/kernels/sub/neon/integer.cpp
+++ /dev/null
@@ -1,183 +0,0 @@
-/*
- * Copyright (c) 2021 Arm Limited.
- *
- * SPDX-License-Identifier: MIT
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to
- * deal in the Software without restriction, including without limitation the
- * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
- * sell copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/ITensor.h"
-#include "arm_compute/core/Types.h"
-#include "arm_compute/core/utils/misc/Traits.h"
-#include "src/core/NEON/wrapper/wrapper.h"
-#include "src/core/helpers/WindowHelpers.h"
-
-namespace arm_compute
-{
-namespace cpu
-{
-namespace
-{
-void sub_s16_u8_s16_impl(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window, bool is_swapped)
-{
- // Create input windows
- Window win = window;
- Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape());
- Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
-
- // Clear X Dimension on execution window as we handle manually
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input1(src0, input1_win);
- Iterator input2(src1, input2_win);
- Iterator output(dst, win);
-
- const int window_step_x = 8;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto input1_ptr = reinterpret_cast<const int16_t *>(input1.ptr());
- const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
- const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
-
- if(policy == ConvertPolicy::WRAP)
- {
- // Compute S elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin1 = wrapper::vloadq(input1_ptr + x);
- const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
- const auto res = is_swapped ? wrapper::vsub(vin2, vin1) : wrapper::vsub(vin1, vin2);
- wrapper::vstore(output_ptr + x, res);
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- const auto res = is_swapped ? static_cast<int16_t>(*(input2_ptr + x)) - *(input1_ptr + x) : *(input1_ptr + x) - static_cast<int16_t>(*(input2_ptr + x));
- *(output_ptr + x) = res;
- }
- }
- else
- {
- // Compute S elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin1 = wrapper::vloadq(input1_ptr + x);
- const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
- const auto res = is_swapped ? wrapper::vqsub(vin2, vin1) : wrapper::vqsub(vin1, vin2);
- wrapper::vstore(output_ptr + x, res);
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- const auto res = is_swapped ? wrapper::sub_sat(static_cast<int16_t>(*(input2_ptr + x)), *(input1_ptr + x)) : wrapper::sub_sat(*(input1_ptr + x), static_cast<int16_t>(*(input2_ptr + x)));
- *(output_ptr + x) = res;
- }
- }
- },
- input1, input2, output);
-}
-}
-
-void sub_s16_u8_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window)
-{
- sub_s16_u8_s16_impl(src1, src0, dst, policy, window, false);
-}
-
-void sub_u8_s16_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window)
-{
- // Swap arguments
- sub_s16_u8_s16_impl(src1, src0, dst, policy, window, true);
-}
-
-void sub_u8_u8_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window)
-{
- // Create input windows
- Window win = window;
- Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape());
- Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape());
-
- // Clear X Dimension on execution window as we handle manually
- win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input1_win.set(Window::DimX, Window::Dimension(0, 1, 1));
- input2_win.set(Window::DimX, Window::Dimension(0, 1, 1));
-
- Iterator input1(src0, input1_win);
- Iterator input2(src1, input2_win);
- Iterator output(dst, win);
-
- const int window_step_x = 8;
- const auto window_start_x = static_cast<int>(window.x().start());
- const auto window_end_x = static_cast<int>(window.x().end());
-
- execute_window_loop(win, [&](const Coordinates &)
- {
- const auto input1_ptr = reinterpret_cast<const uint8_t *>(input1.ptr());
- const auto input2_ptr = reinterpret_cast<const uint8_t *>(input2.ptr());
- const auto output_ptr = reinterpret_cast<int16_t *>(output.ptr());
-
- if(policy == ConvertPolicy::WRAP)
- {
- // Compute S elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x)));
- const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
- wrapper::vstore(output_ptr + x, wrapper::vsub(vin1, vin2));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(output_ptr + x) = static_cast<int16_t>(*(input1_ptr + x)) - static_cast<int16_t>(*(input2_ptr + x));
- }
- }
- else
- {
- // Compute S elements per iteration
- int x = window_start_x;
- for(; x <= (window_end_x - window_step_x); x += window_step_x)
- {
- const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x)));
- const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x)));
- wrapper::vstore(output_ptr + x, wrapper::vqsub(vin1, vin2));
- }
-
- // Compute left-over elements
- for(; x < window_end_x; ++x)
- {
- *(output_ptr + x) = wrapper::sub_sat(static_cast<int16_t>(*(input1_ptr + x)),
- static_cast<int16_t>(*(input2_ptr + x)));
- }
- }
- },
- input1, input2, output);
-}
-
-} // namespace cpu
-} // namespace arm_compute \ No newline at end of file
diff --git a/src/core/cpu/kernels/sub/neon/list.h b/src/core/cpu/kernels/sub/neon/list.h
index 1ab4e6367b..ac1346001a 100644
--- a/src/core/cpu/kernels/sub/neon/list.h
+++ b/src/core/cpu/kernels/sub/neon/list.h
@@ -38,9 +38,6 @@ namespace cpu
DECLARE_SUB_KERNEL(sub_qasymm8_neon);
DECLARE_SUB_KERNEL(sub_qasymm8_signed_neon);
DECLARE_SUB_KERNEL(sub_qsymm16_neon);
-DECLARE_SUB_KERNEL(sub_s16_u8_s16_neon);
-DECLARE_SUB_KERNEL(sub_u8_s16_s16_neon);
-DECLARE_SUB_KERNEL(sub_u8_u8_s16_neon);
#undef DECLARE_SUB_KERNEL
diff --git a/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp b/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp
index b645353dd6..f005e9226e 100644
--- a/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp
+++ b/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp
@@ -127,50 +127,29 @@ Status validate_arguments_with_arithmetic_rules(const ITensorInfo &src1, const I
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::S16, DataType::QSYMM16, DataType::F16,
DataType::S32, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&src2);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
- DataType::S16, DataType::QSYMM16, DataType::F16,
- DataType::S32, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src1, &src2);
- const bool is_quantized = is_data_type_quantized(src1.data_type()) || is_data_type_quantized(src2.data_type());
- if(is_quantized)
+ if(is_data_type_quantized_symmetric(src1.data_type()))
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src1, &src2);
-
- if(is_data_type_quantized_symmetric(src1.data_type()))
- {
- const int32_t in1_offset = src1.quantization_info().uniform().offset;
- const int32_t in2_offset = src2.quantization_info().uniform().offset;
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_offset != 0, "For quantized symmetric, offset must be zero");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(in2_offset != 0, "For quantized symmetric, offset must be zero");
- }
+ const int32_t in1_offset = src1.quantization_info().uniform().offset;
+ const int32_t in2_offset = src2.quantization_info().uniform().offset;
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_offset != 0, "For quantized symmetric, offset must be zero");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(in2_offset != 0, "For quantized symmetric, offset must be zero");
}
const TensorShape out_shape = TensorShape::broadcast_shape(src1.tensor_shape(), src2.tensor_shape());
-
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
// Validate in case of configured dst
if(dst.total_size() > 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&dst);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
- DataType::S16, DataType::QSYMM16, DataType::F16,
- DataType::S32, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((dst.data_type() == DataType::U8) && ((src1.data_type() != DataType::U8) || (src2.data_type() != DataType::U8)),
- "dst can only be U8 if both inputs are U8");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0),
- "Wrong shape for dst");
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src1, &dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0), "Wrong shape for dst");
- if(is_quantized)
+ if(is_data_type_quantized_symmetric(dst.data_type()))
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src1, &dst);
-
- if(is_data_type_quantized_symmetric(dst.data_type()))
- {
- const int32_t offset = dst.quantization_info().uniform().offset;
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(offset != 0, "For quantized symmetric, offset must be zero");
- }
+ const int32_t offset = dst.quantization_info().uniform().offset;
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(offset != 0, "For quantized symmetric, offset must be zero");
}
}
return Status{};
@@ -182,9 +161,7 @@ CLBuildOptions generate_build_options_with_arithmetic_rules(const ITensorInfo &s
const unsigned int num_elems_processed_per_iteration = adjust_vec_size(vector_size_byte_opencl / dst.element_size(), dst.dimension(0));
- build_opts.add_option("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(src1.data_type()));
- build_opts.add_option("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(src2.data_type()));
- build_opts.add_option("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(dst.data_type()));
+ build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src1.data_type()));
build_opts.add_option("-DVEC_SIZE_IN1=" + support::cpp11::to_string(src1.dimension(0) == 1 ? 1 : num_elems_processed_per_iteration));
build_opts.add_option("-DVEC_SIZE_IN2=" + support::cpp11::to_string(src2.dimension(0) == 1 ? 1 : num_elems_processed_per_iteration));
build_opts.add_option("-DVEC_SIZE_OUT=" + support::cpp11::to_string(num_elems_processed_per_iteration));
@@ -220,32 +197,7 @@ std::pair<Status, Window> validate_and_configure_window_for_arithmetic_operators
const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(src1, src2);
const TensorShape &out_shape = broadcast_pair.first;
- set_shape_if_empty(dst, out_shape);
-
- if(src1.data_type() == DataType::S16 || src2.data_type() == DataType::S16)
- {
- set_format_if_unknown(dst, Format::S16);
- }
- else if(src1.data_type() == DataType::F16 || src2.data_type() == DataType::F16)
- {
- set_format_if_unknown(dst, Format::F16);
- }
- else if(src1.data_type() == DataType::F32 || src2.data_type() == DataType::F32)
- {
- set_format_if_unknown(dst, Format::F32);
- }
- else if(src1.data_type() == DataType::QASYMM8 || src2.data_type() == DataType::QASYMM8)
- {
- set_data_type_if_unknown(dst, DataType::QASYMM8);
- }
- else if(src1.data_type() == DataType::QASYMM8_SIGNED || src2.data_type() == DataType::QASYMM8_SIGNED)
- {
- set_data_type_if_unknown(dst, DataType::QASYMM8_SIGNED);
- }
- else if(src1.data_type() == DataType::QSYMM16 || src2.data_type() == DataType::QSYMM16)
- {
- set_data_type_if_unknown(dst, DataType::QSYMM16);
- }
+ auto_init_if_empty(dst, out_shape, 1, src1.data_type());
return configure_window_arithmetic_common(dst);
}
@@ -258,7 +210,6 @@ std::pair<Status, Window> validate_and_configure_window_for_logical_binary_opera
set_shape_if_empty(dst, out_shape);
set_data_type_if_unknown(dst, DataType::U8);
- // The arithmetic utility functions can be share
return configure_window_arithmetic_common(dst);
}
@@ -266,7 +217,9 @@ std::pair<Status, Window> validate_and_configure_window_for_division(ITensorInfo
{
const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(src1, src2);
const TensorShape &out_shape = broadcast_pair.first;
+
auto_init_if_empty(dst, out_shape, 1, src1.data_type());
+
return configure_window_arithmetic_common(dst);
}
} // namespace
diff --git a/src/runtime/cpu/operators/CpuAdd.h b/src/runtime/cpu/operators/CpuAdd.h
index febb79e4cd..3ff135fe41 100644
--- a/src/runtime/cpu/operators/CpuAdd.h
+++ b/src/runtime/cpu/operators/CpuAdd.h
@@ -39,9 +39,6 @@ public:
* Valid configurations (src0,src1) -> dst :
*
* - (U8,U8) -> U8
- * - (U8,U8) -> S16
- * - (S16,U8) -> S16
- * - (U8,S16) -> S16
* - (S16,S16) -> S16
* - (S32,S32) -> S32
* - (F16,F16) -> F16
diff --git a/src/runtime/cpu/operators/CpuSub.h b/src/runtime/cpu/operators/CpuSub.h
index aad01fe4dc..07f5be89cd 100644
--- a/src/runtime/cpu/operators/CpuSub.h
+++ b/src/runtime/cpu/operators/CpuSub.h
@@ -39,11 +39,8 @@ public:
* Valid configurations (src0,src1) -> dst :
*
* - (U8,U8) -> U8
- * - (U8,U8) -> S16
* - (QASYMM8, QASYMM8) -> QASYMM8
* - (QASYMM8_SIGNED, QASYMM8_SIGNED) -> QASYMM8_SIGNED
- * - (S16,U8) -> S16
- * - (U8,S16) -> S16
* - (S16,S16) -> S16
* - (S32,S32) -> S32
* - (F16,F16) -> F16
diff --git a/tests/validation/CL/ArithmeticAddition.cpp b/tests/validation/CL/ArithmeticAddition.cpp
index c74f6a3b23..9e3d9afc36 100644
--- a/tests/validation/CL/ArithmeticAddition.cpp
+++ b/tests/validation/CL/ArithmeticAddition.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,23 +44,6 @@ namespace validation
namespace
{
/** Input data sets **/
-const auto ArithmeticAdditionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)), framework::dataset::make("DataType",
- DataType::U8));
-const auto ArithmeticAdditionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8), framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("DataType",
- DataType::QASYMM8));
-const auto ArithmeticAdditionQASYMM8SignedDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("DataType",
- DataType::QASYMM8_SIGNED));
-const auto ArithmeticAdditionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)),
- framework::dataset::make("DataType",
- DataType::QSYMM16));
-const auto ArithmeticAdditionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
- framework::dataset::make("DataType", DataType::S16));
-const auto ArithmeticAdditionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("DataType", DataType::F16));
-const auto ArithmeticAdditionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("DataType", DataType::F32));
const auto EmptyActivationFunctionsDataset = framework::dataset::make("ActivationInfo",
{ ActivationLayerInfo() });
const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo",
@@ -76,22 +59,19 @@ TEST_SUITE(ArithmeticAddition)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
}),
- framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { true, true, false, false})),
+ framework::dataset::make("Expected", { true, false, false})),
input1_info, input2_info, output_info, expected)
{
ARM_COMPUTE_EXPECT(bool(CLArithmeticAddition::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS);
@@ -129,7 +109,7 @@ using CLArithmeticAdditionFixture = ArithmeticAdditionValidationFixture<CLTensor
TEST_SUITE(Integer)
TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionU8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::U8)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -138,14 +118,14 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFixture<uint8_t>, framework
TEST_SUITE_END() // U8
TEST_SUITE(S16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticAdditionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticAdditionS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticAdditionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -160,7 +140,7 @@ using CLArithmeticAdditionQuantizedFixture = ArithmeticAdditionValidationQuantiz
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallShapes(),
- ArithmeticAdditionQASYMM8Dataset),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
@@ -171,12 +151,13 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture<uint8_t>,
}
template <typename T>
using CLArithmeticAdditionBroadcastQuantizedFixture = ArithmeticAdditionValidationQuantizedBroadcastFixture<CLTensor, CLAccessor, CLArithmeticAddition, T>;
-FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallShapesBroadcast(),
- ArithmeticAdditionQASYMM8Dataset),
- framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
- framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
- framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
- framework::dataset::make("OutQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
+FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT,
+ combine(combine(combine(combine(combine(datasets::SmallShapesBroadcast(),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
+ framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
+ framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(1.f / 255.f, 5) })))
{
// Validate output
validate(CLAccessor(_target), _reference);
@@ -184,7 +165,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastQuantized
TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallShapes(),
- ArithmeticAdditionQASYMM8SignedDataset),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 10) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
@@ -196,7 +177,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture<int8_t>, f
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE(QSYMM16)
FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallShapes(),
- ArithmeticAdditionQSYMM16Dataset),
+ framework::dataset::make("DataType", DataType::QSYMM16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })),
@@ -213,14 +194,16 @@ using CLArithmeticAdditionFloatFixture = ArithmeticAdditionValidationFloatFixtur
TEST_SUITE(Float)
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFloatFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionFP16Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFloatFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::F16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapes(), ArithmeticAdditionFP16Dataset),
+FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapes(), framework::dataset::make("DataType",
+ DataType::F16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
ActivationFunctionsDataset))
{
@@ -230,14 +213,16 @@ FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture<half>
TEST_SUITE_END() // FP16
TEST_SUITE(FP32)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFloatFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionFP32Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFloatFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset))
{
// Validate output
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapes(), ArithmeticAdditionFP32Dataset),
+FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapes(), framework::dataset::make("DataType",
+ DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
ActivationFunctionsDataset))
{
@@ -245,7 +230,8 @@ FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture<float
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticAdditionFloatFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticAdditionFP32Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticAdditionFloatFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+ DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset))
{
@@ -257,7 +243,7 @@ template <typename T>
using CLArithmeticAdditionBroadcastFloatFixture = ArithmeticAdditionBroadcastValidationFloatFixture<CLTensor, CLAccessor, CLArithmeticAddition, T>;
FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastFloatFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapesBroadcast(),
- ArithmeticAdditionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset))
{
@@ -265,7 +251,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastFloatFixt
validate(CLAccessor(_target), _reference);
}
FIXTURE_DATA_TEST_CASE(RunWithActivationBroadcast, CLArithmeticAdditionBroadcastFloatFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapesBroadcast(),
- ArithmeticAdditionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
ActivationFunctionsDataset))
{
@@ -274,7 +260,7 @@ FIXTURE_DATA_TEST_CASE(RunWithActivationBroadcast, CLArithmeticAdditionBroadcast
}
FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, CLArithmeticAdditionBroadcastFloatFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapesBroadcast(),
- ArithmeticAdditionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset))
{
diff --git a/tests/validation/CL/ArithmeticSubtraction.cpp b/tests/validation/CL/ArithmeticSubtraction.cpp
index 2709fcaedb..00eba7f92a 100644
--- a/tests/validation/CL/ArithmeticSubtraction.cpp
+++ b/tests/validation/CL/ArithmeticSubtraction.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,24 +44,6 @@ namespace validation
namespace
{
/** Input data sets **/
-const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)),
- framework::dataset::make("DataType",
- DataType::U8));
-const auto ArithmeticSubtractionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8), framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("DataType",
- DataType::QASYMM8));
-const auto ArithmeticSubtractionQASYMM8SignedDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("DataType",
- DataType::QASYMM8_SIGNED));
-const auto ArithmeticSubtractionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)),
- framework::dataset::make("DataType",
- DataType::QSYMM16));
-const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
- framework::dataset::make("DataType", DataType::S16));
-const auto ArithmeticSubtractionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("DataType", DataType::F16));
-const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("DataType", DataType::F32));
const auto EmptyActivationFunctionsDataset = framework::dataset::make("ActivationInfo",
{ ActivationLayerInfo() });
const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo",
@@ -79,22 +61,19 @@ TEST_SUITE(ArithmeticSubtraction)
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
}),
- framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { true, true, false, false})),
+ framework::dataset::make("Expected", { true, false, false})),
input1_info, input2_info, output_info, expected)
{
ARM_COMPUTE_EXPECT(bool(CLArithmeticSubtraction::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS);
@@ -159,7 +138,8 @@ using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<CL
TEST_SUITE(Integer)
TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::U8)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -169,7 +149,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<uint8_t>, framew
TEST_SUITE_END() // U8
TEST_SUITE(S16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::S16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -177,7 +158,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture<int16_t>, framew
validate(CLAccessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+ DataType::S16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -193,7 +175,7 @@ using CLArithmeticSubtractionQuantizedFixture = ArithmeticSubtractionValidationQ
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
- ArithmeticSubtractionQASYMM8Dataset),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
@@ -206,7 +188,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture<uint8_t
TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
- ArithmeticSubtractionQASYMM8SignedDataset),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 10) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
@@ -219,7 +201,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture<int8_t>
TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE(QSYMM16)
FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(),
- ArithmeticSubtractionQSYMM16Dataset),
+ framework::dataset::make("DataType", DataType::QSYMM16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })),
@@ -237,7 +219,8 @@ using CLArithmeticSubtractionFloatFixture = ArithmeticSubtractionValidationFloat
TEST_SUITE(Float)
TEST_SUITE(FP16)
-FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP16Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::F16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset),
OutOfPlaceDataSet))
@@ -246,7 +229,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture<half>, fram
validate(CLAccessor(_target), _reference);
}
FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticSubtractionFloatFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::TinyShapes(),
- ArithmeticSubtractionFP16Dataset),
+ framework::dataset::make("DataType", DataType::F16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
ActivationFunctionsDataset),
InPlaceDataSet))
@@ -258,7 +241,7 @@ TEST_SUITE_END() // FP16
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(),
- ArithmeticSubtractionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset),
OutOfPlaceDataSet))
@@ -267,7 +250,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture<float>, fra
validate(CLAccessor(_target), _reference);
}
FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticSubtractionFloatFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::TinyShapes(),
- ArithmeticSubtractionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
ActivationFunctionsDataset),
InPlaceDataSet))
@@ -277,7 +260,7 @@ FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticSubtractionFloatFixture<fl
}
FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFloatFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(),
- ArithmeticSubtractionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset),
OutOfPlaceDataSet))
@@ -290,7 +273,7 @@ template <typename T>
using CLArithmeticSubtractionBroadcastFloatFixture = ArithmeticSubtractionBroadcastValidationFloatFixture<CLTensor, CLAccessor, CLArithmeticSubtraction, T>;
FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticSubtractionBroadcastFloatFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapesBroadcast(),
- ArithmeticSubtractionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset),
OutOfPlaceDataSet))
@@ -299,7 +282,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticSubtractionBroadcastFloatF
validate(CLAccessor(_target), _reference);
}
FIXTURE_DATA_TEST_CASE(RunWithActivationBroadcast, CLArithmeticSubtractionBroadcastFloatFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::TinyShapesBroadcast(),
- ArithmeticSubtractionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
ActivationFunctionsDataset),
OutOfPlaceDataSet))
@@ -309,7 +292,7 @@ FIXTURE_DATA_TEST_CASE(RunWithActivationBroadcast, CLArithmeticSubtractionBroadc
}
FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, CLArithmeticSubtractionBroadcastFloatFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapesBroadcast(),
- ArithmeticSubtractionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
EmptyActivationFunctionsDataset),
OutOfPlaceDataSet))
diff --git a/tests/validation/CL/ElementwiseMax.cpp b/tests/validation/CL/ElementwiseMax.cpp
index b9444b2795..225968efd1 100644
--- a/tests/validation/CL/ElementwiseMax.cpp
+++ b/tests/validation/CL/ElementwiseMax.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,7 +58,7 @@ const auto ElementwiseMaxQASYMM8SignedDataset = combine(combine(framework::datas
const auto ElementwiseMaxQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)),
framework::dataset::make("DataType",
DataType::QSYMM16));
-const auto ElementwiseMaxS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
+const auto ElementwiseMaxS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("DataType", DataType::S16));
const auto ElementwiseMaxFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
framework::dataset::make("DataType", DataType::F16));
@@ -80,21 +80,18 @@ TEST_SUITE(ElementwiseMax)
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
}),
framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("OutputInfo",{TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { true, true, false, false})),
+ framework::dataset::make("Expected", { true, false, false})),
input1_info, input2_info, output_info, expected)
{
ARM_COMPUTE_EXPECT(bool(CLElementwiseMax::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS);
diff --git a/tests/validation/CL/ElementwiseMin.cpp b/tests/validation/CL/ElementwiseMin.cpp
index 8f53b241ab..2a066908fa 100644
--- a/tests/validation/CL/ElementwiseMin.cpp
+++ b/tests/validation/CL/ElementwiseMin.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,7 +58,7 @@ const auto ElementwiseMinQASYMM8SignedDataset = combine(combine(framework::datas
const auto ElementwiseMinQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)),
framework::dataset::make("DataType",
DataType::QSYMM16));
-const auto ElementwiseMinS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
+const auto ElementwiseMinS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("DataType", DataType::S16));
const auto ElementwiseMinFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
framework::dataset::make("DataType", DataType::F16));
@@ -80,21 +80,18 @@ TEST_SUITE(ElementwiseMin)
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
}),
framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { true, true, false, false})),
+ framework::dataset::make("Expected", { true, false, false})),
input1_info, input2_info, output_info, expected)
{
ARM_COMPUTE_EXPECT(bool(CLElementwiseMin::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS);
diff --git a/tests/validation/CL/ElementwiseSquaredDiff.cpp b/tests/validation/CL/ElementwiseSquaredDiff.cpp
index 0a4ab6627b..4c732b0885 100644
--- a/tests/validation/CL/ElementwiseSquaredDiff.cpp
+++ b/tests/validation/CL/ElementwiseSquaredDiff.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2020 Arm Limited.
+ * Copyright (c) 2018-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -57,7 +57,7 @@ const auto ElementwiseSquaredDiffQASYMM8Dataset = combine(combine(framework::dat
const auto ElementwiseSquaredDiffQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)),
framework::dataset::make("DataType",
DataType::QSYMM16));
-const auto ElementwiseSquaredDiffS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
+const auto ElementwiseSquaredDiffS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("DataType", DataType::S16));
const auto ElementwiseSquaredDiffFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
framework::dataset::make("DataType", DataType::F16));
@@ -79,21 +79,18 @@ TEST_SUITE(ElementwiseSquaredDiff)
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
}),
framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { true, true, false, false})),
+ framework::dataset::make("Expected", { true, false, false})),
input1_info, input2_info, output_info, expected)
{
ARM_COMPUTE_EXPECT(bool(CLElementwiseSquaredDiff::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS);
diff --git a/tests/validation/CL/PReluLayer.cpp b/tests/validation/CL/PReluLayer.cpp
index 043262d891..f3f1c8b1b8 100644
--- a/tests/validation/CL/PReluLayer.cpp
+++ b/tests/validation/CL/PReluLayer.cpp
@@ -56,7 +56,7 @@ const auto PReluLayerQASYMM8Dataset = combine(combine(framework::dataset::make("
const auto PReluLayerQASYMM8SIGNEDDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("DataType",
DataType::QASYMM8_SIGNED));
-const auto PReluLayerS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
+const auto PReluLayerS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("DataType", DataType::S16));
const auto PReluLayerFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
framework::dataset::make("DataType", DataType::F16));
@@ -71,21 +71,18 @@ TEST_SUITE(PReluLayer)
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
}),
framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { true, true, false, false})),
+ framework::dataset::make("Expected", { true, false, false})),
input1_info, input2_info, output_info, expected)
{
ARM_COMPUTE_EXPECT(bool(CLPReluLayer::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS);
diff --git a/tests/validation/NEON/ArithmeticAddition.cpp b/tests/validation/NEON/ArithmeticAddition.cpp
index ea6656eefe..213dbc1f5e 100644
--- a/tests/validation/NEON/ArithmeticAddition.cpp
+++ b/tests/validation/NEON/ArithmeticAddition.cpp
@@ -48,26 +48,6 @@ constexpr AbsoluteTolerance<float> tolerance_quant(1); /**< Tolerance value for
#else // !defined(__aarch64__) || defined(ENABLE_SVE)
constexpr AbsoluteTolerance<float> tolerance_quant(0);
#endif // !defined(__aarch64__) || defined(ENABLE_SVE)
-
-/** Input data sets **/
-const auto ArithmeticAdditionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)), framework::dataset::make("DataType",
- DataType::U8));
-const auto ArithmeticAdditionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)),
- framework::dataset::make("DataType", DataType::S16));
-const auto ArithmeticAdditionS32Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S32 }), framework::dataset::make("DataType", DataType::S32)),
- framework::dataset::make("DataType", DataType::S32));
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-const auto ArithmeticAdditionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("DataType", DataType::F16));
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-const auto ArithmeticAdditionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("DataType", DataType::F32));
-const auto ArithmeticAdditionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8), framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("DataType", DataType::QASYMM8));
-const auto ArithmeticAdditionQASYMM8SIGNEDDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED));
-const auto ArithmeticAdditionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)),
- framework::dataset::make("DataType", DataType::QSYMM16));
} // namespace
TEST_SUITE(NEON)
@@ -79,25 +59,22 @@ using NEArithmeticAdditionFixture = ArithmeticAdditionValidationFixture<Tensor,
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
- framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Unsupported broadcast
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),// Mismatching shapes
}),
- framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(1U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
})),
- framework::dataset::make("Expected", { true, true, false, false, false})),
+ framework::dataset::make("Expected", { true, false, false, false})),
input1_info, input2_info, output_info, expected)
{
Status s = NEArithmeticAddition::validate(&input1_info.clone()->set_is_resizable(false),
@@ -127,7 +104,7 @@ TEST_CASE(NoPaddingAdded, framework::DatasetMode::PRECOMMIT)
TEST_SUITE(Integer)
TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionU8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::U8)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -136,14 +113,14 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<uint8_t>, framework
TEST_SUITE_END() // U8
TEST_SUITE(S16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticAdditionS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::S16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -152,7 +129,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture<int16_t>, framework
TEST_SUITE_END() // S16
TEST_SUITE(S32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<int32_t>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), ArithmeticAdditionS32Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<int32_t>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::S32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -164,7 +141,7 @@ TEST_SUITE_END() // Integer
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(F16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), ArithmeticAdditionFP16Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<half>, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -174,14 +151,14 @@ TEST_SUITE_END() // F16
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
TEST_SUITE(F32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionFP32Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticAdditionFP32Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -192,7 +169,7 @@ template <typename T>
using NEArithmeticAdditionBroadcastFixture = ArithmeticAdditionBroadcastValidationFixture<Tensor, Accessor, NEArithmeticAddition, T>;
FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticAdditionBroadcastFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapesBroadcast(),
- ArithmeticAdditionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -200,7 +177,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticAdditionBroadcastFixture<f
}
FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, NEArithmeticAdditionBroadcastFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapesBroadcast(),
- ArithmeticAdditionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })))
{
// Validate output
@@ -220,7 +197,7 @@ TEST_SUITE(QASYMM8)
FIXTURE_DATA_TEST_CASE(RunSmall,
NEArithmeticAdditionQuantizedFixture<uint8_t>,
framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionQASYMM8Dataset),
+ combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })),
@@ -235,7 +212,7 @@ TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall,
NEArithmeticAdditionQuantizedFixture<int8_t>,
framework::DatasetMode::ALL,
- combine(combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionQASYMM8SIGNEDDataset),
+ combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(0.5f, 20) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(0.5f, 10) })),
@@ -246,7 +223,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall,
}
FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticAdditionQuantizedBroadcastFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(
- datasets::SmallShapesBroadcast(), ArithmeticAdditionQASYMM8SIGNEDDataset),
+ datasets::SmallShapesBroadcast(), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(0.5f, 20) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(0.5f, 10) })),
@@ -261,7 +238,7 @@ TEST_SUITE(QSYMM16)
FIXTURE_DATA_TEST_CASE(RunSmall,
NEArithmeticAdditionQuantizedFixture<int16_t>,
framework::DatasetMode::PRECOMMIT,
- combine(combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionQSYMM16Dataset),
+ combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QSYMM16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })),
framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })),
diff --git a/tests/validation/NEON/ArithmeticSubtraction.cpp b/tests/validation/NEON/ArithmeticSubtraction.cpp
index 7a36893445..68213fb51f 100644
--- a/tests/validation/NEON/ArithmeticSubtraction.cpp
+++ b/tests/validation/NEON/ArithmeticSubtraction.cpp
@@ -50,39 +50,7 @@ constexpr AbsoluteTolerance<float> tolerance_qasymm8(1); /**< Tolerance value fo
#endif //__aarch64__
constexpr AbsoluteTolerance<int16_t> tolerance_qsymm16(1); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */
-/** Input data sets **/
-const auto ArithmeticSubtractionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8),
- framework::dataset::make("DataType", DataType::QASYMM8)),
- framework::dataset::make("DataType", DataType::QASYMM8));
-
-const auto ArithmeticSubtractionQASYMM8SIGNEDDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
- framework::dataset::make("DataType", DataType::QASYMM8_SIGNED));
-
-const auto ArithmeticSubtractionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16),
- framework::dataset::make("DataType", DataType::QSYMM16)),
- framework::dataset::make("DataType", DataType::QSYMM16));
-
-const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8),
- framework::dataset::make("DataType", DataType::U8)),
- framework::dataset::make("DataType", DataType::U8));
-
-const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }),
- framework::dataset::make("DataType", DataType::S16)),
- framework::dataset::make("DataType", DataType::S16));
-
-const auto ArithmeticSubtractionS32Dataset = combine(combine(framework::dataset::make("DataType", DataType::S32),
- framework::dataset::make("DataType", DataType::S32)),
- framework::dataset::make("DataType", DataType::S32));
-#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
-const auto ArithmeticSubtractionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16),
- framework::dataset::make("DataType", DataType::F16)),
- framework::dataset::make("DataType", DataType::F16));
-#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
-const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("DataType", DataType::F32));
-
+// Quantization Infomation DataSet
const auto ArithmeticSubtractionQuantizationInfoDataset = combine(combine(framework::dataset::make("QuantizationInfoIn1", { QuantizationInfo(10, 120) }),
framework::dataset::make("QuantizationInfoIn2", { QuantizationInfo(20, 110) })),
framework::dataset::make("QuantizationInfoOut", { QuantizationInfo(15, 125) }));
@@ -105,35 +73,31 @@ using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Te
// *INDENT-OFF*
// clang-format off
DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
- framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::QASYMM8), // Mismatching types
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8), // Invalid convert policy
}),
- framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
})),
- framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
- TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QASYMM8),
})),
- framework::dataset::make("ConvertPolicy",{ ConvertPolicy::WRAP,
- ConvertPolicy::SATURATE,
- ConvertPolicy::SATURATE,
- ConvertPolicy::WRAP,
- ConvertPolicy::WRAP,
- ConvertPolicy::WRAP,
+ framework::dataset::make("ConvertPolicy",{ ConvertPolicy::SATURATE,
+ ConvertPolicy::SATURATE,
+ ConvertPolicy::WRAP,
+ ConvertPolicy::WRAP,
+ ConvertPolicy::WRAP,
})),
- framework::dataset::make("Expected", { true, true, false, false, false, false})),
+ framework::dataset::make("Expected", { true, false, false, false, false})),
input1_info, input2_info, output_info, policy, expected)
{
ARM_COMPUTE_EXPECT(bool(NEArithmeticSubtraction::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), policy)) == expected, framework::LogLevel::ERRORS);
@@ -194,7 +158,8 @@ TEST_CASE(InvalidBroadcastBoth, framework::DatasetMode::ALL)
TEST_SUITE_END() // InPlaceValidate
TEST_SUITE(U8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::U8)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -210,7 +175,8 @@ using NEArithmeticSubtractionQSYMM16Fixture = ArithmeticSubtracti
TEST_SUITE(Quantized)
TEST_SUITE(QASYMM8)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQASYMM8Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::QASYMM8)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
ArithmeticSubtractionQuantizationInfoDataset),
InPlaceDataSet))
@@ -222,8 +188,7 @@ TEST_SUITE_END() // QASYMM8
TEST_SUITE(QASYMM8_SIGNED)
FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8SignedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(
- datasets::SmallShapes(),
- ArithmeticSubtractionQASYMM8SIGNEDDataset),
+ datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
ArithmeticSubtractionQuantizationInfoSignedDataset),
InPlaceDataSet))
@@ -234,7 +199,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8SignedFixture, fr
FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionQASYMM8SignedBroadcastFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(
datasets::SmallShapesBroadcast(),
- ArithmeticSubtractionQASYMM8SIGNEDDataset),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
ArithmeticSubtractionQuantizationInfoSignedDataset),
OutOfPlaceDataSet))
@@ -247,7 +212,7 @@ TEST_SUITE_END() // QASYMM8_SIGNED
TEST_SUITE(QSYMM16)
FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQSYMM16Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine(
datasets::SmallShapes(),
- ArithmeticSubtractionQSYMM16Dataset),
+ framework::dataset::make("DataType", DataType::QSYMM16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })),
ArithmeticSubtractionQuantizationInfoSymmetric),
OutOfPlaceDataSet))
@@ -259,7 +224,8 @@ TEST_SUITE_END() // QSYMM16
TEST_SUITE_END() // Quantized
TEST_SUITE(S16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::S16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -267,7 +233,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int16_t>, framew
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+ DataType::S16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -277,7 +244,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int16_t>, framew
TEST_SUITE_END() // S16
TEST_SUITE(S32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS32Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::S32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -285,7 +253,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<int32_t>, framew
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS32Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<int32_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+ DataType::S32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -297,7 +266,8 @@ TEST_SUITE_END() // S32
TEST_SUITE(Float)
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_SUITE(F16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP16Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<half>, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::F16)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -308,7 +278,8 @@ TEST_SUITE_END() // F16
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
TEST_SUITE(F32)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP32Dataset),
+FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType",
+ DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
InPlaceDataSet))
{
@@ -316,7 +287,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture<float>, framewor
validate(Accessor(_target), _reference);
}
-FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionFP32Dataset),
+FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType",
+ DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -328,7 +300,7 @@ template <typename T>
using NEArithmeticSubtractionBroadcastFixture = ArithmeticSubtractionBroadcastValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T>;
FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionBroadcastFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapesBroadcast(),
- ArithmeticSubtractionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
@@ -337,7 +309,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionBroadcastFixtur
}
FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, NEArithmeticSubtractionBroadcastFixture<float>, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapesBroadcast(),
- ArithmeticSubtractionFP32Dataset),
+ framework::dataset::make("DataType", DataType::F32)),
framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
OutOfPlaceDataSet))
{
diff --git a/tests/validation/fixtures/ArithmeticOperationsFixture.h b/tests/validation/fixtures/ArithmeticOperationsFixture.h
index 1dfc2ce579..7aa716d676 100644
--- a/tests/validation/fixtures/ArithmeticOperationsFixture.h
+++ b/tests/validation/fixtures/ArithmeticOperationsFixture.h
@@ -46,15 +46,14 @@ class ArithmeticOperationGenericFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1,
- DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
+ void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info, bool in_place)
{
_op = op;
_act_info = act_info;
_in_place = in_place;
- _target = compute_target(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
- _reference = compute_reference(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
+ _target = compute_target(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
+ _reference = compute_reference(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
}
protected:
@@ -64,13 +63,13 @@ protected:
library->fill_tensor_uniform(tensor, i);
}
- TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
+ TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
// Create tensors
- TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type0, 1, qinfo0);
- TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type1, 1, qinfo1);
- TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out);
+ TensorType ref_src1 = create_tensor<TensorType>(shape0, data_type, 1, qinfo0);
+ TensorType ref_src2 = create_tensor<TensorType>(shape1, data_type, 1, qinfo1);
+ TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), data_type, 1, qinfo_out);
TensorType *dst_to_use = _in_place ? &ref_src1 : &dst;
// Create and configure function
@@ -104,8 +103,7 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1,
- DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
+ SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
// current in-place implementation only supports same metadata of input and output tensors.
@@ -113,9 +111,9 @@ protected:
QuantizationInfo output_qinfo = _in_place ? qinfo0 : qinfo_out;
// Create reference
- SimpleTensor<T> ref_src1{ shape0, data_type0, 1, qinfo0 };
- SimpleTensor<T> ref_src2{ shape1, data_type1, 1, qinfo1 };
- SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, output_qinfo };
+ SimpleTensor<T> ref_src1{ shape0, data_type, 1, qinfo0 };
+ SimpleTensor<T> ref_src2{ shape1, data_type, 1, qinfo1 };
+ SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), data_type, 1, output_qinfo };
// Fill reference
fill(ref_src1, 0);
@@ -137,10 +135,10 @@ class ArithmeticAdditionBroadcastValidationFixture : public ArithmeticOperationG
{
public:
template <typename...>
- void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
}
};
@@ -149,10 +147,10 @@ class ArithmeticAdditionValidationFixture : public ArithmeticOperationGenericFix
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy)
+ void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false);
}
};
@@ -161,10 +159,10 @@ class ArithmeticAdditionBroadcastValidationFloatFixture : public ArithmeticOpera
{
public:
template <typename...>
- void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
}
};
@@ -173,10 +171,10 @@ class ArithmeticAdditionValidationFloatFixture : public ArithmeticOperationGener
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
+ void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
- output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
+ QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false);
}
};
@@ -185,12 +183,11 @@ class ArithmeticAdditionValidationQuantizedFixture : public ArithmeticOperationG
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
- QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
+ void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1,
- output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false);
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
+ qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false);
}
};
@@ -199,11 +196,9 @@ class ArithmeticAdditionValidationQuantizedBroadcastFixture : public ArithmeticO
{
public:
template <typename...>
- void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
- ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1,
- data_type0, data_type1, output_data_type, convert_policy,
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false);
}
};
@@ -213,10 +208,9 @@ class ArithmeticSubtractionBroadcastValidationFixture : public ArithmeticOperati
{
public:
template <typename...>
- void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place)
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, bool in_place)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
- data_type0, data_type1, output_data_type, convert_policy,
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place);
}
};
@@ -226,11 +220,10 @@ class ArithmeticSubtractionBroadcastValidationFloatFixture : public ArithmeticOp
{
public:
template <typename...>
- void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info,
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info,
bool in_place)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
- data_type0, data_type1, output_data_type, convert_policy,
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place);
}
};
@@ -240,10 +233,9 @@ class ArithmeticSubtractionValidationFixture : public ArithmeticOperationGeneric
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place)
+ void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, bool in_place)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
- data_type0, data_type1, output_data_type, convert_policy,
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place);
}
};
@@ -253,10 +245,9 @@ class ArithmeticSubtractionValidationFloatFixture : public ArithmeticOperationGe
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place)
+ void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
- data_type0, data_type1, output_data_type, convert_policy,
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place);
}
};
@@ -266,13 +257,11 @@ class ArithmeticSubtractionValidationQuantizedFixture : public ArithmeticOperati
{
public:
template <typename...>
- void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy,
- QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place)
+ void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape,
- data_type0, data_type1, output_data_type,
- convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place);
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
+ qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place);
}
};
@@ -281,11 +270,10 @@ class ArithmeticSubtractionValidationQuantizedBroadcastFixture : public Arithmet
{
public:
template <typename...>
- void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type,
- ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place)
+ void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out,
+ bool in_place)
{
- ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1,
- data_type0, data_type1, output_data_type, convert_policy,
+ ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place);
}
};