aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/reference/PixelWiseMultiplication.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2018-01-17 17:29:33 +0000
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:45:42 +0000
commit6259e5f9204abf31b811b1d002f68ce6504197bd (patch)
tree2dac943b3c794b66ccd90c8dc8e15d47699c5ea8 /tests/validation/reference/PixelWiseMultiplication.cpp
parent19d0547aa8c60b95766c195822769c7fea78aeaa (diff)
downloadComputeLibrary-6259e5f9204abf31b811b1d002f68ce6504197bd.tar.gz
COMPMID-787: Add CL support for broadcast multiply
Change-Id: I71f67789648ef05ccdedce77c7427bc0127b3a69 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/116741 Tested-by: Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'tests/validation/reference/PixelWiseMultiplication.cpp')
-rw-r--r--tests/validation/reference/PixelWiseMultiplication.cpp119
1 files changed, 89 insertions, 30 deletions
diff --git a/tests/validation/reference/PixelWiseMultiplication.cpp b/tests/validation/reference/PixelWiseMultiplication.cpp
index b3647fc9ce..546a886ac9 100644
--- a/tests/validation/reference/PixelWiseMultiplication.cpp
+++ b/tests/validation/reference/PixelWiseMultiplication.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,46 +41,105 @@ struct is_floating_point
{
};
+namespace
+{
+/** Compute the result of `src1 * src2 * scale`. The result type always matches the type of @p src2.
+ *
+ * @param[in] src1 An input value. Data types supported: U8/QS8/QS16/S16/F16/F32.
+ * @param[in] src2 An input value. Data types supported: same as @p src1.
+ * @param[in] scale Scale to apply after multiplication.
+ * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. For QS8 and QS16 scale must be 1.
+ * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate
+ * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
+ */
template <typename T1, typename T2>
-SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- SimpleTensor<T2> dst(src2.shape(), src2.data_type());
+ using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type;
- if(scale < 0)
- {
- ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
- }
+ const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale);
- using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type;
+ if(is_floating_point<T2>::value)
+ {
+ const auto result = static_cast<T2>(val);
- for(int i = 0; i < src1.num_elements(); ++i)
+ return result;
+ }
+ else
{
- double val = static_cast<intermediate_type>(src1[i]) * static_cast<intermediate_type>(src2[i]) * static_cast<double>(scale);
- if(is_floating_point<T2>::value)
+ double rounded_val = 0;
+ switch(rounding_policy)
{
- dst[i] = val;
+ case(RoundingPolicy::TO_ZERO):
+ rounded_val = support::cpp11::trunc(val);
+ break;
+ case(RoundingPolicy::TO_NEAREST_UP):
+ rounded_val = round_half_up(val);
+ break;
+ case(RoundingPolicy::TO_NEAREST_EVEN):
+ rounded_val = round_half_even(val);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported rounding policy");
}
- else
+
+ const auto result = static_cast<T2>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : rounded_val);
+
+ return result;
+ }
+}
+
+template <size_t dim>
+struct BroadcastUnroll
+{
+ template <typename T1, typename T2>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
+ {
+ const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]);
+ const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]);
+
+ id_src1.set(dim - 1, 0);
+ id_src2.set(dim - 1, 0);
+ id_dst.set(dim - 1, 0);
+
+ for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
{
- double rounded_val = 0;
- switch(rounding_policy)
- {
- case(RoundingPolicy::TO_ZERO):
- rounded_val = support::cpp11::trunc(val);
- break;
- case(RoundingPolicy::TO_NEAREST_UP):
- rounded_val = round_half_up(val);
- break;
- case(RoundingPolicy::TO_NEAREST_EVEN):
- rounded_val = round_half_even(val);
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported rounding policy");
- }
-
- dst[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : static_cast<T2>(rounded_val);
+ BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
+
+ id_src1[dim - 1] += !src1_is_broadcast;
+ id_src2[dim - 1] += !src2_is_broadcast;
}
}
+};
+
+template <>
+struct BroadcastUnroll<0>
+{
+ template <typename T1, typename T2>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
+ {
+ dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
+ }
+};
+} // namespace
+
+template <typename T1, typename T2>
+SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+{
+ SimpleTensor<T2> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type());
+
+ if(scale < 0)
+ {
+ ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
+ }
+
+ Coordinates id_src1, id_src2, id_dst;
+
+ BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
return dst;
}