aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h3
-rw-r--r--arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h20
-rw-r--r--src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp105
-rw-r--r--src/runtime/CL/functions/CLPixelWiseMultiplication.cpp17
-rw-r--r--tests/validation/CL/PixelWiseMultiplication.cpp6
-rw-r--r--tests/validation/fixtures/PixelWiseMultiplicationFixture.h49
-rw-r--r--tests/validation/reference/PixelWiseMultiplication.cpp119
7 files changed, 227 insertions, 92 deletions
diff --git a/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h b/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h
index 6746a49dde..1ecd9be8cd 100644
--- a/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h
+++ b/arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -76,6 +76,7 @@ public:
// Inherited methods overridden:
void run(const Window &window, cl::CommandQueue &queue) override;
+ BorderSize border_size() const override;
private:
const ICLTensor *_input1;
diff --git a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h
index d57bfda2c1..75b67cd17c 100644
--- a/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h
+++ b/arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -37,15 +37,17 @@ class CLPixelWiseMultiplication : public ICLSimpleFunction
public:
/** Initialise the kernel's inputs, output and convertion policy.
*
- * @param[in] input1 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32.
- * @param[in] input2 An input tensor. Data types supported: same as @p input1.
- * @param[out] output The output tensor, Data types supported: same as @p input1. Note: U8 (QS8, QS16) requires both inputs to be U8 (QS8, QS16).
- * @param[in] scale Scale to apply after multiplication.
- * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. For QS8 and QS16 scale must be 1.
- * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate
- * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
+ * @param[in, out] input1 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32.
+ * The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
+ * @param[in, out] input2 An input tensor. Data types supported: same as @p input1.
+ * The input tensor is [in, out] because its TensorInfo might be modified inside the kernel in case of broadcasting of dimension 0.
+ * @param[out] output The output tensor, Data types supported: same as @p input1. Note: U8 (QS8, QS16) requires both inputs to be U8 (QS8, QS16).
+ * @param[in] scale Scale to apply after multiplication.
+ * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. For QS8 and QS16 scale must be 1.
+ * @param[in] overflow_policy Overflow policy. Supported overflow policies: Wrap, Saturate
+ * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
*/
- void configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale,
+ void configure(ICLTensor *input1, ICLTensor *input2, ICLTensor *output, float scale,
ConvertPolicy overflow_policy, RoundingPolicy rounding_policy);
/** Static function to check if given info will lead to a valid configuration of @ref CLPixelWiseMultiplication
*
diff --git a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
index 6dba9c0f95..f30ba61b9a 100644
--- a/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
+++ b/src/core/CL/kernels/CLPixelWiseMultiplicationKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -42,6 +42,8 @@ using namespace arm_compute;
namespace
{
+constexpr unsigned int num_elems_processed_per_iteration = 16;
+
Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale,
ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
{
@@ -50,10 +52,13 @@ Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2,
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(scale < 0, "Scale cannot be negative.");
+ const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2);
+
if(is_data_type_fixed_point(input1->data_type()))
{
// All data types must be all QS8 or all QS16
@@ -62,12 +67,12 @@ Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2,
}
// Validate in case of configured output
- if((output != nullptr) && (output->total_size() != 0))
+ if(output->total_size() > 0)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::QS8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->data_type() == DataType::U8 && (input1->data_type() != DataType::U8 || input2->data_type() != DataType::U8),
"Output can only be U8 if both inputs are U8");
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output");
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, output);
if(is_data_type_fixed_point(input1->data_type()))
{
@@ -80,18 +85,36 @@ Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2,
std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
{
- constexpr unsigned int num_elems_processed_per_iteration = 16;
+ const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
+ const TensorShape &out_shape = broadcast_pair.first;
+ const ValidRegion &valid_region = broadcast_pair.second;
- Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration));
+ // Auto initialize output if not initialized
+ {
+ set_shape_if_empty(*output, out_shape);
+
+ if(input1->data_type() == DataType::S16 || input2->data_type() == DataType::S16)
+ {
+ set_format_if_unknown(*output, Format::S16);
+ }
+ else if(input1->data_type() == DataType::F32 || input2->data_type() == DataType::F32)
+ {
+ set_format_if_unknown(*output, Format::F32);
+ }
+ }
+
+ Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration));
+ Window win_input1 = win.broadcast_if_dimension_le_one(*input1);
+ Window win_input2 = win.broadcast_if_dimension_le_one(*input2);
AccessWindowHorizontal input1_access(input1, 0, num_elems_processed_per_iteration);
AccessWindowHorizontal input2_access(input2, 0, num_elems_processed_per_iteration);
AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
- bool window_changed = update_window_and_padding(win, input1_access, input2_access, output_access);
+ bool window_changed = update_window_and_padding(win_input1, input1_access)
+ || update_window_and_padding(win_input2, input2_access)
+ || update_window_and_padding(win, output_access);
- ValidRegion valid_region = intersect_valid_regions(input1->valid_region(),
- input2->valid_region());
output_access.set_valid_region(win, valid_region);
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
@@ -108,24 +131,13 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I
ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
-
- // Auto initialize output if not initialized
- {
- set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
-
- if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
- {
- set_format_if_unknown(*output->info(), Format::S16);
- }
- else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
- {
- set_format_if_unknown(*output->info(), Format::F32);
- }
- }
-
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(),
scale, overflow_policy, rounding_policy));
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+
_input1 = input1;
_input2 = input2;
_output = output;
@@ -207,15 +219,13 @@ void CLPixelWiseMultiplicationKernel::configure(const ICLTensor *input1, const I
_kernel.setArg(idx++, scale);
}
- // Configure kernel window
- auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure(win_config.second);
}
Status CLPixelWiseMultiplicationKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale,
ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, scale, overflow_policy, rounding_policy));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
@@ -227,16 +237,47 @@ void CLPixelWiseMultiplicationKernel::run(const Window &window, cl::CommandQueue
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICLKernel::window(), window);
- Window collapsed = window.collapse_if_possible(ICLKernel::window(), Window::DimZ);
- Window slice = collapsed.first_slice_window_3D();
+ const TensorShape &in_shape1 = _input1->info()->tensor_shape();
+ const TensorShape &in_shape2 = _input2->info()->tensor_shape();
+ const TensorShape &out_shape = _output->info()->tensor_shape();
+
+ bool can_collapse = true;
+ if(std::min(in_shape1.total_size(), in_shape2.total_size()) > 1)
+ {
+ can_collapse = (std::min(in_shape1.num_dimensions(), in_shape2.num_dimensions()) > Window::DimZ);
+ for(size_t d = Window::DimZ; can_collapse && (d < out_shape.num_dimensions()); ++d)
+ {
+ can_collapse = (in_shape1[d] == in_shape2[d]);
+ }
+ }
+
+ bool has_collapsed = false;
+ Window collapsed = can_collapse ? window.collapse_if_possible(ICLKernel::window(), Window::DimZ, &has_collapsed) : window;
+
+ const TensorShape &in_shape1_collapsed = has_collapsed ? in_shape1.collapsed_from(Window::DimZ) : in_shape1;
+ const TensorShape &in_shape2_collapsed = has_collapsed ? in_shape2.collapsed_from(Window::DimZ) : in_shape2;
+
+ Window slice = collapsed.first_slice_window_3D();
+ Window slice_input1 = slice.broadcast_if_dimension_le_one(in_shape1_collapsed);
+ Window slice_input2 = slice.broadcast_if_dimension_le_one(in_shape2_collapsed);
do
{
unsigned int idx = 0;
- add_3D_tensor_argument(idx, _input1, slice);
- add_3D_tensor_argument(idx, _input2, slice);
+ add_3D_tensor_argument(idx, _input1, slice_input1);
+ add_3D_tensor_argument(idx, _input2, slice_input2);
add_3D_tensor_argument(idx, _output, slice);
enqueue(queue, *this, slice);
+
+ collapsed.slide_window_slice_3D(slice_input1);
+ collapsed.slide_window_slice_3D(slice_input2);
}
while(collapsed.slide_window_slice_3D(slice));
}
+
+BorderSize CLPixelWiseMultiplicationKernel::border_size() const
+{
+ const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
+ const unsigned int border = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
+ return BorderSize(0, border, 0, 0);
+}
diff --git a/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp b/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp
index c78f94476e..b4c20db3da 100644
--- a/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp
+++ b/src/runtime/CL/functions/CLPixelWiseMultiplication.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016, 2017 ARM Limited.
+ * Copyright (c) 2016-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -23,6 +23,7 @@
*/
#include "arm_compute/runtime/CL/functions/CLPixelWiseMultiplication.h"
+#include "arm_compute/core/CL/ICLTensor.h"
#include "arm_compute/core/CL/kernels/CLPixelWiseMultiplicationKernel.h"
#include "support/ToolchainSupport.h"
@@ -30,16 +31,26 @@
using namespace arm_compute;
-void CLPixelWiseMultiplication::configure(const ICLTensor *input1, const ICLTensor *input2, ICLTensor *output, float scale,
+void CLPixelWiseMultiplication::configure(ICLTensor *input1, ICLTensor *input2, ICLTensor *output, float scale,
ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
{
auto k = arm_compute::support::cpp14::make_unique<CLPixelWiseMultiplicationKernel>();
k->configure(input1, input2, output, scale, overflow_policy, rounding_policy);
_kernel = std::move(k);
+
+ if(output->info()->dimension(0) > 1)
+ {
+ ICLTensor *broadcasted_info = (input1->info()->dimension(0) == 1) ? input1 : input2;
+
+ if(broadcasted_info->info()->dimension(0) == 1)
+ {
+ _border_handler.configure(broadcasted_info, _kernel->border_size(), BorderMode::REPLICATE);
+ }
+ }
}
Status CLPixelWiseMultiplication::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, float scale,
ConvertPolicy overflow_policy, RoundingPolicy rounding_policy)
{
return CLPixelWiseMultiplicationKernel::validate(input1, input2, output, scale, overflow_policy, rounding_policy);
-} \ No newline at end of file
+}
diff --git a/tests/validation/CL/PixelWiseMultiplication.cpp b/tests/validation/CL/PixelWiseMultiplication.cpp
index 45f57af3fc..6a71175f51 100644
--- a/tests/validation/CL/PixelWiseMultiplication.cpp
+++ b/tests/validation/CL/PixelWiseMultiplication.cpp
@@ -86,6 +86,8 @@ template <typename T>
using CLPixelWiseMultiplicationToQS16Fixture = PixelWiseMultiplicationValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, qint16_t>;
template <typename T>
using CLFixedPointPixelWiseMultiplicationFixture = FixedPointPixelWiseMultiplicationValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T>;
+template <typename T>
+using CLPixelWiseMultiplicationBroadcastFixture = PixelWiseMultiplicationBroadcastValidationFixture<CLTensor, CLAccessor, CLPixelWiseMultiplication, T, float>;
TEST_SUITE(CL)
TEST_SUITE(PixelWiseMultiplication)
@@ -169,6 +171,10 @@ TEST_SUITE_END() // ScaleUnity
TEST_SUITE_END() // QS16
+TEST_SUITE(Broadcast)
+PIXEL_WISE_MULTIPLICATION_FIXTURE_DATA_TEST_CASE(RunSmall, BroadcastFixture<float>, PRECOMMIT, SmallShapesBroadcast(), F32, F32, scale_255, TO_NEAREST_UP, VALIDATE(float, 1.f))
+TEST_SUITE_END() // Broadcast
+
TEST_SUITE_END() // FixedPointPixelWiseMultiplication
TEST_SUITE_END()
} // namespace validation
diff --git a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
index 7428fb5cb7..b9f19f3e77 100644
--- a/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
+++ b/tests/validation/fixtures/PixelWiseMultiplicationFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -40,19 +40,20 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
-class PixelWiseMultiplicationValidationFixture : public framework::Fixture
+class PixelWiseMultiplicationBroadcastValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape,
- DataType dt_in1,
- DataType dt_in2,
- float scale,
- ConvertPolicy convert_policy,
- RoundingPolicy rounding_policy)
+ void setup(const TensorShape &shape0,
+ const TensorShape &shape1,
+ DataType dt_in1,
+ DataType dt_in2,
+ float scale,
+ ConvertPolicy convert_policy,
+ RoundingPolicy rounding_policy)
{
- _target = compute_target(shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy);
- _reference = compute_reference(shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy);
+ _target = compute_target(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy);
+ _reference = compute_reference(shape0, shape1, dt_in1, dt_in2, scale, convert_policy, rounding_policy);
}
protected:
@@ -62,12 +63,13 @@ protected:
library->fill_tensor_uniform(tensor, seed_offset);
}
- TensorType compute_target(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+ TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
+ float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
// Create tensors
- TensorType src1 = create_tensor<TensorType>(shape, dt_in1);
- TensorType src2 = create_tensor<TensorType>(shape, dt_in2);
- TensorType dst = create_tensor<TensorType>(shape, dt_in2);
+ TensorType src1 = create_tensor<TensorType>(shape0, dt_in1);
+ TensorType src2 = create_tensor<TensorType>(shape1, dt_in2);
+ TensorType dst = create_tensor<TensorType>(TensorShape::broadcast_shape(shape0, shape1), dt_in2);
// Create and configure function
FunctionType multiply;
@@ -96,11 +98,12 @@ protected:
return dst;
}
- SimpleTensor<T2> compute_reference(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+ SimpleTensor<T2> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType dt_in1, DataType dt_in2,
+ float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
// Create reference
- SimpleTensor<T1> src1{ shape, dt_in1 };
- SimpleTensor<T2> src2{ shape, dt_in2 };
+ SimpleTensor<T1> src1{ shape0, dt_in1 };
+ SimpleTensor<T2> src2{ shape1, dt_in2 };
// Fill reference
fill(src1, 0);
@@ -112,6 +115,18 @@ protected:
TensorType _target{};
SimpleTensor<T2> _reference{};
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T1, typename T2>
+class PixelWiseMultiplicationValidationFixture : public PixelWiseMultiplicationBroadcastValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>
+{
+public:
+ template <typename...>
+ void setup(const TensorShape &shape, DataType dt_in1, DataType dt_in2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+ {
+ PixelWiseMultiplicationBroadcastValidationFixture<TensorType, AccessorType, FunctionType, T1, T2>::setup(shape, shape, dt_in1, dt_in2, scale, convert_policy, rounding_policy);
+ }
+};
+
} // namespace validation
} // namespace test
} // namespace arm_compute
diff --git a/tests/validation/reference/PixelWiseMultiplication.cpp b/tests/validation/reference/PixelWiseMultiplication.cpp
index b3647fc9ce..546a886ac9 100644
--- a/tests/validation/reference/PixelWiseMultiplication.cpp
+++ b/tests/validation/reference/PixelWiseMultiplication.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -41,46 +41,105 @@ struct is_floating_point
{
};
+namespace
+{
+/** Compute the result of `src1 * src2 * scale`. The result type always matches the type of @p src2.
+ *
+ * @param[in] src1 An input value. Data types supported: U8/QS8/QS16/S16/F16/F32.
+ * @param[in] src2 An input value. Data types supported: same as @p src1.
+ * @param[in] scale Scale to apply after multiplication.
+ * Scale must be positive and its value must be either 1/255 or 1/2^n where n is between 0 and 15. For QS8 and QS16 scale must be 1.
+ * @param[in] convert_policy Overflow policy. Supported overflow policies: Wrap, Saturate
+ * @param[in] rounding_policy Rounding policy. Supported rounding modes: to zero, to nearest even.
+ */
template <typename T1, typename T2>
-SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+T2 mul(const T1 src1, const T2 src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
{
- SimpleTensor<T2> dst(src2.shape(), src2.data_type());
+ using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type;
- if(scale < 0)
- {
- ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
- }
+ const double val = static_cast<intermediate_type>(src1) * static_cast<intermediate_type>(src2) * static_cast<double>(scale);
- using intermediate_type = typename common_promoted_signed_type<T1, T2, T2>::intermediate_type;
+ if(is_floating_point<T2>::value)
+ {
+ const auto result = static_cast<T2>(val);
- for(int i = 0; i < src1.num_elements(); ++i)
+ return result;
+ }
+ else
{
- double val = static_cast<intermediate_type>(src1[i]) * static_cast<intermediate_type>(src2[i]) * static_cast<double>(scale);
- if(is_floating_point<T2>::value)
+ double rounded_val = 0;
+ switch(rounding_policy)
{
- dst[i] = val;
+ case(RoundingPolicy::TO_ZERO):
+ rounded_val = support::cpp11::trunc(val);
+ break;
+ case(RoundingPolicy::TO_NEAREST_UP):
+ rounded_val = round_half_up(val);
+ break;
+ case(RoundingPolicy::TO_NEAREST_EVEN):
+ rounded_val = round_half_even(val);
+ break;
+ default:
+ ARM_COMPUTE_ERROR("Unsupported rounding policy");
}
- else
+
+ const auto result = static_cast<T2>((convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : rounded_val);
+
+ return result;
+ }
+}
+
+template <size_t dim>
+struct BroadcastUnroll
+{
+ template <typename T1, typename T2>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
+ {
+ const bool src1_is_broadcast = (src1.shape()[dim - 1] != dst.shape()[dim - 1]);
+ const bool src2_is_broadcast = (src2.shape()[dim - 1] != dst.shape()[dim - 1]);
+
+ id_src1.set(dim - 1, 0);
+ id_src2.set(dim - 1, 0);
+ id_dst.set(dim - 1, 0);
+
+ for(size_t i = 0; i < dst.shape()[dim - 1]; ++i, ++id_dst[dim - 1])
{
- double rounded_val = 0;
- switch(rounding_policy)
- {
- case(RoundingPolicy::TO_ZERO):
- rounded_val = support::cpp11::trunc(val);
- break;
- case(RoundingPolicy::TO_NEAREST_UP):
- rounded_val = round_half_up(val);
- break;
- case(RoundingPolicy::TO_NEAREST_EVEN):
- rounded_val = round_half_even(val);
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported rounding policy");
- }
-
- dst[i] = (convert_policy == ConvertPolicy::SATURATE) ? saturate_cast<T2>(rounded_val) : static_cast<T2>(rounded_val);
+ BroadcastUnroll < dim - 1 >::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
+
+ id_src1[dim - 1] += !src1_is_broadcast;
+ id_src2[dim - 1] += !src2_is_broadcast;
}
}
+};
+
+template <>
+struct BroadcastUnroll<0>
+{
+ template <typename T1, typename T2>
+ static void unroll(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, SimpleTensor<T2> &dst,
+ float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy,
+ Coordinates &id_src1, Coordinates &id_src2, Coordinates &id_dst)
+ {
+ dst[coord2index(dst.shape(), id_dst)] = mul(src1[coord2index(src1.shape(), id_src1)], src2[coord2index(src2.shape(), id_src2)], scale, convert_policy, rounding_policy);
+ }
+};
+} // namespace
+
+template <typename T1, typename T2>
+SimpleTensor<T2> pixel_wise_multiplication(const SimpleTensor<T1> &src1, const SimpleTensor<T2> &src2, float scale, ConvertPolicy convert_policy, RoundingPolicy rounding_policy)
+{
+ SimpleTensor<T2> dst(TensorShape::broadcast_shape(src1.shape(), src2.shape()), src2.data_type());
+
+ if(scale < 0)
+ {
+ ARM_COMPUTE_ERROR("Scale of pixel-wise multiplication must be non-negative");
+ }
+
+ Coordinates id_src1, id_src2, id_dst;
+
+ BroadcastUnroll<Coordinates::num_max_dimensions>::unroll(src1, src2, dst, scale, convert_policy, rounding_policy, id_src1, id_src2, id_dst);
return dst;
}