From da816752cad76c8e1b367e8e9c648994a1af599a Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Fri, 2 Jul 2021 09:22:14 +0100 Subject: Remove redundant implementations of Add/Sub operators Allows only implementations where inputs/output are of the same data type and removes legacy Computer Vision ones. Signed-off-by: Georgios Pinitas Change-Id: Ia2b3d23a04236aab682f0c36a1110a30f7c06d1c Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5900 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio Comments-Addressed: Arm Jenkins --- Android.bp | 3 - .../runtime/NEON/functions/NEArithmeticAddition.h | 3 - .../NEON/functions/NEArithmeticSubtraction.h | 3 - filelist.json | 9 - src/core/CL/cl_kernels/elementwise_operation.cl | 35 ++-- .../cl_kernels/elementwise_operation_quantized.cl | 24 +-- src/core/cpu/kernels/CpuAddKernel.cpp | 141 +++------------ src/core/cpu/kernels/CpuAddKernel.h | 3 - src/core/cpu/kernels/CpuSubKernel.cpp | 87 ++------- src/core/cpu/kernels/CpuSubKernel.h | 3 - src/core/cpu/kernels/add/neon/integer.cpp | 170 ----------------- src/core/cpu/kernels/add/neon/list.h | 3 - src/core/cpu/kernels/add/sve/integer.cpp | 201 --------------------- src/core/cpu/kernels/add/sve/list.h | 3 - src/core/cpu/kernels/sub/neon/integer.cpp | 183 ------------------- src/core/cpu/kernels/sub/neon/list.h | 3 - src/core/gpu/cl/kernels/ClElementwiseKernel.cpp | 77 ++------ src/runtime/cpu/operators/CpuAdd.h | 3 - src/runtime/cpu/operators/CpuSub.h | 3 - tests/validation/CL/ArithmeticAddition.cpp | 76 ++++---- tests/validation/CL/ArithmeticSubtraction.cpp | 63 +++---- tests/validation/CL/ElementwiseMax.cpp | 11 +- tests/validation/CL/ElementwiseMin.cpp | 11 +- tests/validation/CL/ElementwiseSquaredDiff.cpp | 11 +- tests/validation/CL/PReluLayer.cpp | 9 +- tests/validation/NEON/ArithmeticAddition.cpp | 57 ++---- tests/validation/NEON/ArithmeticSubtraction.cpp | 94 ++++------ .../fixtures/ArithmeticOperationsFixture.h | 96 +++++----- 28 files changed, 249 insertions(+), 1136 deletions(-) delete mode 100644 src/core/cpu/kernels/add/neon/integer.cpp delete mode 100644 src/core/cpu/kernels/add/sve/integer.cpp delete mode 100644 src/core/cpu/kernels/sub/neon/integer.cpp diff --git a/Android.bp b/Android.bp index 90facab3ca..6a76e4c92a 100644 --- a/Android.bp +++ b/Android.bp @@ -295,12 +295,10 @@ cc_library_static { "src/core/cpu/kernels/activation/sve/qasymm8.cpp", "src/core/cpu/kernels/activation/sve/qasymm8_signed.cpp", "src/core/cpu/kernels/activation/sve/qsymm16.cpp", - "src/core/cpu/kernels/add/neon/integer.cpp", "src/core/cpu/kernels/add/neon/qasymm8.cpp", "src/core/cpu/kernels/add/neon/qasymm8_signed.cpp", "src/core/cpu/kernels/add/neon/qsymm16.cpp", "src/core/cpu/kernels/add/sve/impl.cpp", - "src/core/cpu/kernels/add/sve/integer.cpp", "src/core/cpu/kernels/add/sve/qasymm8.cpp", "src/core/cpu/kernels/add/sve/qasymm8_signed.cpp", "src/core/cpu/kernels/add/sve/qsymm16.cpp", @@ -325,7 +323,6 @@ cc_library_static { "src/core/cpu/kernels/scale/sve/qasymm8.cpp", "src/core/cpu/kernels/scale/sve/qasymm8_signed.cpp", "src/core/cpu/kernels/softmax/impl/sve/impl.cpp", - "src/core/cpu/kernels/sub/neon/integer.cpp", "src/core/cpu/kernels/sub/neon/qasymm8.cpp", "src/core/cpu/kernels/sub/neon/qasymm8_signed.cpp", "src/core/cpu/kernels/sub/neon/qsymm16.cpp", diff --git a/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h b/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h index b8e46ff36e..b9012b02a9 100644 --- a/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h +++ b/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h @@ -62,9 +62,6 @@ public: * |QSYMM16 |QSYMM16 |QASYMM16 | * |QSYMM16 |QSYMM16 |S32 | * |U8 |U8 |U8 | - * |U8 |U8 |S16 | - * |U8 |S16 |S16 | - * |S16 |U8 |S16 | * |S16 |S16 |S16 | * |S32 |S32 |S32 | * |F16 |F16 |F16 | diff --git a/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h b/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h index 0c72e946f6..0b4db61d29 100644 --- a/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h +++ b/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h @@ -68,9 +68,6 @@ public: * |QSYMM16 |QSYMM16 |QASYMM16 | * |QSYMM16 |QSYMM16 |S32 | * |U8 |U8 |U8 | - * |U8 |U8 |S16 | - * |U8 |S16 |S16 | - * |S16 |U8 |S16 | * |S16 |S16 |S16 | * |S32 |S32 |S32 | * |F16 |F16 |F16 | diff --git a/filelist.json b/filelist.json index 56274954c8..9d07492e6a 100644 --- a/filelist.json +++ b/filelist.json @@ -708,9 +708,6 @@ ], "qasymm8_signed": [ "src/core/cpu/kernels/add/sve/qasymm8_signed.cpp" - ], - "integer": [ - "src/core/cpu/kernels/add/sve/integer.cpp" ] }, "neon": { @@ -722,9 +719,6 @@ ], "qasymm8_signed": [ "src/core/cpu/kernels/add/neon/qasymm8_signed.cpp" - ], - "integer": [ - "src/core/cpu/kernels/add/neon/integer.cpp" ] } } @@ -1661,9 +1655,6 @@ ], "qasymm8_signed": [ "src/core/cpu/kernels/sub/neon/qasymm8_signed.cpp" - ], - "integer": [ - "src/core/cpu/kernels/sub/neon/integer.cpp" ] } } diff --git a/src/core/CL/cl_kernels/elementwise_operation.cl b/src/core/CL/cl_kernels/elementwise_operation.cl index c8250045dc..99f725645d 100644 --- a/src/core/CL/cl_kernels/elementwise_operation.cl +++ b/src/core/CL/cl_kernels/elementwise_operation.cl @@ -23,7 +23,7 @@ */ #include "helpers.h" -#if defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(DATA_TYPE_OUT) +#if defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(DATA_TYPE) /** List of all the operations supported by this kernel. * @note ADD and SUB operations, when executed on integers, support saturation */ @@ -43,17 +43,17 @@ #if VEC_SIZE_OUT == 1 #define PRELU(x, y) (x > 0 ? x : x * y) #else // VEC_SIZE_OUT == 1 -#define PRELU(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE_OUT)0), SELECT_VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)))) +#define PRELU(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)))) #endif // VEC_SIZE_OUT == 1 #if defined(S32) -#define DIV(x, y) CONVERT(floor(CONVERT(x, VEC_DATA_TYPE(float, VEC_SIZE_OUT)) / CONVERT(y, VEC_DATA_TYPE(float, VEC_SIZE_OUT))), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)); +#define DIV(x, y) CONVERT(floor(CONVERT(x, VEC_DATA_TYPE(float, VEC_SIZE_OUT)) / CONVERT(y, VEC_DATA_TYPE(float, VEC_SIZE_OUT))), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)); #else /* S32 */ #define DIV(x, y) (x / y) #endif /* S32 */ -#define AND(x, y) (CONVERT((x && y), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT))1)) -#define OR(x, y) (CONVERT((x || y), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT))1)) +#define AND(x, y) (CONVERT((x && y), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))1)) +#define OR(x, y) (CONVERT((x || y), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)) & ((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))1)) #define OP_FUN_NAME_STR(op) elementwise_operation_##op #define OP_FUN_NAME(op) OP_FUN_NAME_STR(op) @@ -66,8 +66,7 @@ * * @note Vector sizes of inputs and output have to be passed at compile time using -DVEC_SIZE_IN1, -DVEC_SIZE_IN2, -DVEC_SIZE_OUT. * @note Leftover vector size has to be passed at compile time using -DVEC_SIZE_LEFTOVER. e.g. -DVEC_SIZE_OUT=3. It is defined as the remainder between the input's first dimension and VEC_SIZE_OUT - * @note The input and output data_types need to be passed at compile time using -DDATA_TYPE_IN1, -DDATA_TYPE_IN2 and -DDATA_TYPE_OUT: - * e.g. -DDATA_TYPE_IN1=uchar -DDATA_TYPE_IN2=uchar -DDATA_TYPE_OUT=short + * @note The input and output data_types need to be passed at compile time using -DDATA_TYPE: e.g. -DDATA_TYPE=uchar * @note To perform saturating operation -DSATURATE has to be passed to the compiler otherwise wrapping policy will be used. * @note The element-wise operation to be executed has to be passed at compile time using -DOP (e.g., -DOP=ADD) * @@ -114,23 +113,23 @@ __kernel void OP_FUN_NAME(OP)( uint out_x_offs = max((int)(get_global_id(0) * VEC_SIZE_OUT - (VEC_SIZE_OUT - VEC_SIZE_LEFTOVER) % VEC_SIZE_OUT), 0); // Get pixels pointer - __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + in1_x_offs * sizeof(DATA_TYPE_IN1) + get_global_id(1) * in1_step_y + get_global_id(2) * in1_step_z; - __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + in2_x_offs * sizeof(DATA_TYPE_IN2) + get_global_id(1) * in2_step_y + get_global_id(2) * in2_step_z; - __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + out_x_offs * sizeof(DATA_TYPE_OUT) + get_global_id(1) * out_step_y + get_global_id(2) * out_step_z; + __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + in1_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * in1_step_y + get_global_id(2) * in1_step_z; + __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + in2_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * in2_step_y + get_global_id(2) * in2_step_z; + __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + out_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * out_step_y + get_global_id(2) * out_step_z; // Load values - VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT) - in_a = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN1, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE_IN1 *)in1_addr)), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)); - VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT) - in_b = CONVERT((VEC_DATA_TYPE(DATA_TYPE_IN2, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE_IN2 *)in2_addr)), VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT)); + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT) + in_a = CONVERT((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE *)in1_addr)), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)); + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT) + in_b = CONVERT((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE *)in2_addr)), VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT)); // Calculate and store result - VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT) + VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT) res0 = OP(in_a, in_b); #if defined(ACTIVATION_TYPE) - res0 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE_OUT, VEC_SIZE_OUT, res0, A_VAL, B_VAL); + res0 = ACTIVATION(ACTIVATION_TYPE, DATA_TYPE, VEC_SIZE_OUT, res0, A_VAL, B_VAL); #endif // defined(ACTIVATION_TYPE) - STORE_VECTOR_SELECT(res, DATA_TYPE_OUT, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0) + STORE_VECTOR_SELECT(res, DATA_TYPE, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0) } -#endif /* defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(DATA_TYPE_IN1) && defined(DATA_TYPE_IN2) && defined(DATA_TYPE_OUT) */ +#endif /* defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(DATA_TYPE) */ diff --git a/src/core/CL/cl_kernels/elementwise_operation_quantized.cl b/src/core/CL/cl_kernels/elementwise_operation_quantized.cl index a08c3b2d47..0051babf03 100644 --- a/src/core/CL/cl_kernels/elementwise_operation_quantized.cl +++ b/src/core/CL/cl_kernels/elementwise_operation_quantized.cl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -28,7 +28,7 @@ #define MAX(x, y) max((x), (y)) #define MIN(x, y) min((x), (y)) #define SQUARED_DIFF(x, y) (x - y) * (x - y) -#define PRELU(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE_OUT)0), SELECT_VEC_DATA_TYPE(float, VEC_SIZE_OUT)))) +#define PRELU(x, y) (select(y * x, x, CONVERT((x > (DATA_TYPE)0), SELECT_VEC_DATA_TYPE(float, VEC_SIZE_OUT)))) #define DIV(x, y) (x / y) #define CONVERT_RTE(x, type) (convert_##type##_rte((x))) @@ -37,11 +37,11 @@ #define OP_FUN_NAME_STR(op) elementwise_operation_##op##_quantized #define OP_FUN_NAME(op) OP_FUN_NAME_STR(op) -#if defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(OFFSET_IN1) && defined(OFFSET_IN2) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE_OUT) +#if defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(OFFSET_IN1) && defined(OFFSET_IN2) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE) #define VEC_FLOAT VEC_DATA_TYPE(float, VEC_SIZE_OUT) #define VEC_INT VEC_DATA_TYPE(int, VEC_SIZE_OUT) -#define VEC_TYPE VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT) +#define VEC_TYPE VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT) /** This function executes an element-wise operation among two tensors. * @@ -57,7 +57,7 @@ * @note To perform saturating operation -DSATURATE has to be passed to the compiler otherwise wrapping policy will be used. * @note The element-wise operation to be executed has to be passed at compile time using -DOP (e.g., -DOP=ADD) * @note For QSYMM16 operations OFFSET_IN1, OFFSET_IN2 and OFFSET_OUT must be set to zero - * @note The data type must be passed at compile time using -DDATA_TYPE_OUT, i.e. -DDATA_TYPE_OUT=uchar + * @note The data type must be passed at compile time using -DDATA_TYPE, i.e. -DDATA_TYPE=uchar * * @param[in] in1_ptr Pointer to the source tensor. Supported data types: QASYMM8/QSYMM16 * @param[in] in1_stride_x Stride of the source tensor in X dimension (in bytes) @@ -102,12 +102,12 @@ __kernel void OP_FUN_NAME(OP)( uint out_x_offs = max((int)(get_global_id(0) * VEC_SIZE_OUT - (VEC_SIZE_OUT - VEC_SIZE_LEFTOVER) % VEC_SIZE_OUT), 0); // Get pixels pointer - __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + in1_x_offs * sizeof(DATA_TYPE_OUT) + get_global_id(1) * in1_step_y + get_global_id(2) * in1_step_z; - __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + in2_x_offs * sizeof(DATA_TYPE_OUT) + get_global_id(1) * in2_step_y + get_global_id(2) * in2_step_z; - __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + out_x_offs * sizeof(DATA_TYPE_OUT) + get_global_id(1) * out_step_y + get_global_id(2) * out_step_z; + __global uchar *in1_addr = in1_ptr + in1_offset_first_element_in_bytes + in1_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * in1_step_y + get_global_id(2) * in1_step_z; + __global uchar *in2_addr = in2_ptr + in2_offset_first_element_in_bytes + in2_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * in2_step_y + get_global_id(2) * in2_step_z; + __global uchar *out_addr = out_ptr + out_offset_first_element_in_bytes + out_x_offs * sizeof(DATA_TYPE) + get_global_id(1) * out_step_y + get_global_id(2) * out_step_z; - VEC_INT in_a = CONVERT((VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE_OUT *)in1_addr)), VEC_INT); - VEC_INT in_b = CONVERT((VEC_DATA_TYPE(DATA_TYPE_OUT, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE_OUT *)in2_addr)), VEC_INT); + VEC_INT in_a = CONVERT((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN1)(0, (__global DATA_TYPE *)in1_addr)), VEC_INT); + VEC_INT in_b = CONVERT((VEC_DATA_TYPE(DATA_TYPE, VEC_SIZE_OUT))(VLOAD(VEC_SIZE_IN2)(0, (__global DATA_TYPE *)in2_addr)), VEC_INT); in_a = SUB(in_a, (VEC_INT)((int)OFFSET_IN1)); in_b = SUB(in_b, (VEC_INT)((int)OFFSET_IN2)); @@ -118,6 +118,6 @@ __kernel void OP_FUN_NAME(OP)( const VEC_TYPE res0 = CONVERT_SAT(CONVERT_DOWN(qresf32, VEC_INT), VEC_TYPE); // Store result - STORE_VECTOR_SELECT(res, DATA_TYPE_OUT, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0) + STORE_VECTOR_SELECT(res, DATA_TYPE, out_addr, VEC_SIZE_OUT, VEC_SIZE_LEFTOVER, VEC_SIZE_LEFTOVER != 0 && get_global_id(0) == 0) } -#endif /* defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(OFFSET_IN1) && defined(OFFSET_IN2) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE_OUT) */ +#endif /* defined(OP) && defined(VEC_SIZE_IN1) && defined(VEC_SIZE_IN2) && defined(VEC_SIZE_OUT) && defined(OFFSET_IN1) && defined(OFFSET_IN2) && defined(OFFSET_OUT) && defined(SCALE_IN1) && defined(SCALE_IN2) && defined(SCALE_OUT) && defined(DATA_TYPE) */ diff --git a/src/core/cpu/kernels/CpuAddKernel.cpp b/src/core/cpu/kernels/CpuAddKernel.cpp index 12766037a7..61b7b19443 100644 --- a/src/core/cpu/kernels/CpuAddKernel.cpp +++ b/src/core/cpu/kernels/CpuAddKernel.cpp @@ -45,14 +45,7 @@ namespace { struct AddSelectorData { - /* Data types for all ITensorInfos: - dt1 -> src0 - dt2 -> src1 - dt3 -> dst - */ - DataType dt1; - DataType dt2; - DataType dt3; + DataType dt; const CPUInfo &ci; }; @@ -72,7 +65,7 @@ static const AddKernel available_kernels[] = "sve2_qu8_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)) && data.ci.has_sve(); + return (data.dt == DataType::QASYMM8) && data.ci.has_sve(); }, REGISTER_QASYMM8_SVE(arm_compute::cpu::add_qasymm8_sve) }, @@ -80,7 +73,7 @@ static const AddKernel available_kernels[] = "sve2_qs8_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)) && data.ci.has_sve(); + return (data.dt == DataType::QASYMM8_SIGNED) && data.ci.has_sve(); }, REGISTER_QASYMM8_SIGNED_SVE(arm_compute::cpu::add_qasymm8_signed_sve) }, @@ -88,7 +81,7 @@ static const AddKernel available_kernels[] = "sve2_qs16_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)) && data.ci.has_sve(); + return (data.dt == DataType::QSYMM16) && data.ci.has_sve(); }, REGISTER_QSYMM16_SVE(arm_compute::cpu::add_qsymm16_sve) }, @@ -98,7 +91,7 @@ static const AddKernel available_kernels[] = "sve_fp32_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)) && data.ci.has_sve(); + return (data.dt == DataType::F32) && data.ci.has_sve(); }, REGISTER_FP32_SVE(arm_compute::cpu::add_same_sve) }, @@ -106,7 +99,7 @@ static const AddKernel available_kernels[] = "sve_fp16_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_sve(); + return (data.dt == DataType::F16) && data.ci.has_sve(); }, REGISTER_FP16_SVE(arm_compute::cpu::add_same_sve) }, @@ -114,7 +107,7 @@ static const AddKernel available_kernels[] = "sve_u8_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)) && data.ci.has_sve(); + return (data.dt == DataType::U8) && data.ci.has_sve(); }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve) }, @@ -122,7 +115,7 @@ static const AddKernel available_kernels[] = "sve_s16_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)) && data.ci.has_sve(); + return (data.dt == DataType::S16) && data.ci.has_sve(); }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve) }, @@ -130,39 +123,15 @@ static const AddKernel available_kernels[] = "sve_s32_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)) && data.ci.has_sve(); + return (data.dt == DataType::S32) && data.ci.has_sve(); }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_same_sve) }, - { - "sve_u8_s16_s16_add", - [](const AddSelectorData & data) - { - return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)) && data.ci.has_sve(); - }, - REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_s16_s16_sve) - }, - { - "sve_s16_u8_s16_add", - [](const AddSelectorData & data) - { - return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)) && data.ci.has_sve(); - }, - REGISTER_INTEGER_SVE(arm_compute::cpu::add_s16_u8_s16_sve) - }, - { - "sve_u8_u8_s16_add", - [](const AddSelectorData & data) - { - return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)) && data.ci.has_sve(); - }, - REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_u8_s16_sve) - }, #endif /* defined(ARM_COMPUTE_ENABLE_SVE) */ #if defined(ARM_COMPUTE_ENABLE_NEON) { "neon_fp32_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::F32); }, REGISTER_FP32_NEON(arm_compute::cpu::add_same_neon) }, #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) @@ -170,56 +139,41 @@ static const AddKernel available_kernels[] = "neon_fp16_add", [](const AddSelectorData & data) { - return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)) && data.ci.has_fp16(); + return (data.dt == DataType::F16) && data.ci.has_fp16(); }, REGISTER_FP16_NEON(arm_compute::cpu::add_same_neon) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { "neon_u8_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::U8); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon) }, { "neon_s16_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::S16); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon) }, { "neon_s32_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::S32); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_same_neon) }, - { - "neon_u8_s16_s16_add", - [](const AddSelectorData & data) { return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_s16_s16_neon) - }, - { - "neon_s16_u8_s16_add", - [](const AddSelectorData & data) { return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_u8_s16_neon) - }, - { - "neon_u8_u8_s16_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_u8_s16_neon) - }, #endif /* defined(ARM_COMPUTE_ENABLE_NEON) */ #if defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) { "neon_qu8_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::add_qasymm8_neon) }, { "neon_qs8_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::add_qasymm8_signed_neon) }, { "neon_qs16_add", - [](const AddSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); }, + [](const AddSelectorData & data) { return (data.dt == DataType::QSYMM16); }, REGISTER_QSYMM16_NEON(arm_compute::cpu::add_qsymm16_neon) }, #endif /* defined(ARM_COMPUTE_ENABLE_NEON) || defined(ARM_COMPUTE_ENABLE_SVE) */ @@ -231,11 +185,11 @@ static const AddKernel available_kernels[] = * * @return A matching micro-kernel else nullptr */ -const AddKernel *get_implementation(const CPUInfo &cpuinfo, DataType dt1, DataType dt2, DataType dt3) +const AddKernel *get_implementation(const CPUInfo &cpuinfo, DataType dt) { for(const auto &uk : available_kernels) { - if(uk.is_selected({ dt1, dt2, dt3, cpuinfo })) + if(uk.is_selected({ dt, cpuinfo })) { return &uk; } @@ -251,9 +205,7 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, - DataType::S32, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1); const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); @@ -265,25 +217,12 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons // Validate in case of configured dst if(dst.total_size() > 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::U8) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32 && dst.data_type() == DataType::S32) - && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32 && dst.data_type() == DataType::F32) - && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16 && dst.data_type() == DataType::F16) - && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && dst.data_type() == DataType::QASYMM8) - && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && dst.data_type() == DataType::QASYMM8_SIGNED) - && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && dst.data_type() == DataType::QSYMM16), - "You called addition with the wrong image formats"); - + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &dst); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0), "Wrong shape for dst"); } - const auto *uk = get_implementation(CPUInfo::get(), src0.data_type(), src1.data_type(), dst.data_type()); + const auto *uk = get_implementation(CPUInfo::get(), src0.data_type()); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); return Status{}; @@ -294,38 +233,8 @@ std::pair validate_and_configure_window(const ITensorInfo &src0, const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); // Auto initialize dst if not initialized - { - set_shape_if_empty(dst, out_shape); - - if(src0.data_type() == DataType::S16 || src1.data_type() == DataType::S16) - { - set_format_if_unknown(dst, Format::S16); - } - if(src0.data_type() == DataType::S32 || src1.data_type() == DataType::S32) - { - set_format_if_unknown(dst, Format::S32); - } - else if(src0.data_type() == DataType::F16 || src1.data_type() == DataType::F16) - { - set_format_if_unknown(dst, Format::F16); - } - else if(src0.data_type() == DataType::F32 || src1.data_type() == DataType::F32) - { - set_format_if_unknown(dst, Format::F32); - } - else if(src0.data_type() == DataType::QASYMM8 || src1.data_type() == DataType::QASYMM8) - { - set_data_type_if_unknown(dst, DataType::QASYMM8); - } - else if(src0.data_type() == DataType::QASYMM8_SIGNED || src1.data_type() == DataType::QASYMM8_SIGNED) - { - set_data_type_if_unknown(dst, DataType::QASYMM8_SIGNED); - } - else if(src0.data_type() == DataType::QSYMM16 || src1.data_type() == DataType::QSYMM16) - { - set_data_type_if_unknown(dst, DataType::QSYMM16); - } - } + set_shape_if_empty(dst, out_shape); + set_data_type_if_unknown(dst, src0.data_type()); Window win = calculate_max_window(out_shape, Steps()); @@ -339,7 +248,7 @@ void CpuAddKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst, policy)); - const auto uk = get_implementation(CPUInfo::get(), src0->data_type(), src1->data_type(), dst->data_type()); + const auto uk = get_implementation(CPUInfo::get(), src0->data_type()); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); _policy = policy; diff --git a/src/core/cpu/kernels/CpuAddKernel.h b/src/core/cpu/kernels/CpuAddKernel.h index 717d0132c6..1205b45dfb 100644 --- a/src/core/cpu/kernels/CpuAddKernel.h +++ b/src/core/cpu/kernels/CpuAddKernel.h @@ -44,9 +44,6 @@ public: * Valid configurations (src0,src1) -> dst : * * - (U8,U8) -> U8 - * - (U8,U8) -> S16 - * - (S16,U8) -> S16 - * - (U8,S16) -> S16 * - (S16,S16) -> S16 * - (S32,S32) -> S32 * - (F16,F16) -> F16 diff --git a/src/core/cpu/kernels/CpuSubKernel.cpp b/src/core/cpu/kernels/CpuSubKernel.cpp index 098a324377..fa7a55805e 100644 --- a/src/core/cpu/kernels/CpuSubKernel.cpp +++ b/src/core/cpu/kernels/CpuSubKernel.cpp @@ -41,9 +41,7 @@ namespace { struct SubSelectorData { - DataType dt1; - DataType dt2; - DataType dt3; + DataType dt; }; using SubSelectorPtr = std::add_pointer::type; @@ -60,59 +58,44 @@ static const SubKernel available_kernels[] = { { "neon_fp32_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F32)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::F32); }, REGISTER_FP32_NEON(arm_compute::cpu::sub_same_neon) }, #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) { "neon_fp16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::F16)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::F16); }, REGISTER_FP16_NEON(arm_compute::cpu::sub_same_neon) }, #endif /* defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) */ { "neon_u8_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::U8)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::U8); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon) }, { "neon_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S16)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::S16); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon) }, { "neon_s32_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == data.dt3) && (data.dt1 == DataType::S32)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::S32); }, REGISTER_INTEGER_NEON(arm_compute::cpu::sub_same_neon) }, - { - "neon_u8_s16_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == DataType::U8) && (data.dt2 == DataType::S16)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::sub_u8_s16_s16_neon) - }, - { - "neon_s16_u8_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == DataType::S16) && (data.dt2 == DataType::U8)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::sub_s16_u8_s16_neon) - }, - { - "neon_u8_u8_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt3 == DataType::S16)); }, - REGISTER_INTEGER_NEON(arm_compute::cpu::sub_u8_u8_s16_neon) - }, { "neon_qu8_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::sub_qasymm8_neon) }, { "neon_qs8_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QASYMM8_SIGNED)); }, + [](const SubSelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::sub_qasymm8_signed_neon) }, { - "neon_s16_sub", - [](const SubSelectorData & data) { return ((data.dt1 == data.dt2) && (data.dt1 == DataType::QSYMM16)); }, + "neon_qs16_sub", + [](const SubSelectorData & data) { return (data.dt == DataType::QSYMM16); }, REGISTER_QSYMM16_NEON(arm_compute::cpu::sub_qsymm16_neon) }, }; @@ -123,11 +106,11 @@ static const SubKernel available_kernels[] = * * @return A matching micro-kernel else nullptr */ -const SubKernel *get_implementation(DataType dt1, DataType dt2, DataType dt3) +const SubKernel *get_implementation(DataType dt) { for(const auto &uk : available_kernels) { - if(uk.is_selected({ dt1, dt2, dt3 })) + if(uk.is_selected({ dt })) { return &uk; } @@ -141,54 +124,21 @@ inline Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src0); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, - DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16, DataType::S16, DataType::S32, DataType::F16, - DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1); - const auto *uk = get_implementation(src0.data_type(), src1.data_type(), dst.data_type()); + const auto *uk = get_implementation(src0.data_type()); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8) - && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8) - && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED) - && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32) - && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32) - && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16), - "You called subtract with the wrong image formats"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - (src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && policy == ConvertPolicy::WRAP) - || (src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && policy == ConvertPolicy::WRAP) - || (src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && policy == ConvertPolicy::WRAP), - "Convert policy cannot be WRAP if datatype is quantized"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(is_data_type_quantized(src0.data_type()) && (policy == ConvertPolicy::WRAP), + "Convert policy cannot be WRAP if datatype is quantized"); // Validate in case of configured dst if(dst.total_size() > 0) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::U8) - && !(src0.data_type() == DataType::QASYMM8 && src1.data_type() == DataType::QASYMM8 && dst.data_type() == DataType::QASYMM8) - && !(src0.data_type() == DataType::QASYMM8_SIGNED && src1.data_type() == DataType::QASYMM8_SIGNED && dst.data_type() == DataType::QASYMM8_SIGNED) - && !(src0.data_type() == DataType::QSYMM16 && src1.data_type() == DataType::QSYMM16 && dst.data_type() == DataType::QSYMM16) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::U8 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::U8 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S16 && src1.data_type() == DataType::S16 && dst.data_type() == DataType::S16) - && !(src0.data_type() == DataType::S32 && src1.data_type() == DataType::S32 && dst.data_type() == DataType::S32) - && !(src0.data_type() == DataType::F32 && src1.data_type() == DataType::F32 && dst.data_type() == DataType::F32) - && !(src0.data_type() == DataType::F16 && src1.data_type() == DataType::F16 && dst.data_type() == DataType::F16), - "You called subtract with the wrong image formats"); - + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &dst); ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0), "Wrong shape for dst"); } @@ -205,8 +155,9 @@ void CpuSubKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I // Auto initialize dst if not initialized set_shape_if_empty(*dst, out_shape); + set_data_type_if_unknown(*dst, src0->data_type()); - const auto *uk = get_implementation(src0->data_type(), src1->data_type(), dst->data_type()); + const auto *uk = get_implementation(src0->data_type()); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); _policy = policy; diff --git a/src/core/cpu/kernels/CpuSubKernel.h b/src/core/cpu/kernels/CpuSubKernel.h index b9160bd150..cb64e64cfa 100644 --- a/src/core/cpu/kernels/CpuSubKernel.h +++ b/src/core/cpu/kernels/CpuSubKernel.h @@ -45,11 +45,8 @@ public: * Valid configurations (src0,src1) -> dst : * * - (U8,U8) -> U8 - * - (U8,U8) -> S16 * - (QASYMM8, QASYMM8) -> QASYMM8 * - (QASYMM8_SIGNED, QASYMM8_SIGNED) -> QASYMM8_SIGNED - * - (S16,U8) -> S16 - * - (U8,S16) -> S16 * - (S16,S16) -> S16 * - (S32,S32) -> S32 * - (F16,F16) -> F16 diff --git a/src/core/cpu/kernels/add/neon/integer.cpp b/src/core/cpu/kernels/add/neon/integer.cpp deleted file mode 100644 index 24a0ac3b7c..0000000000 --- a/src/core/cpu/kernels/add/neon/integer.cpp +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Copyright (c) 2020-2021 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/utils/misc/Traits.h" -#include "src/core/NEON/wrapper/wrapper.h" -#include "src/core/helpers/WindowHelpers.h" - -namespace arm_compute -{ -namespace cpu -{ -void add_u8_u8_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) -{ - // Create input windows - Window win = window; - Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape()); - Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape()); - - // Clear X Dimension on execution window as we handle manually - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input1(src0, input1_win); - Iterator input2(src1, input2_win); - Iterator output(dst, win); - - const int window_step_x = 8; - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - - execute_window_loop(win, [&](const Coordinates &) - { - const auto input1_ptr = reinterpret_cast(input1.ptr()); - const auto input2_ptr = reinterpret_cast(input2.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); - - if(policy == ConvertPolicy::WRAP) - { - // Compute S elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x))); - const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x))); - wrapper::vstore(output_ptr + x, wrapper::vadd(vin1, vin2)); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = static_cast(*(input1_ptr + x)) + static_cast(*(input2_ptr + x)); - } - } - else - { - // Compute S elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x))); - const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x))); - wrapper::vstore(output_ptr + x, wrapper::vqadd(vin1, vin2)); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = wrapper::add_sat(static_cast(*(input1_ptr + x)), - static_cast(*(input2_ptr + x))); - } - } - }, - input1, input2, output); -} - -void add_s16_u8_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) -{ - // Create input windows - Window win = window; - Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape()); - Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape()); - - // Clear X Dimension on execution window as we handle manually - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input1(src0, input1_win); - Iterator input2(src1, input2_win); - Iterator output(dst, win); - - const int window_step_x = 8; - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - - execute_window_loop(win, [&](const Coordinates &) - { - const auto input1_ptr = reinterpret_cast(input1.ptr()); - const auto input2_ptr = reinterpret_cast(input2.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); - - if(policy == ConvertPolicy::WRAP) - { - // Compute S elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin1 = wrapper::vloadq(input1_ptr + x); - const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x))); - wrapper::vstore(output_ptr + x, wrapper::vadd(vin1, vin2)); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = *(input1_ptr + x) + static_cast(*(input2_ptr + x)); - } - } - else - { - // Compute S elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin1 = wrapper::vloadq(input1_ptr + x); - const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x))); - wrapper::vstore(output_ptr + x, wrapper::vqadd(vin1, vin2)); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = wrapper::add_sat(*(input1_ptr + x), static_cast(*(input2_ptr + x))); - } - } - }, - input1, input2, output); -} - -void add_u8_s16_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) -{ - // Simply swap the two input buffers: - add_s16_u8_s16_neon(src1, src0, dst, policy, window); -} -} // namespace cpu -} // namespace arm_compute \ No newline at end of file diff --git a/src/core/cpu/kernels/add/neon/list.h b/src/core/cpu/kernels/add/neon/list.h index 3ab03dd40e..379bd32fb1 100644 --- a/src/core/cpu/kernels/add/neon/list.h +++ b/src/core/cpu/kernels/add/neon/list.h @@ -38,9 +38,6 @@ namespace cpu DECLARE_ADD_KERNEL(add_qasymm8_neon); DECLARE_ADD_KERNEL(add_qasymm8_signed_neon); DECLARE_ADD_KERNEL(add_qsymm16_neon); -DECLARE_ADD_KERNEL(add_s16_u8_s16_neon); -DECLARE_ADD_KERNEL(add_u8_s16_s16_neon); -DECLARE_ADD_KERNEL(add_u8_u8_s16_neon); #undef DECLARE_ADD_KERNEL diff --git a/src/core/cpu/kernels/add/sve/integer.cpp b/src/core/cpu/kernels/add/sve/integer.cpp deleted file mode 100644 index bd8179205b..0000000000 --- a/src/core/cpu/kernels/add/sve/integer.cpp +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright (c) 2020-2021 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#if defined(__ARM_FEATURE_SVE) -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/utils/misc/Traits.h" -#include "src/core/NEON/SVEMath.h" -#include "src/core/NEON/wrapper/intrinsics/intrinsics.h" -#include - -namespace arm_compute -{ -namespace cpu -{ -void add_u8_u8_s16_sve(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) -{ - // Create input windows - Window win = window; - Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape()); - Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape()); - - // Clear X Dimension on execution window as we handle manually - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input1(src0, input1_win); - Iterator input2(src1, input2_win); - Iterator output(dst, win); - - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - const auto all_true_pg = svptrue_b8(); - - execute_window_loop(win, [&](const Coordinates &) - { - const auto input1_ptr = reinterpret_cast(input1.ptr()); - const auto input2_ptr = reinterpret_cast(input2.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); - - if(policy == ConvertPolicy::WRAP) - { - int x = window_start_x; - svbool_t pg_u = svwhilelt_b8(x, window_end_x); - svbool_t pg_0 = svwhilelt_b16(x, window_end_x); - svbool_t pg_1 = svwhilelt_b16(x, static_cast(window_end_x + svcnth())); - do - { - const auto vsrc0 = svld1(pg_u, input1_ptr + x); - const auto vsrc1 = svld1(pg_u, input2_ptr + x); - - const auto vsrc0_lo = svreinterpret_s16_u16(svunpklo(vsrc0)); - const auto vsrc0_hi = svreinterpret_s16_u16(svunpkhi(vsrc0)); - const auto vsrc1_lo = svreinterpret_s16_u16(svunpklo(vsrc1)); - const auto vsrc1_hi = svreinterpret_s16_u16(svunpkhi(vsrc1)); - svst1(pg_0, output_ptr + x, svqadd(vsrc0_lo, vsrc1_lo)); - svst1(pg_1, output_ptr + x + svcnth(), svqadd(vsrc0_hi, vsrc1_hi)); - - x += svcntb(); - pg_u = svwhilelt_b8(x, window_end_x); - pg_0 = svwhilelt_b16(x, window_end_x); - pg_1 = svwhilelt_b16(x, static_cast(window_end_x + svcnth())); - } - while(svptest_any(all_true_pg, pg_u)); - } - else - { - int x = window_start_x; - svbool_t pg_u = svwhilelt_b8(x, window_end_x); - svbool_t pg_0 = svwhilelt_b16(x, window_end_x); - svbool_t pg_1 = svwhilelt_b16(x, static_cast(window_end_x + svcnth())); - do - { - const auto vsrc0 = svld1(pg_u, input1_ptr + x); - const auto vsrc1 = svld1(pg_u, input2_ptr + x); - - const auto vsrc0_lo = svreinterpret_s16_u16(svunpklo(vsrc0)); - const auto vsrc0_hi = svreinterpret_s16_u16(svunpkhi(vsrc0)); - const auto vsrc1_lo = svreinterpret_s16_u16(svunpklo(vsrc1)); - const auto vsrc1_hi = svreinterpret_s16_u16(svunpkhi(vsrc1)); - svst1(pg_0, output_ptr + x, svqadd(vsrc0_lo, vsrc1_lo)); - svst1(pg_1, output_ptr + x + svcnth(), svqadd(vsrc0_hi, vsrc1_hi)); - - x += svcntb(); - pg_u = svwhilelt_b8(x, window_end_x); - pg_0 = svwhilelt_b16(x, window_end_x); - pg_1 = svwhilelt_b16(x, static_cast(window_end_x + svcnth())); - } - while(svptest_any(all_true_pg, pg_u)); - } - }, - input1, input2, output); -} - -void add_s16_u8_s16_sve(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) -{ - // Create input windows - Window win = window; - Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape()); - Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape()); - - // Clear X Dimension on execution window as we handle manually - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input1(src0, input1_win); - Iterator input2(src1, input2_win); - Iterator output(dst, win); - - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - const auto all_true_pg = svptrue_b8(); - - execute_window_loop(win, [&](const Coordinates &) - { - const auto input1_ptr = reinterpret_cast(input1.ptr()); - const auto input2_ptr = reinterpret_cast(input2.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); - - if(policy == ConvertPolicy::WRAP) - { - int x = window_start_x; - svbool_t pg_u = svwhilelt_b8(x, window_end_x); - svbool_t pg_0 = svwhilelt_b16(x, window_end_x); - svbool_t pg_1 = svwhilelt_b16(x + static_cast(svcnth()), window_end_x); - do - { - const auto vsrc0_0 = svld1_s16(pg_0, input1_ptr + x); - const auto vsrc0_1 = svld1_s16(pg_1, input1_ptr + x + svcnth()); - const auto vsrc1_u8 = svld1_u8(pg_u, input2_ptr + x); - const auto vsrc1_0 = svreinterpret_s16_u16(svunpklo(vsrc1_u8)); - const auto vsrc1_1 = svreinterpret_s16_u16(svunpkhi(vsrc1_u8)); - svst1_s16(pg_0, output_ptr + x, svadd_s16_z(pg_0, vsrc0_0, vsrc1_0)); - svst1_s16(pg_1, output_ptr + x + svcnth(), svadd_s16_z(pg_1, vsrc0_1, vsrc1_1)); - - x += svcntb(); - pg_u = svwhilelt_b8(x, window_end_x); - pg_0 = svwhilelt_b16(x, window_end_x); - pg_1 = svwhilelt_b16(x + static_cast(svcnth()), window_end_x); - } - while(svptest_any(all_true_pg, pg_u)); - } - else - { - int x = window_start_x; - svbool_t pg_u = svwhilelt_b8(x, window_end_x); - svbool_t pg_0 = svwhilelt_b16(x, window_end_x); - svbool_t pg_1 = svwhilelt_b16(x + static_cast(svcnth()), window_end_x); - do - { - const auto vsrc0_0 = svld1_s16(pg_0, input1_ptr + x); - const auto vsrc0_1 = svld1_s16(pg_1, input1_ptr + x + svcnth()); - const auto vsrc1_u8 = svld1_u8(pg_u, input2_ptr + x); - const auto vsrc1_0 = svreinterpret_s16_u16(svunpklo(vsrc1_u8)); - const auto vsrc1_1 = svreinterpret_s16_u16(svunpkhi(vsrc1_u8)); - - svst1_s16(pg_0, output_ptr + x, svqadd(vsrc0_0, vsrc1_0)); - svst1_s16(pg_1, output_ptr + x + svcnth(), svqadd(vsrc0_1, vsrc1_1)); - - x += svcntb(); - pg_u = svwhilelt_b8(x, window_end_x); - pg_0 = svwhilelt_b16(x, window_end_x); - pg_1 = svwhilelt_b16(x + static_cast(svcnth()), window_end_x); - } - while(svptest_any(all_true_pg, pg_u)); - } - }, - input1, input2, output); -} - -void add_u8_s16_s16_sve(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) -{ - // Simply swap the two input buffers: - add_s16_u8_s16_sve(src1, src0, dst, policy, window); -} -} // namespace cpu -} // namespace arm_compute -#endif /* defined(__ARM_FEATURE_SVE) */ \ No newline at end of file diff --git a/src/core/cpu/kernels/add/sve/list.h b/src/core/cpu/kernels/add/sve/list.h index 9e439497c9..4d29c2a8f1 100644 --- a/src/core/cpu/kernels/add/sve/list.h +++ b/src/core/cpu/kernels/add/sve/list.h @@ -42,9 +42,6 @@ namespace cpu DECLARE_ADD_KERNEL(add_qasymm8_sve); DECLARE_ADD_KERNEL(add_qasymm8_signed_sve); DECLARE_ADD_KERNEL(add_qsymm16_sve); -DECLARE_ADD_KERNEL(add_s16_u8_s16_sve); -DECLARE_ADD_KERNEL(add_u8_s16_s16_sve); -DECLARE_ADD_KERNEL(add_u8_u8_s16_sve); #undef DECLARE_ADD_KERNEL diff --git a/src/core/cpu/kernels/sub/neon/integer.cpp b/src/core/cpu/kernels/sub/neon/integer.cpp deleted file mode 100644 index bba73df1e8..0000000000 --- a/src/core/cpu/kernels/sub/neon/integer.cpp +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Copyright (c) 2021 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ -#include "arm_compute/core/Helpers.h" -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/Types.h" -#include "arm_compute/core/utils/misc/Traits.h" -#include "src/core/NEON/wrapper/wrapper.h" -#include "src/core/helpers/WindowHelpers.h" - -namespace arm_compute -{ -namespace cpu -{ -namespace -{ -void sub_s16_u8_s16_impl(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window, bool is_swapped) -{ - // Create input windows - Window win = window; - Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape()); - Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape()); - - // Clear X Dimension on execution window as we handle manually - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input1(src0, input1_win); - Iterator input2(src1, input2_win); - Iterator output(dst, win); - - const int window_step_x = 8; - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - - execute_window_loop(win, [&](const Coordinates &) - { - const auto input1_ptr = reinterpret_cast(input1.ptr()); - const auto input2_ptr = reinterpret_cast(input2.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); - - if(policy == ConvertPolicy::WRAP) - { - // Compute S elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin1 = wrapper::vloadq(input1_ptr + x); - const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x))); - const auto res = is_swapped ? wrapper::vsub(vin2, vin1) : wrapper::vsub(vin1, vin2); - wrapper::vstore(output_ptr + x, res); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - const auto res = is_swapped ? static_cast(*(input2_ptr + x)) - *(input1_ptr + x) : *(input1_ptr + x) - static_cast(*(input2_ptr + x)); - *(output_ptr + x) = res; - } - } - else - { - // Compute S elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin1 = wrapper::vloadq(input1_ptr + x); - const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x))); - const auto res = is_swapped ? wrapper::vqsub(vin2, vin1) : wrapper::vqsub(vin1, vin2); - wrapper::vstore(output_ptr + x, res); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - const auto res = is_swapped ? wrapper::sub_sat(static_cast(*(input2_ptr + x)), *(input1_ptr + x)) : wrapper::sub_sat(*(input1_ptr + x), static_cast(*(input2_ptr + x))); - *(output_ptr + x) = res; - } - } - }, - input1, input2, output); -} -} - -void sub_s16_u8_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) -{ - sub_s16_u8_s16_impl(src1, src0, dst, policy, window, false); -} - -void sub_u8_s16_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) -{ - // Swap arguments - sub_s16_u8_s16_impl(src1, src0, dst, policy, window, true); -} - -void sub_u8_u8_s16_neon(const ITensor *src0, const ITensor *src1, ITensor *dst, const ConvertPolicy &policy, const Window &window) -{ - // Create input windows - Window win = window; - Window input1_win = window.broadcast_if_dimension_le_one(src0->info()->tensor_shape()); - Window input2_win = window.broadcast_if_dimension_le_one(src1->info()->tensor_shape()); - - // Clear X Dimension on execution window as we handle manually - win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input1_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - input2_win.set(Window::DimX, Window::Dimension(0, 1, 1)); - - Iterator input1(src0, input1_win); - Iterator input2(src1, input2_win); - Iterator output(dst, win); - - const int window_step_x = 8; - const auto window_start_x = static_cast(window.x().start()); - const auto window_end_x = static_cast(window.x().end()); - - execute_window_loop(win, [&](const Coordinates &) - { - const auto input1_ptr = reinterpret_cast(input1.ptr()); - const auto input2_ptr = reinterpret_cast(input2.ptr()); - const auto output_ptr = reinterpret_cast(output.ptr()); - - if(policy == ConvertPolicy::WRAP) - { - // Compute S elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x))); - const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x))); - wrapper::vstore(output_ptr + x, wrapper::vsub(vin1, vin2)); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = static_cast(*(input1_ptr + x)) - static_cast(*(input2_ptr + x)); - } - } - else - { - // Compute S elements per iteration - int x = window_start_x; - for(; x <= (window_end_x - window_step_x); x += window_step_x) - { - const auto vin1 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input1_ptr + x))); - const auto vin2 = vreinterpretq_s16_u16(wrapper::vmovl(wrapper::vload(input2_ptr + x))); - wrapper::vstore(output_ptr + x, wrapper::vqsub(vin1, vin2)); - } - - // Compute left-over elements - for(; x < window_end_x; ++x) - { - *(output_ptr + x) = wrapper::sub_sat(static_cast(*(input1_ptr + x)), - static_cast(*(input2_ptr + x))); - } - } - }, - input1, input2, output); -} - -} // namespace cpu -} // namespace arm_compute \ No newline at end of file diff --git a/src/core/cpu/kernels/sub/neon/list.h b/src/core/cpu/kernels/sub/neon/list.h index 1ab4e6367b..ac1346001a 100644 --- a/src/core/cpu/kernels/sub/neon/list.h +++ b/src/core/cpu/kernels/sub/neon/list.h @@ -38,9 +38,6 @@ namespace cpu DECLARE_SUB_KERNEL(sub_qasymm8_neon); DECLARE_SUB_KERNEL(sub_qasymm8_signed_neon); DECLARE_SUB_KERNEL(sub_qsymm16_neon); -DECLARE_SUB_KERNEL(sub_s16_u8_s16_neon); -DECLARE_SUB_KERNEL(sub_u8_s16_s16_neon); -DECLARE_SUB_KERNEL(sub_u8_u8_s16_neon); #undef DECLARE_SUB_KERNEL diff --git a/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp b/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp index b645353dd6..f005e9226e 100644 --- a/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp +++ b/src/core/gpu/cl/kernels/ClElementwiseKernel.cpp @@ -127,50 +127,29 @@ Status validate_arguments_with_arithmetic_rules(const ITensorInfo &src1, const I ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&src2); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, - DataType::S32, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src1, &src2); - const bool is_quantized = is_data_type_quantized(src1.data_type()) || is_data_type_quantized(src2.data_type()); - if(is_quantized) + if(is_data_type_quantized_symmetric(src1.data_type())) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src1, &src2); - - if(is_data_type_quantized_symmetric(src1.data_type())) - { - const int32_t in1_offset = src1.quantization_info().uniform().offset; - const int32_t in2_offset = src2.quantization_info().uniform().offset; - ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_offset != 0, "For quantized symmetric, offset must be zero"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(in2_offset != 0, "For quantized symmetric, offset must be zero"); - } + const int32_t in1_offset = src1.quantization_info().uniform().offset; + const int32_t in2_offset = src2.quantization_info().uniform().offset; + ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_offset != 0, "For quantized symmetric, offset must be zero"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(in2_offset != 0, "For quantized symmetric, offset must be zero"); } const TensorShape out_shape = TensorShape::broadcast_shape(src1.tensor_shape(), src2.tensor_shape()); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); // Validate in case of configured dst if(dst.total_size() > 0) { - ARM_COMPUTE_RETURN_ERROR_ON_F16_UNSUPPORTED(&dst); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, - DataType::S16, DataType::QSYMM16, DataType::F16, - DataType::S32, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MSG((dst.data_type() == DataType::U8) && ((src1.data_type() != DataType::U8) || (src2.data_type() != DataType::U8)), - "dst can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0), - "Wrong shape for dst"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src1, &dst); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0), "Wrong shape for dst"); - if(is_quantized) + if(is_data_type_quantized_symmetric(dst.data_type())) { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src1, &dst); - - if(is_data_type_quantized_symmetric(dst.data_type())) - { - const int32_t offset = dst.quantization_info().uniform().offset; - ARM_COMPUTE_RETURN_ERROR_ON_MSG(offset != 0, "For quantized symmetric, offset must be zero"); - } + const int32_t offset = dst.quantization_info().uniform().offset; + ARM_COMPUTE_RETURN_ERROR_ON_MSG(offset != 0, "For quantized symmetric, offset must be zero"); } } return Status{}; @@ -182,9 +161,7 @@ CLBuildOptions generate_build_options_with_arithmetic_rules(const ITensorInfo &s const unsigned int num_elems_processed_per_iteration = adjust_vec_size(vector_size_byte_opencl / dst.element_size(), dst.dimension(0)); - build_opts.add_option("-DDATA_TYPE_IN1=" + get_cl_type_from_data_type(src1.data_type())); - build_opts.add_option("-DDATA_TYPE_IN2=" + get_cl_type_from_data_type(src2.data_type())); - build_opts.add_option("-DDATA_TYPE_OUT=" + get_cl_type_from_data_type(dst.data_type())); + build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(src1.data_type())); build_opts.add_option("-DVEC_SIZE_IN1=" + support::cpp11::to_string(src1.dimension(0) == 1 ? 1 : num_elems_processed_per_iteration)); build_opts.add_option("-DVEC_SIZE_IN2=" + support::cpp11::to_string(src2.dimension(0) == 1 ? 1 : num_elems_processed_per_iteration)); build_opts.add_option("-DVEC_SIZE_OUT=" + support::cpp11::to_string(num_elems_processed_per_iteration)); @@ -220,32 +197,7 @@ std::pair validate_and_configure_window_for_arithmetic_operators const std::pair broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(src1, src2); const TensorShape &out_shape = broadcast_pair.first; - set_shape_if_empty(dst, out_shape); - - if(src1.data_type() == DataType::S16 || src2.data_type() == DataType::S16) - { - set_format_if_unknown(dst, Format::S16); - } - else if(src1.data_type() == DataType::F16 || src2.data_type() == DataType::F16) - { - set_format_if_unknown(dst, Format::F16); - } - else if(src1.data_type() == DataType::F32 || src2.data_type() == DataType::F32) - { - set_format_if_unknown(dst, Format::F32); - } - else if(src1.data_type() == DataType::QASYMM8 || src2.data_type() == DataType::QASYMM8) - { - set_data_type_if_unknown(dst, DataType::QASYMM8); - } - else if(src1.data_type() == DataType::QASYMM8_SIGNED || src2.data_type() == DataType::QASYMM8_SIGNED) - { - set_data_type_if_unknown(dst, DataType::QASYMM8_SIGNED); - } - else if(src1.data_type() == DataType::QSYMM16 || src2.data_type() == DataType::QSYMM16) - { - set_data_type_if_unknown(dst, DataType::QSYMM16); - } + auto_init_if_empty(dst, out_shape, 1, src1.data_type()); return configure_window_arithmetic_common(dst); } @@ -258,7 +210,6 @@ std::pair validate_and_configure_window_for_logical_binary_opera set_shape_if_empty(dst, out_shape); set_data_type_if_unknown(dst, DataType::U8); - // The arithmetic utility functions can be share return configure_window_arithmetic_common(dst); } @@ -266,7 +217,9 @@ std::pair validate_and_configure_window_for_division(ITensorInfo { const std::pair broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(src1, src2); const TensorShape &out_shape = broadcast_pair.first; + auto_init_if_empty(dst, out_shape, 1, src1.data_type()); + return configure_window_arithmetic_common(dst); } } // namespace diff --git a/src/runtime/cpu/operators/CpuAdd.h b/src/runtime/cpu/operators/CpuAdd.h index febb79e4cd..3ff135fe41 100644 --- a/src/runtime/cpu/operators/CpuAdd.h +++ b/src/runtime/cpu/operators/CpuAdd.h @@ -39,9 +39,6 @@ public: * Valid configurations (src0,src1) -> dst : * * - (U8,U8) -> U8 - * - (U8,U8) -> S16 - * - (S16,U8) -> S16 - * - (U8,S16) -> S16 * - (S16,S16) -> S16 * - (S32,S32) -> S32 * - (F16,F16) -> F16 diff --git a/src/runtime/cpu/operators/CpuSub.h b/src/runtime/cpu/operators/CpuSub.h index aad01fe4dc..07f5be89cd 100644 --- a/src/runtime/cpu/operators/CpuSub.h +++ b/src/runtime/cpu/operators/CpuSub.h @@ -39,11 +39,8 @@ public: * Valid configurations (src0,src1) -> dst : * * - (U8,U8) -> U8 - * - (U8,U8) -> S16 * - (QASYMM8, QASYMM8) -> QASYMM8 * - (QASYMM8_SIGNED, QASYMM8_SIGNED) -> QASYMM8_SIGNED - * - (S16,U8) -> S16 - * - (U8,S16) -> S16 * - (S16,S16) -> S16 * - (S32,S32) -> S32 * - (F16,F16) -> F16 diff --git a/tests/validation/CL/ArithmeticAddition.cpp b/tests/validation/CL/ArithmeticAddition.cpp index c74f6a3b23..9e3d9afc36 100644 --- a/tests/validation/CL/ArithmeticAddition.cpp +++ b/tests/validation/CL/ArithmeticAddition.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -44,23 +44,6 @@ namespace validation namespace { /** Input data sets **/ -const auto ArithmeticAdditionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)), framework::dataset::make("DataType", - DataType::U8)); -const auto ArithmeticAdditionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8), framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("DataType", - DataType::QASYMM8)); -const auto ArithmeticAdditionQASYMM8SignedDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("DataType", - DataType::QASYMM8_SIGNED)); -const auto ArithmeticAdditionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)), - framework::dataset::make("DataType", - DataType::QSYMM16)); -const auto ArithmeticAdditionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), - framework::dataset::make("DataType", DataType::S16)); -const auto ArithmeticAdditionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataType", DataType::F16)); -const auto ArithmeticAdditionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataType", DataType::F32)); const auto EmptyActivationFunctionsDataset = framework::dataset::make("ActivationInfo", { ActivationLayerInfo() }); const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo", @@ -76,22 +59,19 @@ TEST_SUITE(ArithmeticAddition) // *INDENT-OFF* // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( - framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes }), - framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("Expected", { true, true, false, false})), + framework::dataset::make("Expected", { true, false, false})), input1_info, input2_info, output_info, expected) { ARM_COMPUTE_EXPECT(bool(CLArithmeticAddition::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS); @@ -129,7 +109,7 @@ using CLArithmeticAdditionFixture = ArithmeticAdditionValidationFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionU8Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::U8)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -138,14 +118,14 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFixture, framework TEST_SUITE_END() // U8 TEST_SUITE(S16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionS16Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticAdditionFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticAdditionS16Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticAdditionFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -160,7 +140,7 @@ using CLArithmeticAdditionQuantizedFixture = ArithmeticAdditionValidationQuantiz TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallShapes(), - ArithmeticAdditionQASYMM8Dataset), + framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), @@ -171,12 +151,13 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture, } template using CLArithmeticAdditionBroadcastQuantizedFixture = ArithmeticAdditionValidationQuantizedBroadcastFixture; -FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallShapesBroadcast(), - ArithmeticAdditionQASYMM8Dataset), - framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), - framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })), - framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), - framework::dataset::make("OutQInfo", { QuantizationInfo(1.f / 255.f, 5) }))) +FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastQuantizedFixture, framework::DatasetMode::PRECOMMIT, + combine(combine(combine(combine(combine(datasets::SmallShapesBroadcast(), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), + framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })), + framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), + framework::dataset::make("OutQInfo", { QuantizationInfo(1.f / 255.f, 5) }))) { // Validate output validate(CLAccessor(_target), _reference); @@ -184,7 +165,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastQuantized TEST_SUITE_END() // QASYMM8 TEST_SUITE(QASYMM8_SIGNED) FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallShapes(), - ArithmeticAdditionQASYMM8SignedDataset), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 10) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), @@ -196,7 +177,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture, f TEST_SUITE_END() // QASYMM8_SIGNED TEST_SUITE(QSYMM16) FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(datasets::SmallShapes(), - ArithmeticAdditionQSYMM16Dataset), + framework::dataset::make("DataType", DataType::QSYMM16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })), @@ -213,14 +194,16 @@ using CLArithmeticAdditionFloatFixture = ArithmeticAdditionValidationFloatFixtur TEST_SUITE(Float) TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionFP16Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::F16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapes(), ArithmeticAdditionFP16Dataset), +FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapes(), framework::dataset::make("DataType", + DataType::F16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), ActivationFunctionsDataset)) { @@ -230,14 +213,16 @@ FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture TEST_SUITE_END() // FP16 TEST_SUITE(FP32) -FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFloatFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionFP32Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticAdditionFloatFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset)) { // Validate output validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapes(), ArithmeticAdditionFP32Dataset), +FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapes(), framework::dataset::make("DataType", + DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), ActivationFunctionsDataset)) { @@ -245,7 +230,8 @@ FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticAdditionFloatFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticAdditionFP32Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticAdditionFloatFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", + DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset)) { @@ -257,7 +243,7 @@ template using CLArithmeticAdditionBroadcastFloatFixture = ArithmeticAdditionBroadcastValidationFloatFixture; FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastFloatFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapesBroadcast(), - ArithmeticAdditionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset)) { @@ -265,7 +251,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticAdditionBroadcastFloatFixt validate(CLAccessor(_target), _reference); } FIXTURE_DATA_TEST_CASE(RunWithActivationBroadcast, CLArithmeticAdditionBroadcastFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::TinyShapesBroadcast(), - ArithmeticAdditionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), ActivationFunctionsDataset)) { @@ -274,7 +260,7 @@ FIXTURE_DATA_TEST_CASE(RunWithActivationBroadcast, CLArithmeticAdditionBroadcast } FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, CLArithmeticAdditionBroadcastFloatFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapesBroadcast(), - ArithmeticAdditionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset)) { diff --git a/tests/validation/CL/ArithmeticSubtraction.cpp b/tests/validation/CL/ArithmeticSubtraction.cpp index 2709fcaedb..00eba7f92a 100644 --- a/tests/validation/CL/ArithmeticSubtraction.cpp +++ b/tests/validation/CL/ArithmeticSubtraction.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -44,24 +44,6 @@ namespace validation namespace { /** Input data sets **/ -const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)), - framework::dataset::make("DataType", - DataType::U8)); -const auto ArithmeticSubtractionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8), framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("DataType", - DataType::QASYMM8)); -const auto ArithmeticSubtractionQASYMM8SignedDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("DataType", - DataType::QASYMM8_SIGNED)); -const auto ArithmeticSubtractionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)), - framework::dataset::make("DataType", - DataType::QSYMM16)); -const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), - framework::dataset::make("DataType", DataType::S16)); -const auto ArithmeticSubtractionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataType", DataType::F16)); -const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataType", DataType::F32)); const auto EmptyActivationFunctionsDataset = framework::dataset::make("ActivationInfo", { ActivationLayerInfo() }); const auto ActivationFunctionsDataset = framework::dataset::make("ActivationInfo", @@ -79,22 +61,19 @@ TEST_SUITE(ArithmeticSubtraction) // *INDENT-OFF* // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( - framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes }), - framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("Expected", { true, true, false, false})), + framework::dataset::make("Expected", { true, false, false})), input1_info, input2_info, output_info, expected) { ARM_COMPUTE_EXPECT(bool(CLArithmeticSubtraction::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS); @@ -159,7 +138,8 @@ using CLArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::U8)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -169,7 +149,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture, framew TEST_SUITE_END() // U8 TEST_SUITE(S16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::S16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -177,7 +158,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFixture, framew validate(CLAccessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, CLArithmeticSubtractionFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", + DataType::S16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -193,7 +175,7 @@ using CLArithmeticSubtractionQuantizedFixture = ArithmeticSubtractionValidationQ TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), - ArithmeticSubtractionQASYMM8Dataset), + framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), @@ -206,7 +188,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), - ArithmeticSubtractionQASYMM8SignedDataset), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 10) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), @@ -219,7 +201,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture TEST_SUITE_END() // QASYMM8_SIGNED TEST_SUITE(QSYMM16) FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionQuantizedFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallShapes(), - ArithmeticSubtractionQSYMM16Dataset), + framework::dataset::make("DataType", DataType::QSYMM16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })), @@ -237,7 +219,8 @@ using CLArithmeticSubtractionFloatFixture = ArithmeticSubtractionValidationFloat TEST_SUITE(Float) TEST_SUITE(FP16) -FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP16Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::F16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset), OutOfPlaceDataSet)) @@ -246,7 +229,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture, fram validate(CLAccessor(_target), _reference); } FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticSubtractionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::TinyShapes(), - ArithmeticSubtractionFP16Dataset), + framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), ActivationFunctionsDataset), InPlaceDataSet)) @@ -258,7 +241,7 @@ TEST_SUITE_END() // FP16 TEST_SUITE(FP32) FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapes(), - ArithmeticSubtractionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset), OutOfPlaceDataSet)) @@ -267,7 +250,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, CLArithmeticSubtractionFloatFixture, fra validate(CLAccessor(_target), _reference); } FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticSubtractionFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::TinyShapes(), - ArithmeticSubtractionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), ActivationFunctionsDataset), InPlaceDataSet)) @@ -277,7 +260,7 @@ FIXTURE_DATA_TEST_CASE(RunWithActivation, CLArithmeticSubtractionFloatFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapes(), - ArithmeticSubtractionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset), OutOfPlaceDataSet)) @@ -290,7 +273,7 @@ template using CLArithmeticSubtractionBroadcastFloatFixture = ArithmeticSubtractionBroadcastValidationFloatFixture; FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticSubtractionBroadcastFloatFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallShapesBroadcast(), - ArithmeticSubtractionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset), OutOfPlaceDataSet)) @@ -299,7 +282,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, CLArithmeticSubtractionBroadcastFloatF validate(CLAccessor(_target), _reference); } FIXTURE_DATA_TEST_CASE(RunWithActivationBroadcast, CLArithmeticSubtractionBroadcastFloatFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::TinyShapesBroadcast(), - ArithmeticSubtractionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), ActivationFunctionsDataset), OutOfPlaceDataSet)) @@ -309,7 +292,7 @@ FIXTURE_DATA_TEST_CASE(RunWithActivationBroadcast, CLArithmeticSubtractionBroadc } FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, CLArithmeticSubtractionBroadcastFloatFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(datasets::LargeShapesBroadcast(), - ArithmeticSubtractionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), EmptyActivationFunctionsDataset), OutOfPlaceDataSet)) diff --git a/tests/validation/CL/ElementwiseMax.cpp b/tests/validation/CL/ElementwiseMax.cpp index b9444b2795..225968efd1 100644 --- a/tests/validation/CL/ElementwiseMax.cpp +++ b/tests/validation/CL/ElementwiseMax.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -58,7 +58,7 @@ const auto ElementwiseMaxQASYMM8SignedDataset = combine(combine(framework::datas const auto ElementwiseMaxQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)), framework::dataset::make("DataType", DataType::QSYMM16)); -const auto ElementwiseMaxS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), +const auto ElementwiseMaxS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("DataType", DataType::S16)); const auto ElementwiseMaxFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataType", DataType::F16)); @@ -80,21 +80,18 @@ TEST_SUITE(ElementwiseMax) // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes }), framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("OutputInfo",{TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("Expected", { true, true, false, false})), + framework::dataset::make("Expected", { true, false, false})), input1_info, input2_info, output_info, expected) { ARM_COMPUTE_EXPECT(bool(CLElementwiseMax::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS); diff --git a/tests/validation/CL/ElementwiseMin.cpp b/tests/validation/CL/ElementwiseMin.cpp index 8f53b241ab..2a066908fa 100644 --- a/tests/validation/CL/ElementwiseMin.cpp +++ b/tests/validation/CL/ElementwiseMin.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -58,7 +58,7 @@ const auto ElementwiseMinQASYMM8SignedDataset = combine(combine(framework::datas const auto ElementwiseMinQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)), framework::dataset::make("DataType", DataType::QSYMM16)); -const auto ElementwiseMinS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), +const auto ElementwiseMinS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("DataType", DataType::S16)); const auto ElementwiseMinFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataType", DataType::F16)); @@ -80,21 +80,18 @@ TEST_SUITE(ElementwiseMin) // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes }), framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("Expected", { true, true, false, false})), + framework::dataset::make("Expected", { true, false, false})), input1_info, input2_info, output_info, expected) { ARM_COMPUTE_EXPECT(bool(CLElementwiseMin::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS); diff --git a/tests/validation/CL/ElementwiseSquaredDiff.cpp b/tests/validation/CL/ElementwiseSquaredDiff.cpp index 0a4ab6627b..4c732b0885 100644 --- a/tests/validation/CL/ElementwiseSquaredDiff.cpp +++ b/tests/validation/CL/ElementwiseSquaredDiff.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020 Arm Limited. + * Copyright (c) 2018-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -57,7 +57,7 @@ const auto ElementwiseSquaredDiffQASYMM8Dataset = combine(combine(framework::dat const auto ElementwiseSquaredDiffQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)), framework::dataset::make("DataType", DataType::QSYMM16)); -const auto ElementwiseSquaredDiffS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), +const auto ElementwiseSquaredDiffS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("DataType", DataType::S16)); const auto ElementwiseSquaredDiffFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataType", DataType::F16)); @@ -79,21 +79,18 @@ TEST_SUITE(ElementwiseSquaredDiff) // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes }), framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("Expected", { true, true, false, false})), + framework::dataset::make("Expected", { true, false, false})), input1_info, input2_info, output_info, expected) { ARM_COMPUTE_EXPECT(bool(CLElementwiseSquaredDiff::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS); diff --git a/tests/validation/CL/PReluLayer.cpp b/tests/validation/CL/PReluLayer.cpp index 043262d891..f3f1c8b1b8 100644 --- a/tests/validation/CL/PReluLayer.cpp +++ b/tests/validation/CL/PReluLayer.cpp @@ -56,7 +56,7 @@ const auto PReluLayerQASYMM8Dataset = combine(combine(framework::dataset::make(" const auto PReluLayerQASYMM8SIGNEDDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)); -const auto PReluLayerS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), +const auto PReluLayerS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("DataType", DataType::S16)); const auto PReluLayerFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("DataType", DataType::F16)); @@ -71,21 +71,18 @@ TEST_SUITE(PReluLayer) // clang-format off DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip( framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes }), framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16), - TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), + framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32), })), - framework::dataset::make("Expected", { true, true, false, false})), + framework::dataset::make("Expected", { true, false, false})), input1_info, input2_info, output_info, expected) { ARM_COMPUTE_EXPECT(bool(CLPReluLayer::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false))) == expected, framework::LogLevel::ERRORS); diff --git a/tests/validation/NEON/ArithmeticAddition.cpp b/tests/validation/NEON/ArithmeticAddition.cpp index ea6656eefe..213dbc1f5e 100644 --- a/tests/validation/NEON/ArithmeticAddition.cpp +++ b/tests/validation/NEON/ArithmeticAddition.cpp @@ -48,26 +48,6 @@ constexpr AbsoluteTolerance tolerance_quant(1); /**< Tolerance value for #else // !defined(__aarch64__) || defined(ENABLE_SVE) constexpr AbsoluteTolerance tolerance_quant(0); #endif // !defined(__aarch64__) || defined(ENABLE_SVE) - -/** Input data sets **/ -const auto ArithmeticAdditionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), framework::dataset::make("DataType", DataType::U8)), framework::dataset::make("DataType", - DataType::U8)); -const auto ArithmeticAdditionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), framework::dataset::make("DataType", DataType::S16)), - framework::dataset::make("DataType", DataType::S16)); -const auto ArithmeticAdditionS32Dataset = combine(combine(framework::dataset::make("DataType", { DataType::S32 }), framework::dataset::make("DataType", DataType::S32)), - framework::dataset::make("DataType", DataType::S32)); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -const auto ArithmeticAdditionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataType", DataType::F16)); -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -const auto ArithmeticAdditionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataType", DataType::F32)); -const auto ArithmeticAdditionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8), framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("DataType", DataType::QASYMM8)); -const auto ArithmeticAdditionQASYMM8SIGNEDDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)); -const auto ArithmeticAdditionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), framework::dataset::make("DataType", DataType::QSYMM16)), - framework::dataset::make("DataType", DataType::QSYMM16)); } // namespace TEST_SUITE(NEON) @@ -79,25 +59,22 @@ using NEArithmeticAdditionFixture = ArithmeticAdditionValidationFixtureset_is_resizable(false), @@ -127,7 +104,7 @@ TEST_CASE(NoPaddingAdded, framework::DatasetMode::PRECOMMIT) TEST_SUITE(Integer) TEST_SUITE(U8) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionU8Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::U8)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -136,14 +113,14 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework TEST_SUITE_END() // U8 TEST_SUITE(S16) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionS16Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticAdditionS16Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::S16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -152,7 +129,7 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture, framework TEST_SUITE_END() // S16 TEST_SUITE(S32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), ArithmeticAdditionS32Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::S32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -164,7 +141,7 @@ TEST_SUITE_END() // Integer TEST_SUITE(Float) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(F16) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), ArithmeticAdditionFP16Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::ALL, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -174,14 +151,14 @@ TEST_SUITE_END() // F16 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ TEST_SUITE(F32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), ArithmeticAdditionFP32Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), ArithmeticAdditionFP32Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticAdditionFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -192,7 +169,7 @@ template using NEArithmeticAdditionBroadcastFixture = ArithmeticAdditionBroadcastValidationFixture; FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticAdditionBroadcastFixture, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapesBroadcast(), - ArithmeticAdditionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -200,7 +177,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticAdditionBroadcastFixture, framework::DatasetMode::NIGHTLY, combine(combine(datasets::LargeShapesBroadcast(), - ArithmeticAdditionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP }))) { // Validate output @@ -220,7 +197,7 @@ TEST_SUITE(QASYMM8) FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionQuantizedFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionQASYMM8Dataset), + combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(5.f / 255.f, 20) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 255.f, 10) })), @@ -235,7 +212,7 @@ TEST_SUITE(QASYMM8_SIGNED) FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionQuantizedFixture, framework::DatasetMode::ALL, - combine(combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionQASYMM8SIGNEDDataset), + combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(0.5f, 20) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(0.5f, 10) })), @@ -246,7 +223,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, } FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticAdditionQuantizedBroadcastFixture, framework::DatasetMode::ALL, combine(combine(combine(combine(combine( - datasets::SmallShapesBroadcast(), ArithmeticAdditionQASYMM8SIGNEDDataset), + datasets::SmallShapesBroadcast(), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(0.5f, 20) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(0.5f, 10) })), @@ -261,7 +238,7 @@ TEST_SUITE(QSYMM16) FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticAdditionQuantizedFixture, framework::DatasetMode::PRECOMMIT, - combine(combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticAdditionQSYMM16Dataset), + combine(combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QSYMM16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), framework::dataset::make("Src0QInfo", { QuantizationInfo(1.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })), framework::dataset::make("Src1QInfo", { QuantizationInfo(2.f / 32768.f, 0), QuantizationInfo(5.f / 32768.f, 0) })), diff --git a/tests/validation/NEON/ArithmeticSubtraction.cpp b/tests/validation/NEON/ArithmeticSubtraction.cpp index 7a36893445..68213fb51f 100644 --- a/tests/validation/NEON/ArithmeticSubtraction.cpp +++ b/tests/validation/NEON/ArithmeticSubtraction.cpp @@ -50,39 +50,7 @@ constexpr AbsoluteTolerance tolerance_qasymm8(1); /**< Tolerance value fo #endif //__aarch64__ constexpr AbsoluteTolerance tolerance_qsymm16(1); /**< Tolerance value for comparing reference's output against implementation's output for quantized data types */ -/** Input data sets **/ -const auto ArithmeticSubtractionQASYMM8Dataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8), - framework::dataset::make("DataType", DataType::QASYMM8)), - framework::dataset::make("DataType", DataType::QASYMM8)); - -const auto ArithmeticSubtractionQASYMM8SIGNEDDataset = combine(combine(framework::dataset::make("DataType", DataType::QASYMM8_SIGNED), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), - framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)); - -const auto ArithmeticSubtractionQSYMM16Dataset = combine(combine(framework::dataset::make("DataType", DataType::QSYMM16), - framework::dataset::make("DataType", DataType::QSYMM16)), - framework::dataset::make("DataType", DataType::QSYMM16)); - -const auto ArithmeticSubtractionU8Dataset = combine(combine(framework::dataset::make("DataType", DataType::U8), - framework::dataset::make("DataType", DataType::U8)), - framework::dataset::make("DataType", DataType::U8)); - -const auto ArithmeticSubtractionS16Dataset = combine(combine(framework::dataset::make("DataType", { DataType::U8, DataType::S16 }), - framework::dataset::make("DataType", DataType::S16)), - framework::dataset::make("DataType", DataType::S16)); - -const auto ArithmeticSubtractionS32Dataset = combine(combine(framework::dataset::make("DataType", DataType::S32), - framework::dataset::make("DataType", DataType::S32)), - framework::dataset::make("DataType", DataType::S32)); -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -const auto ArithmeticSubtractionFP16Dataset = combine(combine(framework::dataset::make("DataType", DataType::F16), - framework::dataset::make("DataType", DataType::F16)), - framework::dataset::make("DataType", DataType::F16)); -#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ -const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset::make("DataType", DataType::F32), - framework::dataset::make("DataType", DataType::F32)), - framework::dataset::make("DataType", DataType::F32)); - +// Quantization Infomation DataSet const auto ArithmeticSubtractionQuantizationInfoDataset = combine(combine(framework::dataset::make("QuantizationInfoIn1", { QuantizationInfo(10, 120) }), framework::dataset::make("QuantizationInfoIn2", { QuantizationInfo(20, 110) })), framework::dataset::make("QuantizationInfoOut", { QuantizationInfo(15, 125) })); @@ -105,35 +73,31 @@ using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixtureset_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), policy)) == expected, framework::LogLevel::ERRORS); @@ -194,7 +158,8 @@ TEST_CASE(InvalidBroadcastBoth, framework::DatasetMode::ALL) TEST_SUITE_END() // InPlaceValidate TEST_SUITE(U8) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionU8Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::U8)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -210,7 +175,8 @@ using NEArithmeticSubtractionQSYMM16Fixture = ArithmeticSubtracti TEST_SUITE(Quantized) TEST_SUITE(QASYMM8) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionQASYMM8Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::QASYMM8)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), ArithmeticSubtractionQuantizationInfoDataset), InPlaceDataSet)) @@ -222,8 +188,7 @@ TEST_SUITE_END() // QASYMM8 TEST_SUITE(QASYMM8_SIGNED) FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8SignedFixture, framework::DatasetMode::ALL, combine(combine(combine(combine( - datasets::SmallShapes(), - ArithmeticSubtractionQASYMM8SIGNEDDataset), + datasets::SmallShapes(), framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), ArithmeticSubtractionQuantizationInfoSignedDataset), InPlaceDataSet)) @@ -234,7 +199,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQASYMM8SignedFixture, fr FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionQASYMM8SignedBroadcastFixture, framework::DatasetMode::ALL, combine(combine(combine(combine( datasets::SmallShapesBroadcast(), - ArithmeticSubtractionQASYMM8SIGNEDDataset), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), ArithmeticSubtractionQuantizationInfoSignedDataset), OutOfPlaceDataSet)) @@ -247,7 +212,7 @@ TEST_SUITE_END() // QASYMM8_SIGNED TEST_SUITE(QSYMM16) FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionQSYMM16Fixture, framework::DatasetMode::ALL, combine(combine(combine(combine( datasets::SmallShapes(), - ArithmeticSubtractionQSYMM16Dataset), + framework::dataset::make("DataType", DataType::QSYMM16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE })), ArithmeticSubtractionQuantizationInfoSymmetric), OutOfPlaceDataSet)) @@ -259,7 +224,8 @@ TEST_SUITE_END() // QSYMM16 TEST_SUITE_END() // Quantized TEST_SUITE(S16) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS16Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::S16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -267,7 +233,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framew validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS16Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", + DataType::S16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -277,7 +244,8 @@ FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture, framew TEST_SUITE_END() // S16 TEST_SUITE(S32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionS32Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::S32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -285,7 +253,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framew validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionS32Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", + DataType::S32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -297,7 +266,8 @@ TEST_SUITE_END() // S32 TEST_SUITE(Float) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_SUITE(F16) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP16Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::ALL, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::F16)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -308,7 +278,8 @@ TEST_SUITE_END() // F16 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ TEST_SUITE(F32) -FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), ArithmeticSubtractionFP32Dataset), +FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapes(), framework::dataset::make("DataType", + DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), InPlaceDataSet)) { @@ -316,7 +287,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEArithmeticSubtractionFixture, framewor validate(Accessor(_target), _reference); } -FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), ArithmeticSubtractionFP32Dataset), +FIXTURE_DATA_TEST_CASE(RunLarge, NEArithmeticSubtractionFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapes(), framework::dataset::make("DataType", + DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -328,7 +300,7 @@ template using NEArithmeticSubtractionBroadcastFixture = ArithmeticSubtractionBroadcastValidationFixture; FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionBroadcastFixture, framework::DatasetMode::PRECOMMIT, combine(combine(combine(datasets::SmallShapesBroadcast(), - ArithmeticSubtractionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { @@ -337,7 +309,7 @@ FIXTURE_DATA_TEST_CASE(RunSmallBroadcast, NEArithmeticSubtractionBroadcastFixtur } FIXTURE_DATA_TEST_CASE(RunLargeBroadcast, NEArithmeticSubtractionBroadcastFixture, framework::DatasetMode::NIGHTLY, combine(combine(combine(datasets::LargeShapesBroadcast(), - ArithmeticSubtractionFP32Dataset), + framework::dataset::make("DataType", DataType::F32)), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })), OutOfPlaceDataSet)) { diff --git a/tests/validation/fixtures/ArithmeticOperationsFixture.h b/tests/validation/fixtures/ArithmeticOperationsFixture.h index 1dfc2ce579..7aa716d676 100644 --- a/tests/validation/fixtures/ArithmeticOperationsFixture.h +++ b/tests/validation/fixtures/ArithmeticOperationsFixture.h @@ -46,15 +46,14 @@ class ArithmeticOperationGenericFixture : public framework::Fixture { public: template - void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1, - DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, + void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info, bool in_place) { _op = op; _act_info = act_info; _in_place = in_place; - _target = compute_target(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out); - _reference = compute_reference(shape0, shape1, data_type0, data_type1, output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out); + _target = compute_target(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out); + _reference = compute_reference(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out); } protected: @@ -64,13 +63,13 @@ protected: library->fill_tensor_uniform(tensor, i); } - TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { // Create tensors - TensorType ref_src1 = create_tensor(shape0, data_type0, 1, qinfo0); - TensorType ref_src2 = create_tensor(shape1, data_type1, 1, qinfo1); - TensorType dst = create_tensor(TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, qinfo_out); + TensorType ref_src1 = create_tensor(shape0, data_type, 1, qinfo0); + TensorType ref_src2 = create_tensor(shape1, data_type, 1, qinfo1); + TensorType dst = create_tensor(TensorShape::broadcast_shape(shape0, shape1), data_type, 1, qinfo_out); TensorType *dst_to_use = _in_place ? &ref_src1 : &dst; // Create and configure function @@ -104,8 +103,7 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &shape0, const TensorShape &shape1, - DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, + SimpleTensor compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { // current in-place implementation only supports same metadata of input and output tensors. @@ -113,9 +111,9 @@ protected: QuantizationInfo output_qinfo = _in_place ? qinfo0 : qinfo_out; // Create reference - SimpleTensor ref_src1{ shape0, data_type0, 1, qinfo0 }; - SimpleTensor ref_src2{ shape1, data_type1, 1, qinfo1 }; - SimpleTensor ref_dst{ TensorShape::broadcast_shape(shape0, shape1), output_data_type, 1, output_qinfo }; + SimpleTensor ref_src1{ shape0, data_type, 1, qinfo0 }; + SimpleTensor ref_src2{ shape1, data_type, 1, qinfo1 }; + SimpleTensor ref_dst{ TensorShape::broadcast_shape(shape0, shape1), data_type, 1, output_qinfo }; // Fill reference fill(ref_src1, 0); @@ -137,10 +135,10 @@ class ArithmeticAdditionBroadcastValidationFixture : public ArithmeticOperationG { public: template - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1, - output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false); + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy, + QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false); } }; @@ -149,10 +147,10 @@ class ArithmeticAdditionValidationFixture : public ArithmeticOperationGenericFix { public: template - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1, - output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false); + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy, + QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), false); } }; @@ -161,10 +159,10 @@ class ArithmeticAdditionBroadcastValidationFloatFixture : public ArithmeticOpera { public: template - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type0, data_type1, - output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false); + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy, + QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false); } }; @@ -173,10 +171,10 @@ class ArithmeticAdditionValidationFloatFixture : public ArithmeticOperationGener { public: template - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1, - output_data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false); + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy, + QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, false); } }; @@ -185,12 +183,11 @@ class ArithmeticAdditionValidationQuantizedFixture : public ArithmeticOperationG { public: template - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, - QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type0, data_type1, - output_data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false); + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy, + qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false); } }; @@ -199,11 +196,9 @@ class ArithmeticAdditionValidationQuantizedBroadcastFixture : public ArithmeticO { public: template - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, - ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape0, shape1, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), false); } }; @@ -213,10 +208,9 @@ class ArithmeticSubtractionBroadcastValidationFixture : public ArithmeticOperati { public: template - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, bool in_place) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape0, shape1, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place); } }; @@ -226,11 +220,10 @@ class ArithmeticSubtractionBroadcastValidationFloatFixture : public ArithmeticOp { public: template - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape0, shape1, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place); } }; @@ -240,10 +233,9 @@ class ArithmeticSubtractionValidationFixture : public ArithmeticOperationGeneric { public: template - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, bool in_place) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, bool in_place) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape, shape, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), in_place); } }; @@ -253,10 +245,9 @@ class ArithmeticSubtractionValidationFloatFixture : public ArithmeticOperationGe { public: template - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool in_place) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape, shape, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy, QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, in_place); } }; @@ -266,13 +257,11 @@ class ArithmeticSubtractionValidationQuantizedFixture : public ArithmeticOperati { public: template - void setup(const TensorShape &shape, DataType data_type0, DataType data_type1, DataType output_data_type, ConvertPolicy convert_policy, - QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place) + void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape, shape, - data_type0, data_type1, output_data_type, - convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place); + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy, + qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place); } }; @@ -281,11 +270,10 @@ class ArithmeticSubtractionValidationQuantizedBroadcastFixture : public Arithmet { public: template - void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type0, DataType data_type1, DataType output_data_type, - ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool in_place) + void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, + bool in_place) { - ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape0, shape1, - data_type0, data_type1, output_data_type, convert_policy, + ArithmeticOperationGenericFixture::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), in_place); } }; -- cgit v1.2.1