aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2021-04-08 12:02:58 +0100
committerMichalis Spyrou <michalis.spyrou@arm.com>2021-04-19 13:45:08 +0000
commit60c3b0e6821a80d78ffca5be30e05d062d071cd2 (patch)
tree3e263a45aa9617cfd7704b2b33ea4337f1582321 /src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
parent4f1650f0c9919f0bac5024b8e31c0f754d25aec3 (diff)
downloadComputeLibrary-60c3b0e6821a80d78ffca5be30e05d062d071cd2.tar.gz
Port DepthwiseConvolution to new API
Resolves: COMPMID-4185 Change-Id: Ib5f22356356a022d567bb18d44ea272b62d10ebf Signed-off-by: Michalis Spyrou <michalis.spyrou@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5424 Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp')
-rw-r--r--src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp409
1 files changed, 176 insertions, 233 deletions
diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
index d17f6b5cd9..e1ceb0f083 100644
--- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
+++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -27,54 +27,39 @@
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
-#include "src/core/NEON/kernels/NEDepthwiseConvolutionLayerNativeKernel.h"
+#include "src/runtime/cpu/operators/CpuDepthwiseConvolution.h"
using namespace arm_compute::misc;
using namespace arm_compute::misc::shape_calculator;
namespace arm_compute
{
-namespace
-{
-Status validate_arguments_optimized(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
- unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation)
+NEDepthwiseConvolutionLayer::~NEDepthwiseConvolutionLayer() = default;
+
+struct NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::Impl
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32);
- if(!is_data_type_quantized_per_channel(weights->data_type()))
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights);
- }
- ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN);
- ARM_COMPUTE_RETURN_ERROR_ON(dilation.x() < 1 || dilation.y() < 1);
- const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH);
- const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT);
- ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_w) + (weights->dimension(idx_w) - 1) * (dilation.x() - 1) > input->dimension(idx_w) + conv_info.pad_left() + conv_info.pad_right());
- ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_h) + (weights->dimension(idx_h) - 1) * (dilation.y() - 1) > input->dimension(idx_h) + conv_info.pad_top() + conv_info.pad_bottom());
-
- if(biases != nullptr)
+ ITensor *src{ nullptr }; // SRC_0
+ ITensor *dst{ nullptr }; // DST_0
+ const ITensor *weights
{
- const unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL);
- ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1);
- ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(channel_idx));
- }
-
- ARM_COMPUTE_RETURN_ON_ERROR(NEDepthwiseConvolutionAssemblyDispatch::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation));
-
- //Validate Activation Layer
- if(act_info.enabled())
+ nullptr
+ }; // SRC_1
+ const ITensor *biases
{
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info));
- }
- return Status{};
-}
-} // namespace
-
-NEDepthwiseConvolutionLayer::~NEDepthwiseConvolutionLayer() = default;
+ nullptr
+ }; // SRC_2
+ Tensor permuted_input{}; // INT_0
+ Tensor permuted_weights{}; // INT_1
+ Tensor permuted_output{}; // INT_2
+ Tensor workspace{}; // INT_3
+ Tensor packed_weights{}; // INT_4
+ std::shared_ptr<cpu::CpuDepthwiseConvolution> op{ nullptr };
+ bool is_prepared{ false };
+ bool permute{ false };
+};
NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::NEDepthwiseConvolutionLayerOptimizedInternal(std::shared_ptr<IMemoryManager> memory_manager)
- : _memory_group(memory_manager), _dwc_optimized_func(memory_manager), _permute_input(), _permute_weights(), _permute_output(), _activationlayer_function(), _accumulator(), _permuted_input(),
- _permuted_weights(), _permuted_output(), _original_weights(nullptr), _has_bias(false), _is_quantized(false), _is_nchw(true), _permute(false), _is_activationlayer_enabled(false), _is_prepared(false)
+ : _memory_group(memory_manager), _impl(std::make_unique<Impl>())
{
}
@@ -87,65 +72,76 @@ void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::
const Size2D &dilation)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output);
- // Perform validation step
- ARM_COMPUTE_ERROR_THROW_ON(NEDepthwiseConvolutionLayerOptimizedInternal::validate(input->info(), weights->info(), (biases == nullptr) ? nullptr : biases->info(),
- output->info(), conv_info, depth_multiplier, act_info, dilation));
-
- _original_weights = weights;
- _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type());
- _has_bias = biases != nullptr;
- _is_nchw = input->info()->data_layout() == DataLayout::NCHW;
- _permute = _is_nchw;
- _is_prepared = false;
- _is_activationlayer_enabled = act_info.enabled();
+
+ bool is_nhwc = input->info()->data_layout() == DataLayout::NCHW;
+ _impl->src = input;
+ _impl->weights = weights;
+ _impl->biases = biases;
+ _impl->dst = output;
+ _impl->permute = is_nhwc;
+
+ _impl->op = std::make_unique<cpu::CpuDepthwiseConvolution>();
+ ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation };
+ _impl->op->configure(_impl->src->info(), _impl->weights->info(), _impl->biases == nullptr ? nullptr : _impl->biases->info(),
+ _impl->dst->info(), info);
// Configure pipeline
- ActivationLayerInfo act_info_to_use = ActivationLayerInfo();
- const bool is_relu = arm_compute::utils::info_helpers::is_relu(act_info);
- const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(act_info);
- _is_activationlayer_enabled = act_info.enabled() && !(is_relu || is_relu6);
- if(!_is_activationlayer_enabled)
+ ActivationLayerInfo act_info_to_use = ActivationLayerInfo();
+ const bool is_relu = arm_compute::utils::info_helpers::is_relu(act_info);
+ const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(act_info);
+ bool is_activationlayer_enabled = act_info.enabled() && !(is_relu || is_relu6);
+
+ if(!is_activationlayer_enabled)
{
act_info_to_use = act_info;
}
+ info = ConvolutionInfo{ conv_info, depth_multiplier, act_info_to_use, dilation };
- if(_is_nchw)
+ auto dwc_optimized_func = std::make_unique<cpu::CpuDepthwiseConvolutionAssemblyDispatch>();
+
+ if(is_nhwc)
{
- _memory_group.manage(&_permuted_input);
- _memory_group.manage(&_permuted_output);
+ auto permute_input = std::make_unique<cpu::CpuPermute>();
+ auto permute_weights = std::make_unique<cpu::CpuPermute>();
+ auto permute_output = std::make_unique<cpu::CpuPermute>();
+
+ _memory_group.manage(&_impl->permuted_input);
+ _memory_group.manage(&_impl->permuted_weights);
+ _memory_group.manage(&_impl->permuted_output);
// Configure the function to transform the input tensor from NCHW -> NHWC
- _permute_input.configure(input, &_permuted_input, PermutationVector(2U, 0U, 1U));
- _permuted_input.info()->set_data_layout(DataLayout::NHWC);
+ permute_input->configure(input->info(), _impl->permuted_input.info(), PermutationVector(2U, 0U, 1U));
+ _impl->permuted_input.info()->set_data_layout(DataLayout::NHWC);
// Configure the function to transform the weights tensor from IHW -> HWI
- _permute_weights.configure(weights, &_permuted_weights, PermutationVector(2U, 0U, 1U));
- _permuted_weights.info()->set_data_layout(DataLayout::NHWC);
+ permute_weights->configure(weights->info(), _impl->permuted_weights.info(), PermutationVector(2U, 0U, 1U));
+ _impl->permuted_weights.info()->set_data_layout(DataLayout::NHWC);
- _permuted_output.info()->set_data_layout(DataLayout::NHWC);
- _permuted_output.info()->set_quantization_info(output->info()->quantization_info());
+ _impl->permuted_output.info()->set_data_layout(DataLayout::NHWC);
+ _impl->permuted_output.info()->set_quantization_info(output->info()->quantization_info());
// Configure optimized depthwise
- _dwc_optimized_func.configure(&_permuted_input, &_permuted_weights, biases, &_permuted_output, conv_info, depth_multiplier, act_info_to_use, dilation);
+ dwc_optimized_func->configure(_impl->permuted_input.info(), _impl->permuted_weights.info(), biases->info(), _impl->permuted_output.info(), info);
// Configure the function to transform the convoluted output to ACL's native ordering format NCHW
- _permuted_output.info()->set_data_layout(DataLayout::NHWC);
- _permute_output.configure(&_permuted_output, output, PermutationVector(1U, 2U, 0U));
+ _impl->permuted_output.info()->set_data_layout(DataLayout::NHWC);
+ permute_output->configure(_impl->permuted_output.info(), output->info(), PermutationVector(1U, 2U, 0U));
- // Allocate tensors
- _permuted_input.allocator()->allocate();
- _permuted_output.allocator()->allocate();
+ _impl->permuted_input.allocator()->allocate();
+ _impl->permuted_output.allocator()->allocate();
}
else
{
- _dwc_optimized_func.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info_to_use, dilation);
+ dwc_optimized_func->configure(_impl->src->info(), _impl->weights->info(), biases->info(), _impl->dst->info(), info);
}
- // Configure activation
- if(_is_activationlayer_enabled)
- {
- _activationlayer_function.configure(output, nullptr, act_info);
- }
+ // Allocate memory based on the internal memory requirements
+ experimental::MemoryRequirements mem_req = dwc_optimized_func->workspace();
+ _impl->workspace.allocator()->init(TensorInfo(TensorShape{ mem_req[0].size }, 1, DataType::S8), mem_req[0].alignment);
+ _impl->packed_weights.allocator()->init(TensorInfo(TensorShape{ mem_req[1].size }, 1, DataType::S8), mem_req[1].alignment);
+
+ _impl->workspace.allocator()->allocate();
+ _impl->packed_weights.allocator()->allocate();
}
Status NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::validate(const ITensorInfo *input,
@@ -157,63 +153,66 @@ Status NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal
const ActivationLayerInfo &act_info,
const Size2D &dilation)
{
- return validate_arguments_optimized(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation);
+ ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation };
+ return cpu::CpuDepthwiseConvolution::validate(input, weights, biases, output, info);
}
void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::run()
{
prepare();
-
MemoryGroupResourceScope scope_mg(_memory_group);
- // Permute input
- if(_permute)
- {
- _permute_input.run();
- }
-
- // Run assembly function
- _dwc_optimized_func.run();
-
- // Permute output
- if(_is_nchw)
- {
- _permute_output.run();
- }
-
- // Run activation
- if(_is_activationlayer_enabled)
- {
- _activationlayer_function.run();
- }
+ ITensorPack pack;
+ pack.add_tensor(TensorType::ACL_SRC_0, _impl->src);
+ pack.add_tensor(TensorType::ACL_SRC_1, _impl->weights);
+ pack.add_tensor(TensorType::ACL_SRC_2, _impl->biases);
+ pack.add_tensor(TensorType::ACL_INT_0, &_impl->permuted_input);
+ pack.add_tensor(TensorType::ACL_INT_1, &_impl->permuted_weights);
+ pack.add_tensor(TensorType::ACL_INT_2, &_impl->permuted_output);
+ pack.add_tensor(TensorType::ACL_INT_3, &_impl->workspace);
+ pack.add_tensor(TensorType::ACL_INT_4, &_impl->packed_weights);
+ pack.add_tensor(TensorType::ACL_DST_0, _impl->dst);
+
+ _impl->op->run(pack);
}
void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::prepare()
{
- if(!_is_prepared)
+ if(!_impl->is_prepared)
{
// Permute weights
- if(_permute)
+ if(_impl->permute)
{
- _permuted_weights.allocator()->allocate();
- _permute_weights.run();
- _original_weights->mark_as_unused();
+ _impl->permuted_weights.allocator()->allocate();
+ _impl->weights->mark_as_unused();
}
- // Prepare optimized function
- _dwc_optimized_func.prepare();
- if(!_permuted_weights.is_used())
+ if(!_impl->permuted_weights.is_used())
{
- _permuted_weights.allocator()->free();
+ _impl->permuted_weights.allocator()->free();
}
- _is_prepared = true;
+ _impl->is_prepared = true;
}
}
+struct NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::Impl
+{
+ Tensor permuted_input{};
+ Tensor permuted_weights{};
+ Tensor permuted_output{};
+ bool is_prepared{ false };
+ bool is_nchw{ false };
+ bool is_activationlayer_enabled{ false };
+ const ITensor *weights{ nullptr };
+ const ITensor *biases{ nullptr };
+ const ITensor *src{ nullptr };
+ ITensor *dst{ nullptr };
+ std::shared_ptr<cpu::CpuDepthwiseConvolution> op{ nullptr };
+};
+
NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::NEDepthwiseConvolutionLayerGeneric()
- : _depthwise_conv_kernel(), _permute_input(), _permute_weights(), _permute_output(), _activationlayer_function(), _permuted_input(), _permuted_weights(), _permuted_output(), _is_prepared(false),
- _is_nchw(false), _is_activationlayer_enabled(false), _original_weights(nullptr)
+ : _impl(std::make_unique<Impl>())
{
}
@@ -224,45 +223,49 @@ void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::configure(
ARM_COMPUTE_ERROR_THROW_ON(NEDepthwiseConvolutionLayer::validate(input->info(), weights->info(), (biases == nullptr) ? nullptr : biases->info(),
output->info(), conv_info, depth_multiplier, act_info, dilation));
- _is_nchw = input->info()->data_layout() == DataLayout::NCHW;
- _is_prepared = !_is_nchw;
+ const ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation };
+ _impl->op = std::make_unique<cpu::CpuDepthwiseConvolution>();
+ _impl->op->configure(input->info(), weights->info(), biases == nullptr ? nullptr : biases->info(), output->info(), info);
+
+ _impl->src = input;
+ _impl->dst = output;
+ _impl->weights = weights;
+ _impl->biases = biases;
+ _impl->is_nchw = input->info()->data_layout() == DataLayout::NCHW;
+ _impl->is_prepared = !_impl->is_nchw;
ITensor *input_to_use = input;
const ITensor *weights_to_use = weights;
ITensor *output_to_use = output;
- if(_is_nchw)
+ if(_impl->is_nchw)
{
- _permute_input.configure(input, &_permuted_input, PermutationVector(2U, 0U, 1U));
- _permuted_input.info()->set_data_layout(DataLayout::NHWC);
- input_to_use = &_permuted_input;
+ auto permute_input = std::make_unique<cpu::CpuPermute>();
+ auto permute_weights = std::make_unique<cpu::CpuPermute>();
- _permute_weights.configure(weights, &_permuted_weights, PermutationVector(2U, 0U, 1U));
- _permuted_weights.info()->set_data_layout(DataLayout::NHWC);
- weights_to_use = &_permuted_weights;
+ permute_input->configure(input->info(), _impl->permuted_input.info(), PermutationVector(2U, 0U, 1U));
+ _impl->permuted_input.info()->set_data_layout(DataLayout::NHWC);
+ input_to_use = &_impl->permuted_input;
- _permuted_output.allocator()->init(output->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(TensorShape()));
- output_to_use = &_permuted_output;
+ permute_weights->configure(weights->info(), _impl->permuted_weights.info(), PermutationVector(2U, 0U, 1U));
+ _impl->permuted_weights.info()->set_data_layout(DataLayout::NHWC);
+ weights_to_use = &_impl->permuted_weights;
+
+ _impl->permuted_output.allocator()->init(output->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(TensorShape()));
+ output_to_use = &_impl->permuted_output;
}
- _original_weights = weights_to_use;
- _depthwise_conv_kernel = std::make_unique<NEDepthwiseConvolutionLayerNativeKernel>();
- _depthwise_conv_kernel->configure(input_to_use, weights_to_use, biases, output_to_use, conv_info, depth_multiplier, dilation);
+ auto depthwise_conv_kernel = std::make_unique<cpu::kernels::CpuDepthwiseConvolutionNativeKernel>();
+ depthwise_conv_kernel->configure(input_to_use->info(), weights_to_use->info(), biases->info(), output_to_use->info(), info);
- if(_is_nchw)
+ if(_impl->is_nchw)
{
- _permute_output.configure(&_permuted_output, output, PermutationVector(1U, 2U, 0U));
- _permuted_output.info()->set_data_layout(DataLayout::NHWC);
+ auto permute_output = std::make_unique<cpu::CpuPermute>();
+ permute_output->configure(_impl->permuted_output.info(), output->info(), PermutationVector(1U, 2U, 0U));
+ _impl->permuted_output.info()->set_data_layout(DataLayout::NHWC);
- _permuted_input.allocator()->allocate();
- _permuted_weights.allocator()->allocate();
- _permuted_output.allocator()->allocate();
- }
-
- //Configure Activation Layer
- _is_activationlayer_enabled = act_info.enabled();
- if(_is_activationlayer_enabled)
- {
- _activationlayer_function.configure(output, nullptr, act_info);
+ _impl->permuted_input.allocator()->allocate();
+ _impl->permuted_weights.allocator()->allocate();
+ _impl->permuted_output.allocator()->allocate();
}
}
@@ -270,89 +273,53 @@ Status NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::validate
const PadStrideInfo &conv_info,
unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output);
- if(input->data_layout() == DataLayout::NCHW)
- {
- TensorShape permuted_input_shape = input->tensor_shape();
- TensorShape permuted_weights_shape = weights->tensor_shape();
- TensorShape permuted_output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier, dilation);
- permute(permuted_input_shape, PermutationVector(2U, 0U, 1U));
- permute(permuted_weights_shape, PermutationVector(2U, 0U, 1U));
- permute(permuted_output_shape, PermutationVector(2U, 0U, 1U));
-
- const TensorInfo permuted_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_input_shape).set_data_layout(DataLayout::NHWC));
- const TensorInfo permuted_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_weights_shape).set_data_layout(DataLayout::NHWC));
- const TensorInfo permuted_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_output_shape).set_data_layout(DataLayout::NCHW));
-
- ARM_COMPUTE_RETURN_ON_ERROR(NEPermute::validate(input, &permuted_input, PermutationVector(2U, 0U, 1U)));
- ARM_COMPUTE_RETURN_ON_ERROR(NEPermute::validate(weights, &permuted_weights, PermutationVector(2U, 0U, 1U)));
- ARM_COMPUTE_RETURN_ON_ERROR(NEPermute::validate(&permuted_output, output, PermutationVector(1U, 2U, 0U)));
-
- ARM_COMPUTE_RETURN_ON_ERROR(NEDepthwiseConvolutionLayerNativeKernel::validate(&permuted_input, &permuted_weights, biases, &permuted_output, conv_info, depth_multiplier, dilation));
- }
- else
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NEDepthwiseConvolutionLayerNativeKernel::validate(input, weights, biases, output, conv_info, depth_multiplier, dilation));
- }
-
- // Validate Activation Layer
- if(act_info.enabled())
- {
- ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info));
- }
-
- return Status{};
+ ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation };
+ return cpu::CpuDepthwiseConvolution::validate(input, weights, biases, output, info);
}
void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::run()
{
- if(_is_nchw)
- {
- prepare();
- _permute_input.run();
- }
-
- NEScheduler::get().schedule(_depthwise_conv_kernel.get(), Window::DimY);
-
- if(_is_nchw)
- {
- _permute_output.run();
- }
-
- if(_is_activationlayer_enabled)
- {
- _activationlayer_function.run();
- }
+ ITensorPack pack;
+ pack.add_tensor(TensorType::ACL_SRC_0, _impl->src);
+ pack.add_tensor(TensorType::ACL_SRC_1, _impl->weights);
+ pack.add_tensor(TensorType::ACL_SRC_2, _impl->biases);
+ pack.add_tensor(TensorType::ACL_INT_0, &_impl->permuted_input);
+ pack.add_tensor(TensorType::ACL_INT_1, &_impl->permuted_weights);
+ pack.add_tensor(TensorType::ACL_INT_2, &_impl->permuted_output);
+ pack.add_tensor(TensorType::ACL_DST_0, _impl->dst);
+
+ _impl->op->run(pack);
}
-void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::prepare()
+NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
+ : _memory_group(std::move(memory_manager)), _impl(std::make_unique<Impl>())
{
- if(!_is_prepared)
- {
- ARM_COMPUTE_ERROR_ON(!_original_weights->is_used());
-
- _permute_weights.run();
- _original_weights->mark_as_unused();
- _is_prepared = true;
- }
}
-NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer(std::shared_ptr<IMemoryManager> memory_manager)
- : _depth_conv_func(DepthwiseConvolutionFunction::GENERIC), _func_optimized(std::move(memory_manager)), _func_generic()
+#ifndef DOXYGEN_SKIP_THIS
+struct NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer::Impl
{
-}
+ DepthwiseConvolutionFunction depth_conv_func{ DepthwiseConvolutionFunction::OPTIMIZED };
+ NEDepthwiseConvolutionLayerOptimizedInternal func_optimized{ nullptr };
+ NEDepthwiseConvolutionLayerGeneric func_generic{};
+ std::shared_ptr<cpu::CpuDepthwiseConvolution> op{ nullptr };
+};
+#endif // DOXYGEN_SKIP_THIS
void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier,
const ActivationLayerInfo &act_info, const Size2D &dilation)
{
- _depth_conv_func = get_depthwiseconvolution_function(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), conv_info, depth_multiplier, act_info, dilation);
- switch(_depth_conv_func)
+ const ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation };
+ _impl->op = std::make_shared<cpu::CpuDepthwiseConvolution>();
+ _impl->depth_conv_func = _impl->op->get_depthwiseconvolution_function(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(),
+ info);
+ switch(_impl->depth_conv_func)
{
case DepthwiseConvolutionFunction::OPTIMIZED:
- _func_optimized.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation);
+ _impl->func_optimized.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation);
break;
case DepthwiseConvolutionFunction::GENERIC:
- _func_generic.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation);
+ _impl->func_generic.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation);
break;
default:
ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction");
@@ -362,43 +329,19 @@ void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weigh
Status NEDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info,
unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation)
{
- DepthwiseConvolutionFunction depth_conv_func = get_depthwiseconvolution_function(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation);
- switch(depth_conv_func)
- {
- case DepthwiseConvolutionFunction::OPTIMIZED:
- return NEDepthwiseConvolutionLayerOptimizedInternal::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation);
- break;
- case DepthwiseConvolutionFunction::GENERIC:
- return NEDepthwiseConvolutionLayerGeneric::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation);
- break;
- default:
- ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction");
- }
-}
-
-DepthwiseConvolutionFunction NEDepthwiseConvolutionLayer::get_depthwiseconvolution_function(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output,
- const PadStrideInfo &conv_info,
- unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation)
-{
- if(bool(NEDepthwiseConvolutionLayerOptimizedInternal::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation)))
- {
- return DepthwiseConvolutionFunction::OPTIMIZED;
- }
- else
- {
- return DepthwiseConvolutionFunction::GENERIC;
- }
+ ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation };
+ return cpu::CpuDepthwiseConvolution::validate(input, weights, biases, output, info);
}
void NEDepthwiseConvolutionLayer::run()
{
- switch(_depth_conv_func)
+ switch(_impl->depth_conv_func)
{
case DepthwiseConvolutionFunction::OPTIMIZED:
- _func_optimized.run();
+ _impl->func_optimized.run();
break;
case DepthwiseConvolutionFunction::GENERIC:
- _func_generic.run();
+ _impl->func_generic.run();
break;
default:
ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured");
@@ -407,13 +350,13 @@ void NEDepthwiseConvolutionLayer::run()
void NEDepthwiseConvolutionLayer::prepare()
{
- switch(_depth_conv_func)
+ switch(_impl->depth_conv_func)
{
case DepthwiseConvolutionFunction::OPTIMIZED:
- _func_optimized.prepare();
+ _impl->func_optimized.prepare();
break;
case DepthwiseConvolutionFunction::GENERIC:
- _func_generic.prepare();
+ _impl->func_generic.prepare();
break;
default:
ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured");