diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2019-06-18 10:23:22 +0100 |
---|---|---|
committer | Georgios Pinitas <georgios.pinitas@arm.com> | 2019-06-25 09:37:00 +0000 |
commit | 6997fc951e48a1bf8f7591f3b2c4c8d721331b96 (patch) | |
tree | 1cc2b28f5b2a5dbb8d7eb32755df4e8f28a1901d /tests/validation/reference | |
parent | 944170e1591ff23c9e6ede2201f0f6aba0f3439b (diff) | |
download | ComputeLibrary-6997fc951e48a1bf8f7591f3b2c4c8d721331b96.tar.gz |
COMPMID-2412: Add QSYMM16 support for ElementwiseAddition for CL
Arithmetic addition uses the same code as other element-wise operations.
Hence, adding QSYMM16 support for addition automatically adds the same
support for:
- arithmetic subtraction
- element-wise min
- element-wise max
- squared difference
Change-Id: If986102844f62e29dd23c03f9245910db43f9043
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/1384
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-by: Giuseppe Rossini <giuseppe.rossini@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Diffstat (limited to 'tests/validation/reference')
-rw-r--r-- | tests/validation/reference/ElementwiseOperations.cpp | 35 |
1 files changed, 32 insertions, 3 deletions
diff --git a/tests/validation/reference/ElementwiseOperations.cpp b/tests/validation/reference/ElementwiseOperations.cpp index 44eb417969..d5a37a0fae 100644 --- a/tests/validation/reference/ElementwiseOperations.cpp +++ b/tests/validation/reference/ElementwiseOperations.cpp @@ -184,10 +184,39 @@ SimpleTensor<uint8_t> arithmetic_operation(ArithmeticOperation op, const SimpleT } } +template <> +SimpleTensor<int16_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, SimpleTensor<int16_t> &dst, ConvertPolicy convert_policy) +{ + if(dst.data_type() == DataType::QSYMM16) + { + SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1); + SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2); + SimpleTensor<float> dst_tmp(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dst.data_type()); + + Coordinates id_src1{}; + Coordinates id_src2{}; + Coordinates id_dst{}; + + BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1_tmp, src2_tmp, dst_tmp, convert_policy, id_src1, id_src2, id_dst); + + dst = convert_to_symmetric<int16_t>(dst_tmp, dst.quantization_info()); + return dst; + } + else + { + // DataType::S16 + Coordinates id_src1{}; + Coordinates id_src2{}; + Coordinates id_dst{}; + + BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(op, src1, src2, dst, convert_policy, id_src1, id_src2, id_dst); + + return dst; + } +} + template SimpleTensor<int32_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int32_t> &src1, const SimpleTensor<int32_t> &src2, SimpleTensor<int32_t> &dst, ConvertPolicy convert_policy); -template SimpleTensor<int16_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, SimpleTensor<int16_t> &dst, - ConvertPolicy convert_policy); template SimpleTensor<int8_t> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, SimpleTensor<int8_t> &dst, ConvertPolicy convert_policy); template SimpleTensor<half> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, SimpleTensor<half> &dst, ConvertPolicy convert_policy); @@ -196,7 +225,7 @@ template SimpleTensor<float> arithmetic_operation(ArithmeticOperation op, const template <typename T> SimpleTensor<T> arithmetic_operation(ArithmeticOperation op, const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy) { - ARM_COMPUTE_ERROR_ON_MSG(dst_data_type == DataType::QASYMM8, "For QASYMM8, the quantized output tensor should be passed directly."); + ARM_COMPUTE_ERROR_ON_MSG(is_data_type_quantized(dst_data_type), "For quantized data types, the quantized output tensor should be passed directly."); SimpleTensor<T> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dst_data_type); arithmetic_operation<T>(op, src1, src2, dst, convert_policy); |