From 19023835fa5a73dea2823edf667c711b03bc5060 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Wed, 17 Jun 2020 16:08:10 +0000 Subject: Revert "COMPMID-3480: Perform in-place computations in NEArithmeticAdditionKernel" This reverts commit 4a61653202afb018f4f259d3c144a735d73f0a20. Reason for revert: We will allow in-place computations by providing the same input1 (or input2) as output, thus avoiding changes in the interface. Change-Id: I7c8669e207e15731dc26dc366150bf960508a879 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3035 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../NEON/kernels/NEArithmeticAdditionKernel.cpp | 121 +++++++++------------ 1 file changed, 50 insertions(+), 71 deletions(-) (limited to 'src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp') diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp index a1c263a836..3878c764a6 100644 --- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp +++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp @@ -810,110 +810,95 @@ void add_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, Convert input1, input2, output); } -Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output, ConvertPolicy policy) { ARM_COMPUTE_UNUSED(policy); - ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); - const TensorShape out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); + const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((input1.tensor_shape().x() != input2.tensor_shape().x()) && ((input1.data_type() != input2.data_type()) || (input1.data_type() != output.data_type()) + || (input2.data_type() != output.data_type())), + "Broadcasting across width is supported on configurations where all tensors have the same data type"); // Validate in case of configured output - if((output != nullptr) && (output->total_size() > 0)) + if(output.total_size() > 0) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8) - && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) - && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) - && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) - && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) - && !(input1->data_type() == DataType::S32 && input2->data_type() == DataType::S32 && output->data_type() == DataType::S32) - && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32) - && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16) - && !(input1->data_type() == DataType::QASYMM8 && input2->data_type() == DataType::QASYMM8 && output->data_type() == DataType::QASYMM8) - && !(input1->data_type() == DataType::QASYMM8_SIGNED && input2->data_type() == DataType::QASYMM8_SIGNED && output->data_type() == DataType::QASYMM8_SIGNED) - && !(input1->data_type() == DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16 && output->data_type() == DataType::QSYMM16), + !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8) + && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16) + && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16) + && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16) + && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16) + && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32 && output.data_type() == DataType::S32) + && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32) + && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16) + && !(input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && output.data_type() == DataType::QASYMM8) + && !(input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && output.data_type() == DataType::QASYMM8_SIGNED) + && !(input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && output.data_type() == DataType::QSYMM16), "You called addition with the wrong image formats"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0), "Wrong shape for output"); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG((input1->tensor_shape().x() != input2->tensor_shape().x()) - && ((input1->data_type() != input2->data_type()) || (input1->data_type() != output->data_type()) - || (input2->data_type() != output->data_type())), - "Broadcasting across width is supported on configurations where all tensors have the same data type"); - } - else - { - // Either auto-initialized output or in-place computation - ARM_COMPUTE_RETURN_ERROR_ON_MSG((input1->tensor_shape().x() != input2->tensor_shape().x()) && (input1->data_type() != input2->data_type()), - "Broadcasting across width is supported on configurations where all tensors have the same data type"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG((output == nullptr) && detail::have_different_dimensions(out_shape, input1->tensor_shape(), 0), - "In case of in-place computation the broadcast input must be input2"); } return Status{}; } -std::pair validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +std::pair validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output) { - const std::pair broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2); + const std::pair broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2); const TensorShape &out_shape = broadcast_pair.first; const ValidRegion &valid_region = broadcast_pair.second; - ITensorInfo *output_to_use = input1; - // Auto initialize output if not initialized - if(output != nullptr) { - set_shape_if_empty(*output, out_shape); + set_shape_if_empty(output, out_shape); - if(input1->data_type() == DataType::S16 || input2->data_type() == DataType::S16) + if(input1.data_type() == DataType::S16 || input2.data_type() == DataType::S16) { - set_format_if_unknown(*output, Format::S16); + set_format_if_unknown(output, Format::S16); } - if(input1->data_type() == DataType::S32 || input2->data_type() == DataType::S32) + if(input1.data_type() == DataType::S32 || input2.data_type() == DataType::S32) { - set_format_if_unknown(*output, Format::S32); + set_format_if_unknown(output, Format::S32); } - else if(input1->data_type() == DataType::F16 || input2->data_type() == DataType::F16) + else if(input1.data_type() == DataType::F16 || input2.data_type() == DataType::F16) { - set_format_if_unknown(*output, Format::F16); + set_format_if_unknown(output, Format::F16); } - else if(input1->data_type() == DataType::F32 || input2->data_type() == DataType::F32) + else if(input1.data_type() == DataType::F32 || input2.data_type() == DataType::F32) { - set_format_if_unknown(*output, Format::F32); + set_format_if_unknown(output, Format::F32); } - else if(input1->data_type() == DataType::QASYMM8 || input2->data_type() == DataType::QASYMM8) + else if(input1.data_type() == DataType::QASYMM8 || input2.data_type() == DataType::QASYMM8) { - set_data_type_if_unknown(*output, DataType::QASYMM8); + set_data_type_if_unknown(output, DataType::QASYMM8); } - else if(input1->data_type() == DataType::QASYMM8_SIGNED || input2->data_type() == DataType::QASYMM8_SIGNED) + else if(input1.data_type() == DataType::QASYMM8_SIGNED || input2.data_type() == DataType::QASYMM8_SIGNED) { - set_data_type_if_unknown(*output, DataType::QASYMM8_SIGNED); + set_data_type_if_unknown(output, DataType::QASYMM8_SIGNED); } - else if(input1->data_type() == DataType::QSYMM16 || input2->data_type() == DataType::QSYMM16) + else if(input1.data_type() == DataType::QSYMM16 || input2.data_type() == DataType::QSYMM16) { - set_data_type_if_unknown(*output, DataType::QSYMM16); + set_data_type_if_unknown(output, DataType::QSYMM16); } - - output_to_use = output; } Window win = calculate_max_window(valid_region, Steps()); // NEArithmeticAdditionKernel doesn't need padding so update_window_and_padding() can be skipped Coordinates coord; - coord.set_num_dimensions(output_to_use->num_dimensions()); - output_to_use->set_valid_region(valid_region); + coord.set_num_dimensions(output.num_dimensions()); + output.set_valid_region(valid_region); return std::make_pair(Status{}, win); } } // namespace @@ -923,15 +908,13 @@ NEArithmeticAdditionKernel::NEArithmeticAdditionKernel() { } -void NEArithmeticAdditionKernel::configure(ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy) +void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), - (output != nullptr) ? output->info() : nullptr, - policy)); + ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy)); // Configure kernel window - auto win_config = validate_and_configure_window(input1->info(), input2->info(), (output != nullptr) ? output->info() : nullptr); + auto win_config = validate_and_configure_window(*input1->info(), *input2->info(), *output->info()); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); static std::map map_function = @@ -964,20 +947,14 @@ void NEArithmeticAdditionKernel::configure(ITensor *input1, const ITensor *input _input1 = input1; _input2 = input2; - _output = input1; + _output = output; _policy = policy; - // Out-of-place calculation - if(output != nullptr) - { - _output = output; - } - std::string function_to_call("add_"); function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_"; function_to_call += string_from_data_type(input1->info()->data_type()) + "_"; function_to_call += string_from_data_type(input2->info()->data_type()) + "_"; - function_to_call += string_from_data_type(_output->info()->data_type()); + function_to_call += string_from_data_type(output->info()->data_type()); auto it = map_function.find(function_to_call); @@ -991,8 +968,10 @@ void NEArithmeticAdditionKernel::configure(ITensor *input1, const ITensor *input Status NEArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) { - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, (output != nullptr) ? output : nullptr, policy)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), (output != nullptr) ? output->clone().get() : nullptr).first); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output); + + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output, policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(*input1->clone(), *input2->clone(), *output->clone()).first); return Status{}; } -- cgit v1.2.1