aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/PixelWiseMultiplication.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/reference/PixelWiseMultiplication.cpp')
-rw-r--r--tests/validation/reference/PixelWiseMultiplication.cpp57
1 files changed, 29 insertions, 28 deletions
diff --git a/tests/validation/reference/PixelWiseMultiplication.cpp b/tests/validation/reference/PixelWiseMultiplication.cpp
index 2b4c849c39..3e21fca72a 100644
--- a/tests/validation/reference/PixelWiseMultiplication.cpp
+++ b/tests/validation/reference/PixelWiseMultiplication.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -52,16 +52,16 @@ namespace
* @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate
* @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
*/
-template <typename T1, typename T2>
-T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+template <typename T1, typename T2, typename T3>
+T3 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type;
+ using intermediate_type = typename common_promoted_signed_type<T1, T2, T3>::intermediate_type;
const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale);
- if(is_floating_point<T2>::value)
+ if(is_floating_point<T3>::value)
{
- const auto result = static_cast<T2>(val);
+ const auto result = static_cast<T3>(val);
return result;
}
@@ -83,7 +83,7 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy,
ARM_COMPUTE_ERROR("Unsupported rounding policy");
}
- const auto result = static_cast<T2>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : rounded_val);
+ const auto result = static_cast<T3>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T3>(rounded_val) : rounded_val);
return result;
}
@@ -92,8 +92,8 @@ T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy,
template <size_t dim>
struct BroadcastUnroll
{
- template <typename T1, typename T2>
- static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ template <typename T1, typename T2, typename T3>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
{
@@ -117,23 +117,23 @@ struct BroadcastUnroll
template <>
struct BroadcastUnroll<0>
{
- template <typename T1, typename T2>
- static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ template <typename T1, typename T2, typename T3>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T3> &dst,
float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
{
- dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
+ dst[coord2index(dst.shape(), id_dst)] = mul<T1, T2, T3>(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
}
};
} // namespace
-template <typename T1, typename T2>
-SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+template <typename T1, typename T2, typename T3>
+SimpleTensor<T3> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ DataType dt_out, const QuantizationInfo &qout)
{
ARM_COMPUTE_UNUSED(qout);
- SimpleTensor<T2> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type());
+ SimpleTensor<T3> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out);
if(scale < 0)
{
@@ -151,15 +151,15 @@ SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const S
template <>
SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<uint8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+ DataType dt_out, const QuantizationInfo &qout)
{
- SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
+ SimpleTensor<uint8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
if(src1.data_type() == DataType::QASYMM8 && src2.data_type() == DataType::QASYMM8)
{
SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
- SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
+ SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
dst = convert_to_asymmetric<uint8_t>(dst_tmp, qout);
}
else
@@ -179,15 +179,15 @@ SimpleTensor<uint8_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src
template <>
SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1, const SimpleTensor<int8_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+ DataType dt_out, const QuantizationInfo &qout)
{
- SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
+ SimpleTensor<int8_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
if(src1.data_type() == DataType::QASYMM8_SIGNED && src2.data_type() == DataType::QASYMM8_SIGNED)
{
SimpleTensor<float> src1_tmp = convert_from_asymmetric(src1);
SimpleTensor<float> src2_tmp = convert_from_asymmetric(src2);
- SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
+ SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
dst = convert_to_asymmetric<int8_t>(dst_tmp, qout);
}
else
@@ -207,15 +207,15 @@ SimpleTensor<int8_t> pixel_wise_multiplication(const SimpleTensor<int8_t> &src1,
template <>
SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
- const QuantizationInfo &qout)
+ DataType dt_out, const QuantizationInfo &qout)
{
- SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type(), 1, qout);
+ SimpleTensor<int16_t> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), dt_out, 1, qout);
if(src1.data_type() == DataType::QSYMM16 && src2.data_type() == DataType::QSYMM16)
{
SimpleTensor<float> src1_tmp = convert_from_symmetric<int16_t>(src1);
SimpleTensor<float> src2_tmp = convert_from_symmetric<int16_t>(src2);
- SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, qout);
+ SimpleTensor<float> dst_tmp = pixel_wise_multiplication<float, float, float>(src1_tmp, src2_tmp, scale, convert_policy, rounding_policy, DataType::F32, qout);
dst = convert_to_symmetric<int16_t>(dst_tmp, qout);
}
else
@@ -234,9 +234,10 @@ SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src
}
// *INDENT-OFF*
// clang-format off
-template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
-template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
-template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, const QuantizationInfo &qout);
+template SimpleTensor<int16_t> pixel_wise_multiplication(const SimpleTensor<uint8_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
+template SimpleTensor<int32_t> pixel_wise_multiplication(const SimpleTensor<int16_t> &src1, const SimpleTensor<int16_t> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
+template SimpleTensor<float> pixel_wise_multiplication(const SimpleTensor<float> &src1, const SimpleTensor<float> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
+template SimpleTensor<half_float::half> pixel_wise_multiplication(const SimpleTensor<half_float::half> &src1, const SimpleTensor<half_float::half> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, DataType dt_out, const QuantizationInfo &qout);
// clang-format on
// *INDENT-ON*
} // namespace reference