aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp')
-rw-r--r--src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp26
1 files changed, 11 insertions, 15 deletions
diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
index 3878c764a6..1459f7f250 100644
--- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
@@ -853,7 +853,7 @@ Status validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2,
return Status{};
}
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo &input1, ITensorInfo &input2, ITensorInfo &output)
+std::pair<Status, Window> validate_and_configure_window(const ITensorInfo &input1, const ITensorInfo &input2, ITensorInfo &output)
{
const std::pair<TensorShape, ValidRegion> broadcast_pair = ITensorInfo::broadcast_shape_and_valid_region(input1, input2);
const TensorShape &out_shape = broadcast_pair.first;
@@ -904,17 +904,17 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo &input1, ITe
} // namespace
NEArithmeticAdditionKernel::NEArithmeticAdditionKernel()
- : _func(nullptr), _input1(nullptr), _input2(nullptr), _output(nullptr), _policy()
+ : _func(nullptr), _policy()
{
}
-void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy)
+void NEArithmeticAdditionKernel::configure(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output, ConvertPolicy policy)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1->info(), *input2->info(), *output->info(), policy));
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output, policy));
// Configure kernel window
- auto win_config = validate_and_configure_window(*input1->info(), *input2->info(), *output->info());
+ auto win_config = validate_and_configure_window(*input1, *input2, *output);
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
static std::map<std::string, AddFunction *> map_function =
@@ -945,16 +945,13 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
};
- _input1 = input1;
- _input2 = input2;
- _output = output;
_policy = policy;
std::string function_to_call("add_");
function_to_call += policy == ConvertPolicy::WRAP ? "wrap_" : "saturate_";
- function_to_call += string_from_data_type(input1->info()->data_type()) + "_";
- function_to_call += string_from_data_type(input2->info()->data_type()) + "_";
- function_to_call += string_from_data_type(output->info()->data_type());
+ function_to_call += string_from_data_type(input1->data_type()) + "_";
+ function_to_call += string_from_data_type(input2->data_type()) + "_";
+ function_to_call += string_from_data_type(output->data_type());
auto it = map_function.find(function_to_call);
@@ -976,13 +973,12 @@ Status NEArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITe
return Status{};
}
-void NEArithmeticAdditionKernel::run(const Window &window, const ThreadInfo &info)
+void NEArithmeticAdditionKernel::run_op(const InputTensorMap &inputs, const OutputTensorMap &outputs, const Window &window, const ThreadInfo &info)
{
ARM_COMPUTE_UNUSED(info);
ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window);
- ARM_COMPUTE_ERROR_ON(_func == nullptr);
-
- (*_func)(_input1, _input2, _output, _policy, window);
+ // Dispatch kernel
+ (*_func)(inputs.at(TensorType::ACL_SRC_0), inputs.at(TensorType::ACL_SRC_1), outputs.at(TensorType::ACL_DST), _policy, window);
}
} // namespace arm_compute