aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/CPP
diff options
context:
space:
mode:
authorIsabella Gottardi <isabella.gottardi@arm.com>2017-10-30 15:28:13 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:35:24 +0000
commitb5908c257d554009a00de3aaa95b3721000ed185 (patch)
tree2c9fecbb241061e10bd9886bfe36056c4e6cf211 /tests/validation/CPP
parent05078ec491da8f282f4597b4cf1fe79cc16f4b22 (diff)
downloadComputeLibrary-b5908c257d554009a00de3aaa95b3721000ed185.tar.gz
COMPMID-653 - Arithmetic Subtraction, add support different datatype
Change-Id: I2b3d65c8d8a85ad67b9972713d06f047f5bcd1ae Reviewed-on: http://mpd-gerrit.cambridge.arm.com/93693 Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'tests/validation/CPP')
-rw-r--r--tests/validation/CPP/ArithmeticSubtraction.cpp13
-rw-r--r--tests/validation/CPP/ArithmeticSubtraction.h4
2 files changed, 10 insertions, 7 deletions
diff --git a/tests/validation/CPP/ArithmeticSubtraction.cpp b/tests/validation/CPP/ArithmeticSubtraction.cpp
index 80bdb15a49..bed2d37090 100644
--- a/tests/validation/CPP/ArithmeticSubtraction.cpp
+++ b/tests/validation/CPP/ArithmeticSubtraction.cpp
@@ -34,23 +34,26 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy)
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy)
{
- SimpleTensor<T> result(src1.shape(), dst_data_type);
+ SimpleTensor<T3> result(src1.shape(), dst_data_type);
- using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type;
+ using intermediate_type = typename common_promoted_signed_type<typename std::conditional<sizeof(T1) >= sizeof(T2), T1, T2>::type >::intermediate_type;
for(int i = 0; i < src1.num_elements(); ++i)
{
intermediate_type val = static_cast<intermediate_type>(src1[i]) - static_cast<intermediate_type>(src2[i]);
- result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val);
+ result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(val) : static_cast<T3>(val);
}
return result;
}
template SimpleTensor<uint8_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<int8_t> arithmetic_subtraction(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
template SimpleTensor<half> arithmetic_subtraction(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
diff --git a/tests/validation/CPP/ArithmeticSubtraction.h b/tests/validation/CPP/ArithmeticSubtraction.h
index 18b0d121a0..9308314bda 100644
--- a/tests/validation/CPP/ArithmeticSubtraction.h
+++ b/tests/validation/CPP/ArithmeticSubtraction.h
@@ -35,8 +35,8 @@ namespace validation
{
namespace reference
{
-template <typename T>
-SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy);
} // namespace reference
} // namespace validation
} // namespace test