diff options
author | Isabella Gottardi <isabella.gottardi@arm.com> | 2017-10-30 15:28:13 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:35:24 +0000 |
commit | b5908c257d554009a00de3aaa95b3721000ed185 (patch) | |
tree | 2c9fecbb241061e10bd9886bfe36056c4e6cf211 /tests/validation/CPP/ArithmeticSubtraction.cpp | |
parent | 05078ec491da8f282f4597b4cf1fe79cc16f4b22 (diff) | |
download | ComputeLibrary-b5908c257d554009a00de3aaa95b3721000ed185.tar.gz |
COMPMID-653 - Arithmetic Subtraction, add support different datatype
Change-Id: I2b3d65c8d8a85ad67b9972713d06f047f5bcd1ae
Reviewed-on: http://mpd-gerrit.cambridge.arm.com/93693
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'tests/validation/CPP/ArithmeticSubtraction.cpp')
-rw-r--r-- | tests/validation/CPP/ArithmeticSubtraction.cpp | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/tests/validation/CPP/ArithmeticSubtraction.cpp b/tests/validation/CPP/ArithmeticSubtraction.cpp index 80bdb15a49..bed2d37090 100644 --- a/tests/validation/CPP/ArithmeticSubtraction.cpp +++ b/tests/validation/CPP/ArithmeticSubtraction.cpp @@ -34,23 +34,26 @@ namespace validation { namespace reference { -template <typename T> -SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy) +template <typename T1, typename T2, typename T3> +SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy) { - SimpleTensor<T> result(src1.shape(), dst_data_type); + SimpleTensor<T3> result(src1.shape(), dst_data_type); - using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type; + using intermediate_type = typename common_promoted_signed_type<typename std::conditional<sizeof(T1) >= sizeof(T2), T1, T2>::type >::intermediate_type; for(int i = 0; i < src1.num_elements(); ++i) { intermediate_type val = static_cast<intermediate_type>(src1[i]) - static_cast<intermediate_type>(src2[i]); - result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val); + result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(val) : static_cast<T3>(val); } return result; } template SimpleTensor<uint8_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int8_t> arithmetic_subtraction(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<half> arithmetic_subtraction(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy); |