diff options
author | Michele Di Giorgio <michele.digiorgio@arm.com> | 2018-01-17 17:29:33 +0000 |
---|---|---|
committer | Anthony Barbier <anthony.barbier@arm.com> | 2018-11-02 16:45:42 +0000 |
commit | 6259e5f9204abf31b811b1d002f68ce6504197bd (patch) | |
tree | 2dac943b3c794b66ccd90c8dc8e15d47699c5ea8 | |
parent | 19d0547aa8c60b95766c195822769c7fea78aeaa (diff) | |
download | ComputeLibrary-6259e5f9204abf31b811b1d002f68ce6504197bd.tar.gz |
COMPMID-787: Add CL support for broadcast multiply
Change-Id: I71f67789648ef05ccdedce77c7427bc0127b3a69
Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/116741
Tested-by: Jenkins <bsgcomp@arm.com>
Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com>
Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
7 files changed, 227 insertions, 92 deletions
diff --git a/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h b/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h index 6746a49dde..1ecd9be8cd 100644 --- a/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h +++ b/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -76,6 +76,7 @@ public: // Inherited methods overridden: void run(const Window &window, cl::CommandQueue &queue) override; + BorderSize border_size() const override; private: const ICLTensor *_input1; diff --git a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h index d57bfda2c1..75b67cd17c 100644 --- a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h +++ b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -37,15 +37,17 @@ class CLPixelWiseMultiplication : public ICLSimpleFunction public: /** Initialise the kernel's inputs, output and convertion policy. * - * @param[in] input1 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32. - * @param[in] input2 An input tensor. Data types supported: same as @p input1. - * @param[out] output The output tensor, Data types supported: same as @p input1. Note: U8 (QS8, QS16) requires both inputs to be U8 (QS8, QS16). - * @param[in] scale Scale to apply after multiplication. - * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. For QS8 and QS16 scale must be 1. - * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate - * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even. + * @param[in, out] input1 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32. + * The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0. + * @param[in, out] input2 An input tensor. Data types supported: same as @p input1. + * The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0. + * @param[out] output The output tensor, Data types supported: same as @p input1. Note: U8 (QS8, QS16) requires both inputs to be U8 (QS8, QS16). + * @param[in] scale Scale to apply after multiplication. + * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. For QS8 and QS16 scale must be 1. + * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate + * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even. */ - void configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale, + void configure(ICLTensor *input1, ICLTensor *input2, ICLTensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy); /** Static function to check if given info will lead to a valid configuration of @ref CLPixelWiseMultiplication * diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp index 6dba9c0f95..f30ba61b9a 100644 --- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp +++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -42,6 +42,8 @@ using namespace arm_compute; namespace { +constexpr unsigned int num_elems_processed_per_iteration = 16; + Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { @@ -50,10 +52,13 @@ Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative."); + const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2); + if(is_data_type_fixed_point(input1->data_type())) { // All data types must be all QS8 or all QS16 @@ -62,12 +67,12 @@ Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, } // Validate in case of configured output - if((output != nullptr) && (output->total_size() != 0)) + if(output->total_size() > 0) { ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8), "Output can only be U8 if both inputs are U8"); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output); if(is_data_type_fixed_point(input1->data_type())) { @@ -80,18 +85,36 @@ Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) { - constexpr unsigned int num_elems_processed_per_iteration = 16; + const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2); + const TensorShape &out_shape = broadcast_pair.first; + const ValidRegion &valid_region = broadcast_pair.second; - Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + // Auto initialize output if not initialized + { + set_shape_if_empty(*output, out_shape); + + if(input1->data_type() == DataType::S16 || input2->data_type() == DataType::S16) + { + set_format_if_unknown(*output, Format::S16); + } + else if(input1->data_type() == DataType::F32 || input2->data_type() == DataType::F32) + { + set_format_if_unknown(*output, Format::F32); + } + } + + Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration)); + Window win_input1 = win.broadcast_if_dimension_le_one(*input1); + Window win_input2 = win.broadcast_if_dimension_le_one(*input2); AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration); AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration); AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); - bool window_changed = update_window_and_padding(win, input1_access, input2_access, output_access); + bool window_changed = update_window_and_padding(win_input1, input1_access) + || update_window_and_padding(win_input2, input2_access) + || update_window_and_padding(win, output_access); - ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), - input2->valid_region()); output_access.set_valid_region(win, valid_region); Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; @@ -108,24 +131,13 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); - - // Auto initialize output if not initialized - { - set_shape_if_empty(*output->info(), input1->info()->tensor_shape()); - - if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16) - { - set_format_if_unknown(*output->info(), Format::S16); - } - else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32) - { - set_format_if_unknown(*output->info(), Format::F32); - } - } - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), scale, overflow_policy, rounding_policy)); + // Configure kernel window + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + _input1 = input1; _input2 = input2; _output = output; @@ -207,15 +219,13 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I _kernel.setArg(idx++, scale); } - // Configure kernel window - auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); - ARM_COMPUTE_ERROR_THROW_ON(win_config.first); ICLKernel::configure(win_config.second); } Status CLPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { + ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy)); ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); @@ -227,16 +237,47 @@ void CLPixelWiseMultiplicationKernel::run(const Window &window, cl::CommandQueue ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window); - Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ); - Window slice = collapsed.first_slice_window_3D(); + const TensorShape &in_shape1 = _input1->info()->tensor_shape(); + const TensorShape &in_shape2 = _input2->info()->tensor_shape(); + const TensorShape &out_shape = _output->info()->tensor_shape(); + + bool can_collapse = true; + if(std::min(in_shape1.total_size(), in_shape2.total_size()) > 1) + { + can_collapse = (std::min(in_shape1.num_dimensions(), in_shape2.num_dimensions()) > Window::DimZ); + for(size_t d = Window::DimZ; can_collapse && (d < out_shape.num_dimensions()); ++d) + { + can_collapse = (in_shape1[d] == in_shape2[d]); + } + } + + bool has_collapsed = false; + Window collapsed = can_collapse ? window.collapse_if_possible(ICLKernel::window(), Window::DimZ, &has_collapsed) : window; + + const TensorShape &in_shape1_collapsed = has_collapsed ? in_shape1.collapsed_from(Window::DimZ) : in_shape1; + const TensorShape &in_shape2_collapsed = has_collapsed ? in_shape2.collapsed_from(Window::DimZ) : in_shape2; + + Window slice = collapsed.first_slice_window_3D(); + Window slice_input1 = slice.broadcast_if_dimension_le_one(in_shape1_collapsed); + Window slice_input2 = slice.broadcast_if_dimension_le_one(in_shape2_collapsed); do { unsigned int idx = 0; - add_3D_tensor_argument(idx, _input1, slice); - add_3D_tensor_argument(idx, _input2, slice); + add_3D_tensor_argument(idx, _input1, slice_input1); + add_3D_tensor_argument(idx, _input2, slice_input2); add_3D_tensor_argument(idx, _output, slice); enqueue(queue, *this, slice); + + collapsed.slide_window_slice_3D(slice_input1); + collapsed.slide_window_slice_3D(slice_input2); } while(collapsed.slide_window_slice_3D(slice)); } + +BorderSize CLPixelWiseMultiplicationKernel::border_size() const +{ + const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0)); + const unsigned int border = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize); + return BorderSize(0, border, 0, 0); +} diff --git a/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp b/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp index c78f94476e..b4c20db3da 100644 --- a/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp +++ b/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016, 2017 ARM Limited. + * Copyright (c) 2016-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -23,6 +23,7 @@ */ #include "arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h" +#include "arm_compute/core/CL/ICLTensor.h" #include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h" #include "support/ToolchainSupport.h" @@ -30,16 +31,26 @@ using namespace arm_compute; -void CLPixelWiseMultiplication::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale, +void CLPixelWiseMultiplication::configure(ICLTensor *input1, ICLTensor *input2, ICLTensor *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { auto k = arm_compute::support::cpp14::make_unique<CLPixelWiseMultiplicationKernel>(); k->configure(input1, input2, output, scale, overflow_policy, rounding_policy); _kernel = std::move(k); + + if(output->info()->dimension(0) > 1) + { + ICLTensor *broadcasted_info = (input1->info()->dimension(0) == 1) ? input1 : input2; + + if(broadcasted_info->info()->dimension(0) == 1) + { + _border_handler.configure(broadcasted_info, _kernel->border_size(), BorderMode::REPLICATE); + } + } } Status CLPixelWiseMultiplication::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale, ConvertPolicy overflow_policy, RoundingPolicy rounding_policy) { return CLPixelWiseMultiplicationKernel::validate(input1, input2, output, scale, overflow_policy, rounding_policy); -}
\ No newline at end of file +} diff --git a/tests/validation/CL/PixelWiseMultiplication.cpp b/tests/validation/CL/PixelWiseMultiplication.cpp index 45f57af3fc..6a71175f51 100644 --- a/tests/validation/CL/PixelWiseMultiplication.cpp +++ b/tests/validation/CL/PixelWiseMultiplication.cpp @@ -86,6 +86,8 @@ template <typename T> using CLPixelWiseMultiplicationToQS16Fixture = PixelWiseMultiplicationValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, qint16_t>; template <typename T> using CLFixedPointPixelWiseMultiplicationFixture = FixedPointPixelWiseMultiplicationValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T>; +template <typename T> +using CLPixelWiseMultiplicationBroadcastFixture = PixelWiseMultiplicationBroadcastValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, float>; TEST_SUITE(CL) TEST_SUITE(PixelWiseMultiplication) @@ -169,6 +171,10 @@ TEST_SUITE_END() // ScaleUnity TEST_SUITE_END() // QS16 +TEST_SUITE(Broadcast) +PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, BroadcastFixture<float>, PRECOMMIT, SmallShapesBroadcast(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f)) +TEST_SUITE_END() // Broadcast + TEST_SUITE_END() // FixedPointPixelWiseMultiplication TEST_SUITE_END() } // namespace validation diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h index 7428fb5cb7..b9f19f3e77 100644 --- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h +++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -40,19 +40,20 @@ namespace test namespace validation { template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> -class PixelWiseMultiplicationValidationFixture : public framework::Fixture +class PixelWiseMultiplicationBroadcastValidationFixture : public framework::Fixture { public: template <typename...> - void setup(TensorShape shape, - DataType dt_in1, - DataType dt_in2, - float scale, - ConvertPolicy convert_policy, - RoundingPolicy rounding_policy) + void setup(const TensorShape &shape0, + const TensorShape &shape1, + DataType dt_in1, + DataType dt_in2, + float scale, + ConvertPolicy convert_policy, + RoundingPolicy rounding_policy) { - _target = compute_target(shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy); - _reference = compute_reference(shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy); + _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy); + _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy); } protected: @@ -62,12 +63,13 @@ protected: library->fill_tensor_uniform(tensor, seed_offset); } - TensorType compute_target(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { // Create tensors - TensorType src1 = create_tensor<TensorType>(shape, dt_in1); - TensorType src2 = create_tensor<TensorType>(shape, dt_in2); - TensorType dst = create_tensor<TensorType>(shape, dt_in2); + TensorType src1 = create_tensor<TensorType>(shape0, dt_in1); + TensorType src2 = create_tensor<TensorType>(shape1, dt_in2); + TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_in2); // Create and configure function FunctionType multiply; @@ -96,11 +98,12 @@ protected: return dst; } - SimpleTensor<T2> compute_reference(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) + SimpleTensor<T2> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { // Create reference - SimpleTensor<T1> src1{ shape, dt_in1 }; - SimpleTensor<T2> src2{ shape, dt_in2 }; + SimpleTensor<T1> src1{ shape0, dt_in1 }; + SimpleTensor<T2> src2{ shape1, dt_in2 }; // Fill reference fill(src1, 0); @@ -112,6 +115,18 @@ protected: TensorType _target{}; SimpleTensor<T2> _reference{}; }; + +template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2> +class PixelWiseMultiplicationValidationFixture : public PixelWiseMultiplicationBroadcastValidationFixture<TensorType, AccessorType, FunctionType, T1, T2> +{ +public: + template <typename...> + void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) + { + PixelWiseMultiplicationBroadcastValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy); + } +}; + } // namespace validation } // namespace test } // namespace arm_compute diff --git a/tests/validation/reference/PixelWiseMultiplication.cpp b/tests/validation/reference/PixelWiseMultiplication.cpp index b3647fc9ce..546a886ac9 100644 --- a/tests/validation/reference/PixelWiseMultiplication.cpp +++ b/tests/validation/reference/PixelWiseMultiplication.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -41,46 +41,105 @@ struct is_floating_point { }; +namespace +{ +/** Compute the result of `src1 * src2 * scale`. The result type always matches the type of @p src2. + * + * @param[in] src1 An input value. Data types supported: U8/QS8/QS16/S16/F16/F32. + * @param[in] src2 An input value. Data types supported: same as @p src1. + * @param[in] scale Scale to apply after multiplication. + * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. For QS8 and QS16 scale must be 1. + * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate + * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even. + */ template <typename T1, typename T2> -SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) +T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) { - SimpleTensor<T2> dst(src2.shape(), src2.data_type()); + using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type; - if(scale < 0) - { - ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); - } + const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale); - using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type; + if(is_floating_point<T2>::value) + { + const auto result = static_cast<T2>(val); - for(int i = 0; i < src1.num_elements(); ++i) + return result; + } + else { - double val = static_cast<intermediate_type>(src1[i]) * static_cast<intermediate_type>(src2[i]) * static_cast<double>(scale); - if(is_floating_point<T2>::value) + double rounded_val = 0; + switch(rounding_policy) { - dst[i] = val; + case(RoundingPolicy::TO_ZERO): + rounded_val = support::cpp11::trunc(val); + break; + case(RoundingPolicy::TO_NEAREST_UP): + rounded_val = round_half_up(val); + break; + case(RoundingPolicy::TO_NEAREST_EVEN): + rounded_val = round_half_even(val); + break; + default: + ARM_COMPUTE_ERROR("Unsupported rounding policy"); } - else + + const auto result = static_cast<T2>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : rounded_val); + + return result; + } +} + +template <size_t dim> +struct BroadcastUnroll +{ + template <typename T1, typename T2> + static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, + Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) + { + const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]); + const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]); + + id_src1.set(dim - 1, 0); + id_src2.set(dim - 1, 0); + id_dst.set(dim - 1, 0); + + for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1]) { - double rounded_val = 0; - switch(rounding_policy) - { - case(RoundingPolicy::TO_ZERO): - rounded_val = support::cpp11::trunc(val); - break; - case(RoundingPolicy::TO_NEAREST_UP): - rounded_val = round_half_up(val); - break; - case(RoundingPolicy::TO_NEAREST_EVEN): - rounded_val = round_half_even(val); - break; - default: - ARM_COMPUTE_ERROR("Unsupported rounding policy"); - } - - dst[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : static_cast<T2>(rounded_val); + BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); + + id_src1[dim - 1] += !src1_is_broadcast; + id_src2[dim - 1] += !src2_is_broadcast; } } +}; + +template <> +struct BroadcastUnroll<0> +{ + template <typename T1, typename T2> + static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst, + float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy, + Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst) + { + dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy); + } +}; +} // namespace + +template <typename T1, typename T2> +SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy) +{ + SimpleTensor<T2> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type()); + + if(scale < 0) + { + ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative"); + } + + Coordinates id_src1, id_src2, id_dst; + + BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst); return dst; } |