aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
diff options
context:
space:
mode:
authorGeorgios Pinitas <georgios.pinitas@arm.com>2018-09-10 15:07:45 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-11-02 16:54:54 +0000
commitcbf39c63a6eb89a2c80b2338afc374081803d79d (patch)
treeafe1c55d5e3bbf0e111ec0dce9a564304844a55f /src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
parentd7647d4ebd0f0b5253b7f31ffcd48a851ba62947 (diff)
downloadComputeLibrary-cbf39c63a6eb89a2c80b2338afc374081803d79d.tar.gz
COMPMID-1566: Add broadcast to CLArithmeticSubtraction
Change-Id: I05d21f9a92013ecfd1128d12cf1561cfd6e5c5e9 Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/147983 Tested-by: bsgcomp <bsgcomp@arm.com> Reviewed-by: Anthony Barbier <anthony.barbier@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp171
1 files changed, 96 insertions, 75 deletions
diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
index 3c76548b0a..ff8fb84958 100644
--- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
@@ -46,10 +46,12 @@ class Coordinates;
namespace
{
+constexpr unsigned int num_elems_processed_per_iteration = 16;
+
void sub_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -64,8 +66,8 @@ void sub_wrap_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, con
void sub_saturate_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -80,8 +82,8 @@ void sub_saturate_U8_U8_U8(const ITensor *in1, const ITensor *in2, ITensor *out,
void sub_wrap_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -104,8 +106,8 @@ void sub_wrap_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out,
void sub_saturate_S16_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -144,8 +146,8 @@ inline float16x8x2_t vsub2q_f16(const float16x8x2_t &a, const float16x8x2_t &b)
void sub_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -167,8 +169,8 @@ void sub_F16_F16_F16(const ITensor *in1, const ITensor *in2, ITensor *out, const
void sub_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -192,8 +194,8 @@ void sub_F32_F32_F32(const ITensor *in1, const ITensor *in2, ITensor *out, const
}
void sub_wrap_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -213,8 +215,8 @@ void sub_wrap_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, c
void sub_saturate_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -234,8 +236,8 @@ void sub_saturate_S16_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *ou
void sub_wrap_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -255,8 +257,8 @@ void sub_wrap_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, c
void sub_saturate_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -276,8 +278,8 @@ void sub_saturate_U8_S16_S16(const ITensor *in1, const ITensor *in2, ITensor *ou
void sub_wrap_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -298,8 +300,8 @@ void sub_wrap_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, co
void sub_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, const Window &window)
{
- Iterator input1(in1, window);
- Iterator input2(in2, window);
+ Iterator input1(in1, window.broadcast_if_dimension_le_one(in1->info()->tensor_shape()));
+ Iterator input2(in2, window.broadcast_if_dimension_le_one(in2->info()->tensor_shape()));
Iterator output(out, window);
execute_window_loop(window, [&](const Coordinates & id)
@@ -318,43 +320,71 @@ void sub_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out
input1, input2, output);
}
-inline Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+inline Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output, ConvertPolicy policy)
{
ARM_COMPUTE_UNUSED(policy);
- ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8)
- && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
- && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
- && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
- && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
- && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32)
- && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16),
- "You called subtract with the wrong image formats");
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8, DataType::S16, DataType::F16, DataType::F32);
+
+ const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
+ // Validate in case of configured output
+ if(output.total_size() > 0)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8)
+ && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
+ && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
+ && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
+ && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
+ && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32)
+ && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16),
+ "You called subtract with the wrong image formats");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
+ "Wrong shape for output");
+ }
return Status{};
}
-inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
+inline std::pair<Status, Window> validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output)
{
- constexpr unsigned int num_elems_processed_per_iteration = 16;
+ const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2);
+ const TensorShape &out_shape = broadcast_pair.first;
+ const ValidRegion &valid_region = broadcast_pair.second;
- // Configure kernel window
- Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration));
- AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+ // Auto initialize output if not initialized
+ {
+ set_shape_if_empty(output, out_shape);
- bool window_changed = update_window_and_padding(win,
- AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration),
- AccessWindowHorizontal(input2, 0, num_elems_processed_per_iteration),
- output_access);
+ if(input1.data_type() == DataType::S16 || input2.data_type() == DataType::S16)
+ {
+ set_format_if_unknown(output, Format::S16);
+ }
+ else if(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16)
+ {
+ set_format_if_unknown(output, Format::F16);
+ }
+ else if(input1.data_type() == DataType::F32 || input2.data_type() == DataType::F32)
+ {
+ set_format_if_unknown(output, Format::F32);
+ }
+ }
+
+ Window win = calculate_max_window(valid_region, Steps(num_elems_processed_per_iteration));
+ Window win_input1 = win.broadcast_if_dimension_le_one(input1);
+ Window win_input2 = win.broadcast_if_dimension_le_one(input2);
- ValidRegion valid_region = intersect_valid_regions(input1->valid_region(),
- input2->valid_region());
+ AccessWindowHorizontal input1_access(&input1, 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal input2_access(&input2, 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal output_access(&output, 0, num_elems_processed_per_iteration);
+
+ bool window_changed = update_window_and_padding(win_input1, input1_access)
+ || update_window_and_padding(win_input2, input2_access)
+ || update_window_and_padding(win, output_access);
output_access.set_valid_region(win, valid_region);
@@ -371,26 +401,11 @@ NEArithmeticSubtractionKernel::NEArithmeticSubtractionKernel()
void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy));
- // Auto initialize output if not initialized
- {
- set_shape_if_empty(*output->info(), input1->info()->tensor_shape());
-
- if(input1->info()->data_type() == DataType::S16 || input2->info()->data_type() == DataType::S16)
- {
- set_format_if_unknown(*output->info(), Format::S16);
- }
- else if(input1->info()->data_type() == DataType::F16 || input2->info()->data_type() == DataType::F16)
- {
- set_format_if_unknown(*output->info(), Format::F16);
- }
- else if(input1->info()->data_type() == DataType::F32 || input2->info()->data_type() == DataType::F32)
- {
- set_format_if_unknown(*output->info(), Format::F32);
- }
- }
-
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), policy));
+ // Configure kernel window
+ auto win_config = validate_and_configure_window(*input1->info(), *input2->info(), *output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
static std::map<std::string, NEArithmeticSubtractionKernel::SubFunction *> map_function =
{
@@ -427,16 +442,15 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens
_func = it->second;
}
- // Configure kernel window
- auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
- ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
INEKernel::configure(win_config.second);
}
Status NEArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
+
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output, policy));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(*input1->clone(), *input2->clone(), *output->clone()).first);
return Status{};
}
@@ -450,3 +464,10 @@ void NEArithmeticSubtractionKernel::run(const Window &window, const ThreadInfo &
(*_func)(_input1, _input2, _output, window);
}
+
+BorderSize NEArithmeticSubtractionKernel::border_size() const
+{
+ const unsigned int replicateSize = _output->info()->dimension(0) - std::min(_input1->info()->dimension(0), _input2->info()->dimension(0));
+ const unsigned int border = std::min<unsigned int>(num_elems_processed_per_iteration - 1U, replicateSize);
+ return BorderSize(0, border, 0, 0);
+} \ No newline at end of file