From 6259e5f9204abf31b811b1d002f68ce6504197bd Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Wed, 17 Jan 2018 17:29:33 +0000 Subject: COMPMID-787: Add CL support for broadcast multiply Change-Id: I71f67789648ef05ccdedce77c7427bc0127b3a69 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/116741 Tested-by: Jenkins Reviewed-by: Georgios Pinitas Reviewed-by: Anthony Barbier --- tests/validation/CL/PixelWiseMultiplication.cpp | 6 ++ .../fixtures/PixelWiseMultiplicationFixture.h | 49 ++++++--- .../reference/PixelWiseMultiplication.cpp | 119 +++++++++++++++------ 3 files changed, 127 insertions(+), 47 deletions(-) (limited to 'tests') diff --git a/tests/validation/CL/PixelWiseMultiplication.cpp b/tests/validation/CL/PixelWiseMultiplication.cpp index 45f57af3fc..6a71175f51 100644 --- a/tests/validation/CL/PixelWiseMultiplication.cpp +++ b/tests/validation/CL/PixelWiseMultiplication.cpp @@ -86,6 +86,8 @@ template using CLPixelWiseMultiplicationToQS16Fixture = PixelWiseMultiplicationValidationFixture; template using CLFixedPointPixelWiseMultiplicationFixture = FixedPointPixelWiseMultiplicationValidationFixture; +template +using CLPixelWiseMultiplicationBroadcastFixture = PixelWiseMultiplicationBroadcastValidationFixture; TEST_SUITE(CL) TEST_SUITE(PixelWiseMultiplication) @@ -169,6 +171,10 @@ TEST_SUITE_END() // ScaleUnity TEST_SUITE_END() // QS16 +TEST_SUITE(Broadcast) +PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, BroadcastFixture, PRECOMMIT, SmallShapesBroadcast(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f)) +TEST_SUITE_END() // Broadcast + TEST_SUITE_END() // FixedPointPixelWiseMultiplication TEST_SUITE_END() } // namespace validation diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h index 7428fb5cb7..b9f19f3e77 100644 --- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h +++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,19 +40,20 @@ namespace test namespace validation { template -class PixelWiseMultiplicationValidationFixture : public framework::Fixture +class PixelWiseMultiplicationBroadcastValidationFixture : public framework::Fixture { public: template - void setup(TensorShape shape, - DataType dt_in1, - DataType dt_in2, - float scale, - ConvertPolicy convert_policy, - RoundingPolicy rounding_policy) + void setup(const TensorShape &shape0, + const TensorShape &shape1, + DataType dt_in1, + DataType dt_in2, + float scale, + ConvertPolicy convert_policy, + RoundingPolicy rounding_policy) { - _target = compute_target(shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy); - _reference = compute_reference(shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy); + _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy); + _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy); } protected: @@ -62,12 +63,13 @@ protected: library->fill_tensor_uniform(tensor, seed_offset); } - TensorType compute_target(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { // Create tensors - TensorType src1 = create_tensor(shape, dt_in1); - TensorType src2 = create_tensor(shape, dt_in2); - TensorType dst = create_tensor(shape, dt_in2); + TensorType src1 = create_tensor(shape0, dt_in1); + TensorType src2 = create_tensor(shape1, dt_in2); + TensorType dst = create_tensor(TensorShape::broadcast_shape(shape0, shape1), dt_in2); // Create and configure function FunctionType multiply; @@ -96,11 +98,12 @@ protected: return dst; } - SimpleTensor compute_reference(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) + SimpleTensor compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { // Create reference - SimpleTensor src1{ shape, dt_in1 }; - SimpleTensor src2{ shape, dt_in2 }; + SimpleTensor src1{ shape0, dt_in1 }; + SimpleTensor src2{ shape1, dt_in2 }; // Fill reference fill(src1, 0); @@ -112,6 +115,18 @@ protected: TensorType _target{}; SimpleTensor _reference{}; }; + +template +class PixelWiseMultiplicationValidationFixture : public PixelWiseMultiplicationBroadcastValidationFixture +{ +public: + template + void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) + { + PixelWiseMultiplicationBroadcastValidationFixture::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy); + } +}; + } // namespace validation } // namespace test } // namespace arm_compute diff --git a/tests/validation/reference/PixelWiseMultiplication.cpp b/tests/validation/reference/PixelWiseMultiplication.cpp index b3647fc9ce..546a886ac9 100644 --- a/tests/validation/reference/PixelWiseMultiplication.cpp +++ b/tests/validation/reference/PixelWiseMultiplication.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -41,46 +41,105 @@ struct is_floating_point { }; +namespace +{ +/** Compute the result of `src1 * src2 * scale`. The result type always matches the type of @p src2. + * + * @param[in] src1 An input value. Data types supported: U8/QS8/QS16/S16/F16/F32. + * @param[in] src2 An input value. Data types supported: same as @p src1. + * @param[in] scale Scale to apply after multiplication. + * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. For QS8 and QS16 scale must be 1. + * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate + * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even. + */ template -SimpleTensor pixel_wise_multiplication(const SimpleTensor &src1, const SimpleTensor &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) +T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { - SimpleTensor dst(src2.shape(), src2.data_type()); + using intermediate_type = typename common_promoted_signed_type::intermediate_type; - if(scale < 0) - { - ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); - } + const double val = static_cast(src1) * static_cast(src2) * static_cast(scale); - using intermediate_type = typename common_promoted_signed_type::intermediate_type; + if(is_floating_point::value) + { + const auto result = static_cast(val); - for(int i = 0; i < src1.num_elements(); ++i) + return result; + } + else { - double val = static_cast(src1[i]) * static_cast(src2[i]) * static_cast(scale); - if(is_floating_point::value) + double rounded_val = 0; + switch(rounding_policy) { - dst[i] = val; + case(RoundingPolicy::TO_ZERO): + rounded_val = support::cpp11::trunc(val); + break; + case(RoundingPolicy::TO_NEAREST_UP): + rounded_val = round_half_up(val); + break; + case(RoundingPolicy::TO_NEAREST_EVEN): + rounded_val = round_half_even(val); + break; + default: + ARM_COMPUTE_ERROR("Unsupported rounding policy"); } - else + + const auto result = static_cast((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast(rounded_val) : rounded_val); + + return result; + } +} + +template +struct BroadcastUnroll +{ + template + static void unroll(const SimpleTensor &src1, const SimpleTensor &src2, SimpleTensor &dst, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, + Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) + { + const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]); + const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]); + + id_src1.set(dim - 1, 0); + id_src2.set(dim - 1, 0); + id_dst.set(dim - 1, 0); + + for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1]) { - double rounded_val = 0; - switch(rounding_policy) - { - case(RoundingPolicy::TO_ZERO): - rounded_val = support::cpp11::trunc(val); - break; - case(RoundingPolicy::TO_NEAREST_UP): - rounded_val = round_half_up(val); - break; - case(RoundingPolicy::TO_NEAREST_EVEN): - rounded_val = round_half_even(val); - break; - default: - ARM_COMPUTE_ERROR("Unsupported rounding policy"); - } - - dst[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast(rounded_val) : static_cast(rounded_val); + BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); + + id_src1[dim - 1] += !src1_is_broadcast; + id_src2[dim - 1] += !src2_is_broadcast; } } +}; + +template <> +struct BroadcastUnroll<0> +{ + template + static void unroll(const SimpleTensor &src1, const SimpleTensor &src2, SimpleTensor &dst, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, + Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) + { + dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy); + } +}; +} // namespace + +template +SimpleTensor pixel_wise_multiplication(const SimpleTensor &src1, const SimpleTensor &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) +{ + SimpleTensor dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type()); + + if(scale < 0) + { + ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); + } + + Coordinates id_src1, id_src2, id_dst; + + BroadcastUnroll::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); return dst; } -- cgit v1.2.1