aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
diff options
context:
space:
mode:
authorMichele Di Giorgio <michele.digiorgio@arm.com>2020-06-04 15:05:38 +0100
committerMichele Di Giorgio <michele.digiorgio@arm.com>2020-06-15 13:59:04 +0000
commit4a61653202afb018f4f259d3c144a735d73f0a20 (patch)
tree082fd42e91cc0914dcacc0746bbe3e117d74210c /src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
parentccd94966cc58ef5148577e71ba1a4ff5aae1f3bb (diff)
downloadComputeLibrary-4a61653202afb018f4f259d3c144a735d73f0a20.tar.gz
COMPMID-3480: Perform in-place computations in NEArithmeticAdditionKernel
Change-Id: I0089657dd95d7c7b8592984def8e8de1d7e6d085 Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3308 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp121
1 files changed, 71 insertions, 50 deletions
diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
index 3878c764a6..a1c263a836 100644
--- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
@@ -810,95 +810,110 @@ void add_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, Convert
input1, input2, output);
}
-Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output, ConvertPolicy policy)
+Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
{
ARM_COMPUTE_UNUSED(policy);
- ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::S16, DataType::QSYMM16, DataType::F16,
DataType::S32, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
DataType::S16, DataType::QSYMM16, DataType::F16,
DataType::S32, DataType::F32);
- const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
+ const TensorShape out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((input1.tensor_shape().x() != input2.tensor_shape().x()) && ((input1.data_type() != input2.data_type()) || (input1.data_type() != output.data_type())
- || (input2.data_type() != output.data_type())),
- "Broadcasting across width is supported on configurations where all tensors have the same data type");
// Validate in case of configured output
- if(output.total_size() > 0)
+ if((output != nullptr) && (output->total_size() > 0))
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8)
- && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
- && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
- && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16)
- && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16)
- && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32 && output.data_type() == DataType::S32)
- && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32)
- && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16)
- && !(input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && output.data_type() == DataType::QASYMM8)
- && !(input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && output.data_type() == DataType::QASYMM8_SIGNED)
- && !(input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && output.data_type() == DataType::QSYMM16),
+ !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8)
+ && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::S32 && input2->data_type() == DataType::S32 && output->data_type() == DataType::S32)
+ && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32)
+ && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16)
+ && !(input1->data_type() == DataType::QASYMM8 && input2->data_type() == DataType::QASYMM8 && output->data_type() == DataType::QASYMM8)
+ && !(input1->data_type() == DataType::QASYMM8_SIGNED && input2->data_type() == DataType::QASYMM8_SIGNED && output->data_type() == DataType::QASYMM8_SIGNED)
+ && !(input1->data_type() == DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16 && output->data_type() == DataType::QSYMM16),
"You called addition with the wrong image formats");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0),
"Wrong shape for output");
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((input1->tensor_shape().x() != input2->tensor_shape().x())
+ && ((input1->data_type() != input2->data_type()) || (input1->data_type() != output->data_type())
+ || (input2->data_type() != output->data_type())),
+ "Broadcasting across width is supported on configurations where all tensors have the same data type");
+ }
+ else
+ {
+ // Either auto-initialized output or in-place computation
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((input1->tensor_shape().x() != input2->tensor_shape().x()) && (input1->data_type() != input2->data_type()),
+ "Broadcasting across width is supported on configurations where all tensors have the same data type");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((output == nullptr) && detail::have_different_dimensions(out_shape, input1->tensor_shape(), 0),
+ "In case of in-place computation the broadcast input must be input2");
}
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output)
+std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
{
- const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2);
+ const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2);
const TensorShape &out_shape = broadcast_pair.first;
const ValidRegion &valid_region = broadcast_pair.second;
+ ITensorInfo *output_to_use = input1;
+
// Auto initialize output if not initialized
+ if(output != nullptr)
{
- set_shape_if_empty(output, out_shape);
+ set_shape_if_empty(*output, out_shape);
- if(input1.data_type() == DataType::S16 || input2.data_type() == DataType::S16)
+ if(input1->data_type() == DataType::S16 || input2->data_type() == DataType::S16)
{
- set_format_if_unknown(output, Format::S16);
+ set_format_if_unknown(*output, Format::S16);
}
- if(input1.data_type() == DataType::S32 || input2.data_type() == DataType::S32)
+ if(input1->data_type() == DataType::S32 || input2->data_type() == DataType::S32)
{
- set_format_if_unknown(output, Format::S32);
+ set_format_if_unknown(*output, Format::S32);
}
- else if(input1.data_type() == DataType::F16 || input2.data_type() == DataType::F16)
+ else if(input1->data_type() == DataType::F16 || input2->data_type() == DataType::F16)
{
- set_format_if_unknown(output, Format::F16);
+ set_format_if_unknown(*output, Format::F16);
}
- else if(input1.data_type() == DataType::F32 || input2.data_type() == DataType::F32)
+ else if(input1->data_type() == DataType::F32 || input2->data_type() == DataType::F32)
{
- set_format_if_unknown(output, Format::F32);
+ set_format_if_unknown(*output, Format::F32);
}
- else if(input1.data_type() == DataType::QASYMM8 || input2.data_type() == DataType::QASYMM8)
+ else if(input1->data_type() == DataType::QASYMM8 || input2->data_type() == DataType::QASYMM8)
{
- set_data_type_if_unknown(output, DataType::QASYMM8);
+ set_data_type_if_unknown(*output, DataType::QASYMM8);
}
- else if(input1.data_type() == DataType::QASYMM8_SIGNED || input2.data_type() == DataType::QASYMM8_SIGNED)
+ else if(input1->data_type() == DataType::QASYMM8_SIGNED || input2->data_type() == DataType::QASYMM8_SIGNED)
{
- set_data_type_if_unknown(output, DataType::QASYMM8_SIGNED);
+ set_data_type_if_unknown(*output, DataType::QASYMM8_SIGNED);
}
- else if(input1.data_type() == DataType::QSYMM16 || input2.data_type() == DataType::QSYMM16)
+ else if(input1->data_type() == DataType::QSYMM16 || input2->data_type() == DataType::QSYMM16)
{
- set_data_type_if_unknown(output, DataType::QSYMM16);
+ set_data_type_if_unknown(*output, DataType::QSYMM16);
}
+
+ output_to_use = output;
}
Window win = calculate_max_window(valid_region, Steps());
// NEArithmeticAdditionKernel doesn't need padding so update_window_and_padding() can be skipped
Coordinates coord;
- coord.set_num_dimensions(output.num_dimensions());
- output.set_valid_region(valid_region);
+ coord.set_num_dimensions(output_to_use->num_dimensions());
+ output_to_use->set_valid_region(valid_region);
return std::make_pair(Status{}, win);
}
} // namespace
@@ -908,13 +923,15 @@ NEArithmeticAdditionKernel::NEArithmeticAdditionKernel()
{
}
-void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
+void NEArithmeticAdditionKernel::configure(ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy));
+ ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(),
+ (output != nullptr) ? output->info() : nullptr,
+ policy));
// Configure kernel window
- auto win_config = validate_and_configure_window(*input1->info(), *input2->info(), *output->info());
+ auto win_config = validate_and_configure_window(input1->info(), input2->info(), (output != nullptr) ? output->info() : nullptr);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
static std::map<std::string, AddFunction *> map_function =
@@ -947,14 +964,20 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor
_input1 = input1;
_input2 = input2;
- _output = output;
+ _output = input1;
_policy = policy;
+ // Out-of-place calculation
+ if(output != nullptr)
+ {
+ _output = output;
+ }
+
std::string function_to_call("add_");
function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_";
function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
- function_to_call += string_from_data_type(output->info()->data_type());
+ function_to_call += string_from_data_type(_output->info()->data_type());
auto it = map_function.find(function_to_call);
@@ -968,10 +991,8 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor
Status NEArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
-
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output, policy));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(*input1->clone(), *input2->clone(), *output->clone()).first);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, (output != nullptr) ? output : nullptr, policy));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), (output != nullptr) ? output->clone().get() : nullptr).first);
return Status{};
}