From 1fd2c80692ed8ecefc4d8deb783564ad19eaf70c Mon Sep 17 00:00:00 2001 From: Georgios Pinitas Date: Tue, 16 Jun 2020 17:44:46 +0100 Subject: COMPMID-3375: Port NEActivationLayer functions/kernels to run on different tensors. Signed-off-by: Georgios Pinitas Change-Id: I98782bb73e9dc0899ffb1796aca6f99714adea94 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3343 Reviewed-by: Michalis Spyrou Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/core/NEON/kernels/NEActivationLayerKernel.cpp | 72 +++++++++++------------ src/core/NEON/kernels/NEReshapeLayerKernel.cpp | 13 ++-- src/runtime/CPP/CPPScheduler.cpp | 8 +-- src/runtime/CPP/SingleThreadScheduler.cpp | 2 +- src/runtime/NEON/INEOperator.cpp | 4 +- src/runtime/NEON/functions/NEActivationLayer.cpp | 61 +++++++++++++++++-- src/runtime/NEON/functions/NELSTMLayer.cpp | 32 +++++----- src/runtime/NEON/functions/NERNNLayer.cpp | 10 ++-- src/runtime/NEON/functions/NEReshapeLayer.cpp | 42 ++++++++----- src/runtime/OMP/OMPScheduler.cpp | 2 +- 10 files changed, 154 insertions(+), 92 deletions(-) (limited to 'src') diff --git a/src/core/NEON/kernels/NEActivationLayerKernel.cpp b/src/core/NEON/kernels/NEActivationLayerKernel.cpp index ffbfd710f9..2c00a76305 100644 --- a/src/core/NEON/kernels/NEActivationLayerKernel.cpp +++ b/src/core/NEON/kernels/NEActivationLayerKernel.cpp @@ -95,7 +95,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c return Status{}; } -std::pair validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) +std::pair validate_and_configure_window(const ITensorInfo *input, ITensorInfo *output) { // Configure kernel window Window win = calculate_max_window(*input, Steps()); @@ -116,23 +116,15 @@ std::pair validate_and_configure_window(ITensorInfo *input, ITen } // namespace NEActivationLayerKernel::NEActivationLayerKernel() - : _input(nullptr), _output(nullptr), _func(nullptr), _act_info() + : _func(nullptr), _act_info() { } -void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, ActivationLayerInfo activation_info) +void NEActivationLayerKernel::configure(const ITensorInfo *input, ITensorInfo *output, ActivationLayerInfo activation_info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(input); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); - _input = input; _act_info = activation_info; - _output = input; - - // Out-of-place calculation - if(output != nullptr) - { - _output = output; - } // Disabled activation, thus no operation needed if(!activation_info.enabled()) @@ -140,7 +132,7 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat _func = nullptr; } - ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), (output != nullptr) ? output->info() : nullptr, activation_info)); + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input, output, activation_info)); // Activation functions : FP32 static std::map act_map_f32 = @@ -218,7 +210,7 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat }; - switch(input->info()->data_type()) + switch(input->data_type()) { case DataType::QASYMM8_SIGNED: _func = act_map_qasymm8_signed[activation_info.activation()]; @@ -242,14 +234,14 @@ void NEActivationLayerKernel::configure(ITensor *input, ITensor *output, Activat } // Configure kernel window - auto win_config = validate_and_configure_window(input->info(), (output != nullptr) ? output->info() : nullptr); + auto win_config = validate_and_configure_window(input, output); ARM_COMPUTE_ERROR_THROW_ON(win_config.first); ICPPKernel::configure(win_config.second); } template typename std::enable_if::value, void>::type -NEActivationLayerKernel::activation(const Window &window) +NEActivationLayerKernel::activation(const ITensor *src, ITensor *dst, const Window &window) { /** NEON vector tag type. */ using ExactTagType = typename wrapper::traits::neon_bitvector_tag_t; @@ -262,16 +254,16 @@ NEActivationLayerKernel::activation(const Window &window) Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - Iterator input(_input, win_collapsed); - Iterator output(_output, win_collapsed); + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); // A small delta added to the input to prevent NAN values caused by zeros in inputs to SQRT #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - const auto delta = wrapper::vdup_n(static_cast(1e-7), ExactTagType{}); + const auto delta = wrapper::vdup_n(static_cast(1e-7), ExactTagType {}); #else /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - const auto delta = wrapper::vdup_n(static_cast(1e-24), ExactTagType{}); + const auto delta = wrapper::vdup_n(static_cast(1e-24), ExactTagType {}); #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */ - const auto const_1 = wrapper::vdup_n(static_cast(1.f), ExactTagType{}); + const auto const_1 = wrapper::vdup_n(static_cast(1.f), ExactTagType {}); const auto const_0 = wrapper::vdup_n(static_cast(0.f), ExactTagType{}); const auto const_6 = wrapper::vdup_n(static_cast(6.f), ExactTagType{}); const auto const_3 = wrapper::vdup_n(static_cast(3.f), ExactTagType{}); @@ -402,7 +394,7 @@ NEActivationLayerKernel::activation(const Window &window) } template -typename std::enable_if::value, void>::type NEActivationLayerKernel::activation(const Window &window) +typename std::enable_if::value, void>::type NEActivationLayerKernel::activation(const ITensor *src, ITensor *dst, const Window &window) { const int window_step_x = 16 / sizeof(T); const auto window_start_x = static_cast(window.x().start()); @@ -412,11 +404,11 @@ typename std::enable_if::value, void>::type NEActivat Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - Iterator input(_input, win_collapsed); - Iterator output(_output, win_collapsed); + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); - const UniformQuantizationInfo qi_in = _input->info()->quantization_info().uniform(); - const UniformQuantizationInfo qi_out = _output->info()->quantization_info().uniform(); + const UniformQuantizationInfo qi_in = src->info()->quantization_info().uniform(); + const UniformQuantizationInfo qi_out = dst->info()->quantization_info().uniform(); const qasymm8x16_t va = vdupq_n_u8(quantize_qasymm8(_act_info.a(), qi_in)); const qasymm8x16_t vb = vdupq_n_u8(quantize_qasymm8(_act_info.b(), qi_in)); const qasymm8_t a = quantize_qasymm8(_act_info.a(), qi_in); @@ -579,7 +571,7 @@ typename std::enable_if::value, void>::type NEActivat } template -typename std::enable_if::value, void>::type NEActivationLayerKernel::activation(const Window &window) +typename std::enable_if::value, void>::type NEActivationLayerKernel::activation(const ITensor *src, ITensor *dst, const Window &window) { const int window_step_x = 16 / sizeof(T); const auto window_start_x = static_cast(window.x().start()); @@ -589,11 +581,11 @@ typename std::enable_if::value, void>::type NE Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - Iterator input(_input, win_collapsed); - Iterator output(_output, win_collapsed); + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); - const UniformQuantizationInfo qi_in = _input->info()->quantization_info().uniform(); - const UniformQuantizationInfo qi_out = _output->info()->quantization_info().uniform(); + const UniformQuantizationInfo qi_in = src->info()->quantization_info().uniform(); + const UniformQuantizationInfo qi_out = dst->info()->quantization_info().uniform(); const qasymm8x16_signed_t va = vdupq_n_s8(quantize_qasymm8_signed(_act_info.a(), qi_in)); const qasymm8x16_signed_t vb = vdupq_n_s8(quantize_qasymm8_signed(_act_info.b(), qi_in)); const qasymm8_signed_t a = quantize_qasymm8_signed(_act_info.a(), qi_in); @@ -756,7 +748,7 @@ typename std::enable_if::value, void>::type NE } template -typename std::enable_if::value, void>::type NEActivationLayerKernel::activation(const Window &window) +typename std::enable_if::value, void>::type NEActivationLayerKernel::activation(const ITensor *src, ITensor *dst, const Window &window) { const int window_step_x = 16 / sizeof(T); const auto window_start_x = static_cast(window.x().start()); @@ -766,11 +758,11 @@ typename std::enable_if::value, void>::type NEActivat Window win_collapsed = window.collapse_if_possible(window, Window::DimZ); win_collapsed.set(Window::DimX, Window::Dimension(0, 1, 1)); - Iterator input(_input, win_collapsed); - Iterator output(_output, win_collapsed); + Iterator input(src, win_collapsed); + Iterator output(dst, win_collapsed); - const UniformQuantizationInfo qi_in = _input->info()->quantization_info().uniform(); - const UniformQuantizationInfo qi_out = _output->info()->quantization_info().uniform(); + const UniformQuantizationInfo qi_in = src->info()->quantization_info().uniform(); + const UniformQuantizationInfo qi_out = dst->info()->quantization_info().uniform(); const auto vconst_1 = vdupq_n_f32(1.f); const float32x4_t va_f32 = vdupq_n_f32(_act_info.a()); const float32x4_t vb_f32 = vdupq_n_f32(_act_info.b()); @@ -863,7 +855,9 @@ Status NEActivationLayerKernel::validate(const ITensorInfo *input, const ITensor return Status{}; } -void NEActivationLayerKernel::run(const Window &window, const ThreadInfo &info) +void NEActivationLayerKernel::run_op(const std::vector &inputs, + const std::vector &outputs, + const Window &window, const ThreadInfo &info) { // Early exit on disabled activation if(!_act_info.enabled()) @@ -876,5 +870,7 @@ void NEActivationLayerKernel::run(const Window &window, const ThreadInfo &info) ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window); ARM_COMPUTE_ERROR_ON(_func == nullptr); - (this->*_func)(window); + ARM_COMPUTE_ERROR_ON(inputs.empty() || outputs.empty()); + + (this->*_func)(inputs[0].tensor, outputs[0].tensor, window); } diff --git a/src/core/NEON/kernels/NEReshapeLayerKernel.cpp b/src/core/NEON/kernels/NEReshapeLayerKernel.cpp index 600f8f9bf1..c141eecf75 100644 --- a/src/core/NEON/kernels/NEReshapeLayerKernel.cpp +++ b/src/core/NEON/kernels/NEReshapeLayerKernel.cpp @@ -86,29 +86,32 @@ void NEReshapeLayerKernel::configure(const ITensorInfo *input, ITensorInfo *outp INEKernel::configure(win); } -void NEReshapeLayerKernel::run_op(const std::vector &inputs, std::vector &outputs, const Window &window, const ThreadInfo &info) +void NEReshapeLayerKernel::run_op(const std::vector &inputs, const std::vector &outputs, const Window &window, const ThreadInfo &info) { ARM_COMPUTE_UNUSED(info); ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(INEKernel::window(), window); - switch(inputs[0]->second->info()->data_type()) + const auto src = inputs[0].tensor; + auto dst = outputs[0].tensor; + + switch(src->info()->data_type()) { case DataType::U8: case DataType::S8: case DataType::QASYMM8: case DataType::QASYMM8_SIGNED: - reshape_tensor(window, inputs[0]->second, outputs[0]->second); + reshape_tensor(window, src, dst); break; case DataType::U16: case DataType::S16: case DataType::F16: - reshape_tensor(window, inputs[0]->second, outputs[0]->second); + reshape_tensor(window, src, dst); break; case DataType::U32: case DataType::S32: case DataType::F32: - reshape_tensor(window, inputs[0]->second, outputs[0]->second); + reshape_tensor(window, src, dst); break; default: ARM_COMPUTE_ERROR("Unsupported data type!"); diff --git a/src/runtime/CPP/CPPScheduler.cpp b/src/runtime/CPP/CPPScheduler.cpp index db551590ea..41e1a2d647 100644 --- a/src/runtime/CPP/CPPScheduler.cpp +++ b/src/runtime/CPP/CPPScheduler.cpp @@ -363,7 +363,7 @@ void CPPScheduler::run_workloads(std::vector &workloads) } #endif /* DOXYGEN_SKIP_THIS */ -void CPPScheduler::schedule_common(ICPPKernel *kernel, const Hints &hints, std::vector &inputs, std::vector &outputs) +void CPPScheduler::schedule_common(ICPPKernel *kernel, const Hints &hints, const std::vector &inputs, const std::vector &outputs) { ARM_COMPUTE_ERROR_ON_MSG(!kernel, "The child class didn't set the kernel"); @@ -473,15 +473,15 @@ void CPPScheduler::schedule_common(ICPPKernel *kernel, const Hints &hints, std:: } } -void CPPScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, std::vector &inputs, std::vector &outputs) +void CPPScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, const std::vector &inputs, const std::vector &outputs) { schedule_common(kernel, hints, inputs, outputs); } void CPPScheduler::schedule(ICPPKernel *kernel, const Hints &hints) { - std::vector inputs; - std::vector outputs; + const std::vector inputs; + std::vector outputs; schedule_common(kernel, hints, inputs, outputs); } } // namespace arm_compute diff --git a/src/runtime/CPP/SingleThreadScheduler.cpp b/src/runtime/CPP/SingleThreadScheduler.cpp index 777f84bec8..8257628090 100644 --- a/src/runtime/CPP/SingleThreadScheduler.cpp +++ b/src/runtime/CPP/SingleThreadScheduler.cpp @@ -49,7 +49,7 @@ void SingleThreadScheduler::schedule(ICPPKernel *kernel, const Hints &hints) kernel->run(kernel->window(), info); } -void SingleThreadScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, std::vector &inputs, std::vector &outputs) +void SingleThreadScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, const std::vector &inputs, const std::vector &outputs) { ARM_COMPUTE_UNUSED(hints); ThreadInfo info; diff --git a/src/runtime/NEON/INEOperator.cpp b/src/runtime/NEON/INEOperator.cpp index c24d5c47f1..78790856ee 100644 --- a/src/runtime/NEON/INEOperator.cpp +++ b/src/runtime/NEON/INEOperator.cpp @@ -33,7 +33,7 @@ INEOperator::INEOperator(IRuntimeContext *ctx) { } -void INEOperator::run(std::vector &inputs, std::vector &outputs, std::vector &workspace) +void INEOperator::run(std::vector inputs, std::vector outputs, std::vector workspace) { ARM_COMPUTE_UNUSED(workspace); @@ -45,7 +45,7 @@ void INEOperator::run(std::vector &inputs, std::vector constants) +void INEOperator::prepare(std::vector constants) { ARM_COMPUTE_UNUSED(constants); } diff --git a/src/runtime/NEON/functions/NEActivationLayer.cpp b/src/runtime/NEON/functions/NEActivationLayer.cpp index e4d1125c79..889ff6b1f4 100644 --- a/src/runtime/NEON/functions/NEActivationLayer.cpp +++ b/src/runtime/NEON/functions/NEActivationLayer.cpp @@ -23,25 +23,76 @@ */ #include "arm_compute/runtime/NEON/functions/NEActivationLayer.h" +#include "arm_compute/core/Error.h" #include "arm_compute/core/NEON/kernels/NEActivationLayerKernel.h" +#include "arm_compute/core/experimental/Types.h" #include "arm_compute/runtime/IRuntimeContext.h" +#include "arm_compute/runtime/Tensor.h" #include "support/MemorySupport.h" namespace arm_compute { -NEActivationLayer::NEActivationLayer(IRuntimeContext *ctx) // NOLINT - : INESimpleFunctionNoBorder(ctx) +namespace experimental { -} -void NEActivationLayer::configure(ITensor *input, ITensor *output, ActivationLayerInfo activation_info) +void NEActivationLayer::configure(const ITensorInfo *input, ITensorInfo *output, const ActivationLayerInfo &activation_info) { auto k = arm_compute::support::cpp14::make_unique(); k->configure(input, output, activation_info); _kernel = std::move(k); } +Status NEActivationLayer::validate(const ITensorInfo *input, const ITensorInfo *output, const ActivationLayerInfo &activation_info) +{ + return NEActivationLayerKernel::validate(input, output, activation_info); +} + +MemoryRequirements NEActivationLayer::workspace() const +{ + return MemoryRequirements{}; +} +} // namespace experimental + +struct NEActivationLayer::Impl +{ + const ITensor *src{ nullptr }; + ITensor *dst{ nullptr }; + IRuntimeContext *ctx{ nullptr }; + std::unique_ptr op{ nullptr }; +}; + +NEActivationLayer::NEActivationLayer(IRuntimeContext *ctx) + : _impl(support::cpp14::make_unique()) +{ + _impl->ctx = ctx; +} + +NEActivationLayer::NEActivationLayer(NEActivationLayer &&) = default; + +NEActivationLayer &NEActivationLayer::operator=(NEActivationLayer &&) = default; + +NEActivationLayer::~NEActivationLayer() = default; + +void NEActivationLayer::configure(ITensor *input, ITensor *output, ActivationLayerInfo activation_info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input); + + _impl->src = input; + _impl->dst = output == nullptr ? input : output; + + _impl->op = arm_compute::support::cpp14::make_unique(); + _impl->op->configure(_impl->src->info(), _impl->dst->info(), activation_info); +} + Status NEActivationLayer::validate(const ITensorInfo *input, const ITensorInfo *output, const ActivationLayerInfo &act_info) { - return NEActivationLayerKernel::validate(input, output, act_info); + return experimental::NEActivationLayer::validate(input, output, act_info); +} + +void NEActivationLayer::run() +{ + const InputTensor src{ TensorType::ACL_SRC, _impl->src }; + OutputTensor dst{ TensorType::ACL_DST, _impl->dst }; + + _impl->op->run({ src }, { dst }, {}); } } // namespace arm_compute diff --git a/src/runtime/NEON/functions/NELSTMLayer.cpp b/src/runtime/NEON/functions/NELSTMLayer.cpp index f9d445fe71..0a111363e3 100644 --- a/src/runtime/NEON/functions/NELSTMLayer.cpp +++ b/src/runtime/NEON/functions/NELSTMLayer.cpp @@ -474,7 +474,7 @@ Status NELSTMLayer::validate(const ITensorInfo *input, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&forget_gate, forget_gate_bias, &forget_gate, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&forget_gate, &forget_gate, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); // Validate input gate if(!lstm_params.has_cifg_opt()) @@ -508,7 +508,7 @@ Status NELSTMLayer::validate(const ITensorInfo *input, ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&input_gate, lstm_params.input_layer_norm_weights(), &input_gate, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&input_gate, lstm_params.input_gate_bias(), &input_gate, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&input_gate, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); } else { @@ -526,14 +526,14 @@ Status NELSTMLayer::validate(const ITensorInfo *input, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, cell_bias, &cell_state_tmp, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&cell_state_tmp, nullptr, activation_info)); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, activation_info)); ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &input_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &forget_gate, &cell_state_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&cell_state_tmp, &cell_state_tmp, &cell_state_tmp, ConvertPolicy::SATURATE)); if(cell_threshold != 0.f) { - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, - cell_threshold))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -cell_threshold, + cell_threshold))); } // Validate output gate tmp @@ -559,18 +559,18 @@ Status NELSTMLayer::validate(const ITensorInfo *input, RoundingPolicy::TO_ZERO)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAddition::validate(&output_gate_tmp, output_gate_bias, &output_gate_tmp, ConvertPolicy::SATURATE)); } - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&output_gate_tmp, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LOGISTIC))); // Validate output state - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&cell_state_tmp, &cell_state_tmp, activation_info)); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&cell_state_tmp, &cell_state_tmp, activation_info)); ARM_COMPUTE_RETURN_ON_ERROR(NEPixelWiseMultiplicationKernel::validate(&cell_state_tmp, &output_gate_tmp, &output_gate_tmp, 1, ConvertPolicy::SATURATE, RoundingPolicy::TO_ZERO)); if(lstm_params.has_projection()) { ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(&output_gate_tmp, lstm_params.projection_weights(), lstm_params.projection_bias(), output_state_out)); if(projection_threshold != 0.f) { - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(output_state_out, output_state_out, - ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold))); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output_state_out, output_state_out, + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, -projection_threshold, projection_threshold))); } } @@ -612,7 +612,7 @@ void NELSTMLayer::run() NEScheduler::get().schedule(&_pixelwise_mul_forget_gate_coeff, Window::DimY); NEScheduler::get().schedule(&_accum_forget_gate_bias, Window::DimY); } - NEScheduler::get().schedule(&_activation_forget_gate, Window::DimY); + _activation_forget_gate.run(); if(_run_cifg_opt) { @@ -642,7 +642,7 @@ void NELSTMLayer::run() NEScheduler::get().schedule(&_pixelwise_mul_input_gate_coeff, Window::DimY); NEScheduler::get().schedule(&_accum_input_gate_bias, Window::DimY); } - NEScheduler::get().schedule(&_activation_input_gate, Window::DimY); + _activation_input_gate.run(); } _fully_connected_cell_state.run(); @@ -655,14 +655,14 @@ void NELSTMLayer::run() NEScheduler::get().schedule(&_pixelwise_mul_cell_gate_coeff, Window::DimY); NEScheduler::get().schedule(&_accum_cell_gate_bias, Window::DimY); } - NEScheduler::get().schedule(&_activation_cell_state, Window::DimY); + _activation_cell_state.run(); NEScheduler::get().schedule(&_pixelwise_mul_cell_state1, Window::DimY); NEScheduler::get().schedule(&_pixelwise_mul_cell_state2, Window::DimY); NEScheduler::get().schedule(&_accum_cell_state2, Window::DimY); if(_perform_cell_clipping) { - NEScheduler::get().schedule(&_cell_clip, Window::DimY); + _cell_clip.run(); } _fully_connected_output.run(); @@ -677,9 +677,9 @@ void NELSTMLayer::run() NEScheduler::get().schedule(&_pixelwise_mul_output_gate_coeff, Window::DimY); NEScheduler::get().schedule(&_accum_output_gate_bias, Window::DimY); } - NEScheduler::get().schedule(&_activation_output, Window::DimY); + _activation_output.run(); - NEScheduler::get().schedule(&_activation_output_state, Window::DimY); + _activation_output_state.run(); NEScheduler::get().schedule(&_pixelwise_mul_output_state2, Window::DimY); if(_has_projection_weights) @@ -687,7 +687,7 @@ void NELSTMLayer::run() _fully_connected_output_state.run(); if(_perform_projection_clipping) { - NEScheduler::get().schedule(&_projection_clip, Window::DimY); + _projection_clip.run(); } } diff --git a/src/runtime/NEON/functions/NERNNLayer.cpp b/src/runtime/NEON/functions/NERNNLayer.cpp index 154b060c3d..4a15777be9 100644 --- a/src/runtime/NEON/functions/NERNNLayer.cpp +++ b/src/runtime/NEON/functions/NERNNLayer.cpp @@ -34,8 +34,8 @@ namespace arm_compute { NERNNLayer::NERNNLayer(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation_kernel(), _fully_connected(memory_manager), _copy_kernel(), _fully_connected_out(), _gemm_output(), - _add_output(), _is_prepared(false) + : _memory_group(std::move(memory_manager)), _gemm_state_f(), _add_kernel(), _activation(), _fully_connected(memory_manager), _copy_kernel(), _fully_connected_out(), _gemm_output(), _add_output(), + _is_prepared(false) { } @@ -60,7 +60,7 @@ Status NERNNLayer::validate(const ITensorInfo *input, const ITensorInfo *weights ARM_COMPUTE_RETURN_ON_ERROR(NEFullyConnectedLayer::validate(input, weights, bias, &shape_info)); ARM_COMPUTE_RETURN_ON_ERROR(NEArithmeticAdditionKernel::validate(&shape_info, &shape_info, &shape_info, ConvertPolicy::SATURATE)); - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayerKernel::validate(&shape_info, &shape_info, info)); + ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(&shape_info, &shape_info, info)); return Status{}; } @@ -95,7 +95,7 @@ void NERNNLayer::configure(const ITensor *input, const ITensor *weights, const I _fully_connected_out.allocator()->allocate(); _gemm_output.allocator()->allocate(); - _activation_kernel.configure(&_add_output, hidden_state, info); + _activation.configure(&_add_output, hidden_state, info); _add_output.allocator()->allocate(); _copy_kernel.configure(hidden_state, output); @@ -112,7 +112,7 @@ void NERNNLayer::run() _gemm_state_f.run(); NEScheduler::get().schedule(&_add_kernel, Window::DimY); - NEScheduler::get().schedule(&_activation_kernel, Window::DimY); + _activation.run(); // copy hidden out to output NEScheduler::get().schedule(&_copy_kernel, Window::DimY); diff --git a/src/runtime/NEON/functions/NEReshapeLayer.cpp b/src/runtime/NEON/functions/NEReshapeLayer.cpp index 680abef026..daf358e7db 100644 --- a/src/runtime/NEON/functions/NEReshapeLayer.cpp +++ b/src/runtime/NEON/functions/NEReshapeLayer.cpp @@ -44,7 +44,7 @@ void NEReshapeLayer::configure(const ITensorInfo *input, ITensorInfo *output) Status NEReshapeLayer::validate(const ITensorInfo *input, const ITensorInfo *output) { - return arm_compute::NEReshapeLayer::validate(input, output); + return arm_compute::NEReshapeLayerKernel::validate(input, output); } MemoryRequirements NEReshapeLayer::workspace() const @@ -53,32 +53,44 @@ MemoryRequirements NEReshapeLayer::workspace() const } } // namespace experimental -void NEReshapeLayer::configure(const ITensor *input, ITensor *output) +struct NEReshapeLayer::Impl { - _input = input; - _output = output; + const ITensor *src{ nullptr }; + ITensor *dst{ nullptr }; + std::unique_ptr op{ nullptr }; +}; - auto k = arm_compute::support::cpp14::make_unique(); - k->configure(input->info(), output->info()); - _kernel = std::move(k); +NEReshapeLayer::NEReshapeLayer() + : _impl(support::cpp14::make_unique()) +{ +} + +NEReshapeLayer::NEReshapeLayer(NEReshapeLayer &&) = default; + +NEReshapeLayer &NEReshapeLayer::operator=(NEReshapeLayer &&) = default; + +NEReshapeLayer::~NEReshapeLayer() = default; + +void NEReshapeLayer::configure(const ITensor *input, ITensor *output) +{ + _impl->src = input; + _impl->dst = output; + _impl->op = arm_compute::support::cpp14::make_unique(); + _impl->op->configure(input->info(), output->info()); } Status NEReshapeLayer::validate(const ITensorInfo *input, const ITensorInfo *output) { ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); - ARM_COMPUTE_RETURN_ON_ERROR(NEReshapeLayerKernel::validate(input, output)); + ARM_COMPUTE_RETURN_ON_ERROR(experimental::NEReshapeLayer::validate(input, output)); return Status{}; } void NEReshapeLayer::run() { - InputOperatorTensors src_0 = std::make_pair(TensorType::ACL_SRC, _input); - OutputOperatorTensors dst_0 = std::make_pair(TensorType::ACL_DST, _output); - - std::vector inputs = { &src_0 }; - std::vector outputs = { &dst_0 }; - - NEScheduler::get().schedule_op(_kernel.get(), Window::DimY, inputs, outputs); + const InputTensor src{ TensorType::ACL_SRC, _impl->src }; + OutputTensor dst{ TensorType::ACL_DST, _impl->dst }; + _impl->op->run({ src }, { dst }, {}); } } // namespace arm_compute diff --git a/src/runtime/OMP/OMPScheduler.cpp b/src/runtime/OMP/OMPScheduler.cpp index a1851f03c3..6d6b285019 100644 --- a/src/runtime/OMP/OMPScheduler.cpp +++ b/src/runtime/OMP/OMPScheduler.cpp @@ -83,7 +83,7 @@ void OMPScheduler::schedule(ICPPKernel *kernel, const Hints &hints) } } -void OMPScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, std::vector &inputs, std::vector &outputs) +void OMPScheduler::schedule_op(ICPPKernel *kernel, const Hints &hints, const std::vector &inputs, const std::vector &outputs) { ARM_COMPUTE_ERROR_ON_MSG(!kernel, "The child class didn't set the kernel"); ARM_COMPUTE_ERROR_ON_MSG(hints.strategy() == StrategyHint::DYNAMIC, -- cgit v1.2.1