diff options
Diffstat (limited to 'tests/validation/CPP')
-rw-r--r-- | tests/validation/CPP/ArithmeticSubtraction.cpp | 13 | ||||
-rw-r--r-- | tests/validation/CPP/ArithmeticSubtraction.h | 4 |
2 files changed, 10 insertions, 7 deletions
diff --git a/tests/validation/CPP/ArithmeticSubtraction.cpp b/tests/validation/CPP/ArithmeticSubtraction.cpp index 80bdb15a49..bed2d37090 100644 --- a/tests/validation/CPP/ArithmeticSubtraction.cpp +++ b/tests/validation/CPP/ArithmeticSubtraction.cpp @@ -34,23 +34,26 @@ namespace validation { namespace reference { -template <typename T> -SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy) +template <typename T1, typename T2, typename T3> +SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy) { - SimpleTensor<T> result(src1.shape(), dst_data_type); + SimpleTensor<T3> result(src1.shape(), dst_data_type); - using intermediate_type = typename common_promoted_signed_type<T>::intermediate_type; + using intermediate_type = typename common_promoted_signed_type<typename std::conditional<sizeof(T1) >= sizeof(T2), T1, T2>::type >::intermediate_type; for(int i = 0; i < src1.num_elements(); ++i) { intermediate_type val = static_cast<intermediate_type>(src1[i]) - static_cast<intermediate_type>(src2[i]); - result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T>(val) : static_cast<T>(val); + result[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(val) : static_cast<T3>(val); } return result; } template SimpleTensor<uint8_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<uint8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int16_t> arithmetic_subtraction(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<int8_t> arithmetic_subtraction(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, DataType dst_data_type, ConvertPolicy convert_policy); template SimpleTensor<half> arithmetic_subtraction(const SimpleTensor<half> &src1, const SimpleTensor<half> &src2, DataType dst_data_type, ConvertPolicy convert_policy); diff --git a/tests/validation/CPP/ArithmeticSubtraction.h b/tests/validation/CPP/ArithmeticSubtraction.h index 18b0d121a0..9308314bda 100644 --- a/tests/validation/CPP/ArithmeticSubtraction.h +++ b/tests/validation/CPP/ArithmeticSubtraction.h @@ -35,8 +35,8 @@ namespace validation { namespace reference { -template <typename T> -SimpleTensor<T> arithmetic_subtraction(const SimpleTensor<T> &src1, const SimpleTensor<T> &src2, DataType dst_data_type, ConvertPolicy convert_policy); +template <typename T1, typename T2, typename T3> +SimpleTensor<T3> arithmetic_subtraction(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, DataType dst_data_type, ConvertPolicy convert_policy); } // namespace reference } // namespace validation } // namespace test |