From 4a61653202afb018f4f259d3c144a735d73f0a20 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Thu, 4 Jun 2020 15:05:38 +0100 Subject: COMPMID-3480: Perform in-place computations in NEArithmeticAdditionKernel Change-Id: I0089657dd95d7c7b8592984def8e8de1d7e6d085 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3308 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas Comments-Addressed: Arm Jenkins --- .../NEON/kernels/NEArithmeticAdditionKernel.cpp | 121 ++++++++++++--------- 1 file changed, 71 insertions(+), 50 deletions(-) (limited to 'src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp') diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp index 3878c764a6..a1c263a836 100644 --- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp +++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp @@ -810,95 +810,110 @@ void add_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out, Convert input1, input2, output); } -Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output, ConvertPolicy policy) +Status validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) { ARM_COMPUTE_UNUSED(policy); - ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input1); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::QSYMM16, DataType::F16, DataType::S32, DataType::F32); - const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape()); + const TensorShape out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape()); ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG((input1.tensor_shape().x() != input2.tensor_shape().x()) && ((input1.data_type() != input2.data_type()) || (input1.data_type() != output.data_type()) - || (input2.data_type() != output.data_type())), - "Broadcasting across width is supported on configurations where all tensors have the same data type"); // Validate in case of configured output - if(output.total_size() > 0) + if((output != nullptr) && (output->total_size() > 0)) { ARM_COMPUTE_RETURN_ERROR_ON_MSG( - !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::U8) - && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16) - && !(input1.data_type() == DataType::U8 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16) - && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::U8 && output.data_type() == DataType::S16) - && !(input1.data_type() == DataType::S16 && input2.data_type() == DataType::S16 && output.data_type() == DataType::S16) - && !(input1.data_type() == DataType::S32 && input2.data_type() == DataType::S32 && output.data_type() == DataType::S32) - && !(input1.data_type() == DataType::F32 && input2.data_type() == DataType::F32 && output.data_type() == DataType::F32) - && !(input1.data_type() == DataType::F16 && input2.data_type() == DataType::F16 && output.data_type() == DataType::F16) - && !(input1.data_type() == DataType::QASYMM8 && input2.data_type() == DataType::QASYMM8 && output.data_type() == DataType::QASYMM8) - && !(input1.data_type() == DataType::QASYMM8_SIGNED && input2.data_type() == DataType::QASYMM8_SIGNED && output.data_type() == DataType::QASYMM8_SIGNED) - && !(input1.data_type() == DataType::QSYMM16 && input2.data_type() == DataType::QSYMM16 && output.data_type() == DataType::QSYMM16), + !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8) + && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::S32 && input2->data_type() == DataType::S32 && output->data_type() == DataType::S32) + && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32) + && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16) + && !(input1->data_type() == DataType::QASYMM8 && input2->data_type() == DataType::QASYMM8 && output->data_type() == DataType::QASYMM8) + && !(input1->data_type() == DataType::QASYMM8_SIGNED && input2->data_type() == DataType::QASYMM8_SIGNED && output->data_type() == DataType::QASYMM8_SIGNED) + && !(input1->data_type() == DataType::QSYMM16 && input2->data_type() == DataType::QSYMM16 && output->data_type() == DataType::QSYMM16), "You called addition with the wrong image formats"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0), + ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output->tensor_shape(), 0), "Wrong shape for output"); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG((input1->tensor_shape().x() != input2->tensor_shape().x()) + && ((input1->data_type() != input2->data_type()) || (input1->data_type() != output->data_type()) + || (input2->data_type() != output->data_type())), + "Broadcasting across width is supported on configurations where all tensors have the same data type"); + } + else + { + // Either auto-initialized output or in-place computation + ARM_COMPUTE_RETURN_ERROR_ON_MSG((input1->tensor_shape().x() != input2->tensor_shape().x()) && (input1->data_type() != input2->data_type()), + "Broadcasting across width is supported on configurations where all tensors have the same data type"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((output == nullptr) && detail::have_different_dimensions(out_shape, input1->tensor_shape(), 0), + "In case of in-place computation the broadcast input must be input2"); } return Status{}; } -std::pair validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output) +std::pair validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) { - const std::pair broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2); + const std::pair broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(*input1, *input2); const TensorShape &out_shape = broadcast_pair.first; const ValidRegion &valid_region = broadcast_pair.second; + ITensorInfo *output_to_use = input1; + // Auto initialize output if not initialized + if(output != nullptr) { - set_shape_if_empty(output, out_shape); + set_shape_if_empty(*output, out_shape); - if(input1.data_type() == DataType::S16 || input2.data_type() == DataType::S16) + if(input1->data_type() == DataType::S16 || input2->data_type() == DataType::S16) { - set_format_if_unknown(output, Format::S16); + set_format_if_unknown(*output, Format::S16); } - if(input1.data_type() == DataType::S32 || input2.data_type() == DataType::S32) + if(input1->data_type() == DataType::S32 || input2->data_type() == DataType::S32) { - set_format_if_unknown(output, Format::S32); + set_format_if_unknown(*output, Format::S32); } - else if(input1.data_type() == DataType::F16 || input2.data_type() == DataType::F16) + else if(input1->data_type() == DataType::F16 || input2->data_type() == DataType::F16) { - set_format_if_unknown(output, Format::F16); + set_format_if_unknown(*output, Format::F16); } - else if(input1.data_type() == DataType::F32 || input2.data_type() == DataType::F32) + else if(input1->data_type() == DataType::F32 || input2->data_type() == DataType::F32) { - set_format_if_unknown(output, Format::F32); + set_format_if_unknown(*output, Format::F32); } - else if(input1.data_type() == DataType::QASYMM8 || input2.data_type() == DataType::QASYMM8) + else if(input1->data_type() == DataType::QASYMM8 || input2->data_type() == DataType::QASYMM8) { - set_data_type_if_unknown(output, DataType::QASYMM8); + set_data_type_if_unknown(*output, DataType::QASYMM8); } - else if(input1.data_type() == DataType::QASYMM8_SIGNED || input2.data_type() == DataType::QASYMM8_SIGNED) + else if(input1->data_type() == DataType::QASYMM8_SIGNED || input2->data_type() == DataType::QASYMM8_SIGNED) { - set_data_type_if_unknown(output, DataType::QASYMM8_SIGNED); + set_data_type_if_unknown(*output, DataType::QASYMM8_SIGNED); } - else if(input1.data_type() == DataType::QSYMM16 || input2.data_type() == DataType::QSYMM16) + else if(input1->data_type() == DataType::QSYMM16 || input2->data_type() == DataType::QSYMM16) { - set_data_type_if_unknown(output, DataType::QSYMM16); + set_data_type_if_unknown(*output, DataType::QSYMM16); } + + output_to_use = output; } Window win = calculate_max_window(valid_region, Steps()); // NEArithmeticAdditionKernel doesn't need padding so update_window_and_padding() can be skipped Coordinates coord; - coord.set_num_dimensions(output.num_dimensions()); - output.set_valid_region(valid_region); + coord.set_num_dimensions(output_to_use->num_dimensions()); + output_to_use->set_valid_region(valid_region); return std::make_pair(Status{}, win); } } // namespace @@ -908,13 +923,15 @@ NEArithmeticAdditionKernel::NEArithmeticAdditionKernel() { } -void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy) +void NEArithmeticAdditionKernel::configure(ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy)); + ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), + (output != nullptr) ? output->info() : nullptr, + policy)); // Configure kernel window - auto win_config = validate_and_configure_window(*input1->info(), *input2->info(), *output->info()); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), (output != nullptr) ? output->info() : nullptr); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); static std::map map_function = @@ -947,14 +964,20 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor _input1 = input1; _input2 = input2; - _output = output; + _output = input1; _policy = policy; + // Out-of-place calculation + if(output != nullptr) + { + _output = output; + } + std::string function_to_call("add_"); function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_"; function_to_call += string_from_data_type(input1->info()->data_type()) + "_"; function_to_call += string_from_data_type(input2->info()->data_type()) + "_"; - function_to_call += string_from_data_type(output->info()->data_type()); + function_to_call += string_from_data_type(_output->info()->data_type()); auto it = map_function.find(function_to_call); @@ -968,10 +991,8 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor Status NEArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output); - - ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output, policy)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(*input1->clone(), *input2->clone(), *output->clone()).first); + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, (output != nullptr) ? output : nullptr, policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), (output != nullptr) ? output->clone().get() : nullptr).first); return Status{}; } -- cgit v1.2.1