From 173ba9bbb19ea83f951318d9989e440768b4de8f Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Tue, 23 Jun 2020 17:25:43 +0100 Subject: COMPMID-3373: Async support to NEArithmetic* kernels/functions (Pt. 1) Added support on NEArithmeticAddition and NEArithmeticSubtraction Signed-off-by: Michalis Spyrou Change-Id: Ifa805f8455ef6eff1ee627752dc1c7fe9740ec47 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3451 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas --- .../NEON/kernels/NEArithmeticAdditionKernel.cpp | 26 +++++++++------------- 1 file changed, 11 insertions(+), 15 deletions(-) (limited to 'src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp') diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp index 3878c764a6..1459f7f250 100644 --- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp +++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp @@ -853,7 +853,7 @@ Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, return Status{}; } -std::pair validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output) +std::pair validate_and_configure_window(const ITensorInfo &input1, const ITensorInfo &input2, ITensorInfo &output) { const std::pair broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2); const TensorShape &out_shape = broadcast_pair.first; @@ -904,17 +904,17 @@ std::pair validate_and_configure_window(ITensorInfo &input1, ITe } // namespace NEArithmeticAdditionKernel::NEArithmeticAdditionKernel() - : _func(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _policy() + : _func(nullptr), _policy() { } -void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy) +void NEArithmeticAdditionKernel::configure(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output, ConvertPolicy policy) { ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output); - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output, policy)); // Configure kernel window - auto win_config = validate_and_configure_window(*input1->info(), *input2->info(), *output->info()); + auto win_config = validate_and_configure_window(*input1, *input2, *output); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); static std::map map_function = @@ -945,16 +945,13 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ }; - _input1 = input1; - _input2 = input2; - _output = output; _policy = policy; std::string function_to_call("add_"); function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_"; - function_to_call += string_from_data_type(input1->info()->data_type()) + "_"; - function_to_call += string_from_data_type(input2->info()->data_type()) + "_"; - function_to_call += string_from_data_type(output->info()->data_type()); + function_to_call += string_from_data_type(input1->data_type()) + "_"; + function_to_call += string_from_data_type(input2->data_type()) + "_"; + function_to_call += string_from_data_type(output->data_type()); auto it = map_function.find(function_to_call); @@ -976,13 +973,12 @@ Status NEArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITe return Status{}; } -void NEArithmeticAdditionKernel::run(const Window &window, const ThreadInfo &info) +void NEArithmeticAdditionKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - ARM_COMPUTE_ERROR_ON(_func == nullptr); - - (*_func)(_input1, _input2, _output, _policy, window); + // Dispatch kernel + (*_func)(inputs.at(TensorType::ACL_SRC_0), inputs.at(TensorType::ACL_SRC_1), outputs.at(TensorType::ACL_DST), _policy, window); } } // namespace arm_compute -- cgit v1.2.1