From 60c3b0e6821a80d78ffca5be30e05d062d071cd2 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 8 Apr 2021 12:02:58 +0100 Subject: Port DepthwiseConvolution to new API Resolves: COMPMID-4185 Change-Id: Ib5f22356356a022d567bb18d44ea272b62d10ebf Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5424 Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- .../CL/functions/CLDepthwiseConvolutionLayer.cpp | 14 +- .../NEON/functions/NEDepthwiseConvolutionLayer.cpp | 409 +++++++-------- .../NEDepthwiseConvolutionAssemblyDispatch.cpp | 569 --------------------- .../cpu/operators/CpuDepthwiseConvolution.cpp | 521 +++++++++++++++++++ .../cpu/operators/CpuDepthwiseConvolution.h | 230 +++++++++ .../CpuDepthwiseConvolutionAssemblyDispatch.cpp | 564 ++++++++++++++++++++ .../CpuDepthwiseConvolutionAssemblyDispatch.h | 97 ++++ 7 files changed, 1596 insertions(+), 808 deletions(-) delete mode 100644 src/runtime/NEON/functions/assembly/NEDepthwiseConvolutionAssemblyDispatch.cpp create mode 100644 src/runtime/cpu/operators/CpuDepthwiseConvolution.cpp create mode 100644 src/runtime/cpu/operators/CpuDepthwiseConvolution.h create mode 100644 src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.cpp create mode 100644 src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h (limited to 'src/runtime') diff --git a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp index f7517a50a3..8e3d010786 100644 --- a/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp +++ b/src/runtime/CL/functions/CLDepthwiseConvolutionLayer.cpp @@ -82,9 +82,10 @@ Status validate_arguments_3x3(const ITensorInfo *input, const ITensorInfo *weigh if(needs_permute) { - TensorShape permuted_input_shape = input->tensor_shape(); - TensorShape permuted_weights_shape = weights->tensor_shape(); - TensorShape permuted_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier, dilation); + TensorShape permuted_input_shape = input->tensor_shape(); + TensorShape permuted_weights_shape = weights->tensor_shape(); + const ConvolutionInfo info{ conv_info, depth_multiplier, ActivationLayerInfo(), dilation }; + TensorShape permuted_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info); permute(permuted_input_shape, PermutationVector(1U, 2U, 0U)); permute(permuted_weights_shape, PermutationVector(1U, 2U, 0U)); @@ -272,9 +273,10 @@ Status CLDepthwiseConvolutionLayer::CLDepthwiseConvolutionLayerGeneric::validate if(needs_permute) { - TensorShape permuted_input_shape = input->tensor_shape(); - TensorShape permuted_weights_shape = weights->tensor_shape(); - TensorShape permuted_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier, dilation); + TensorShape permuted_input_shape = input->tensor_shape(); + TensorShape permuted_weights_shape = weights->tensor_shape(); + const ConvolutionInfo info{ conv_info, depth_multiplier, ActivationLayerInfo(), dilation }; + TensorShape permuted_output_shape = shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info); permute(permuted_input_shape, PermutationVector(2U, 0U, 1U)); permute(permuted_weights_shape, PermutationVector(2U, 0U, 1U)); diff --git a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp index d17f6b5cd9..e1ceb0f083 100644 --- a/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp +++ b/src/runtime/NEON/functions/NEDepthwiseConvolutionLayer.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2020 Arm Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,54 +27,39 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/NEON/NEScheduler.h" -#include "src/core/NEON/kernels/NEDepthwiseConvolutionLayerNativeKernel.h" +#include "src/runtime/cpu/operators/CpuDepthwiseConvolution.h" using namespace arm_compute::misc; using namespace arm_compute::misc::shape_calculator; namespace arm_compute { -namespace -{ -Status validate_arguments_optimized(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, - unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation) +NEDepthwiseConvolutionLayer::~NEDepthwiseConvolutionLayer() = default; + +struct NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::Impl { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); - if(!is_data_type_quantized_per_channel(weights->data_type())) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); - } - ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); - ARM_COMPUTE_RETURN_ERROR_ON(dilation.x() < 1 || dilation.y() < 1); - const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH); - const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); - ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_w) + (weights->dimension(idx_w) - 1) * (dilation.x() - 1) > input->dimension(idx_w) + conv_info.pad_left() + conv_info.pad_right()); - ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_h) + (weights->dimension(idx_h) - 1) * (dilation.y() - 1) > input->dimension(idx_h) + conv_info.pad_top() + conv_info.pad_bottom()); - - if(biases != nullptr) + ITensor *src{ nullptr }; // SRC_0 + ITensor *dst{ nullptr }; // DST_0 + const ITensor *weights { - const unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL); - ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); - ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(channel_idx)); - } - - ARM_COMPUTE_RETURN_ON_ERROR(NEDepthwiseConvolutionAssemblyDispatch::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation)); - - //Validate Activation Layer - if(act_info.enabled()) + nullptr + }; // SRC_1 + const ITensor *biases { - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info)); - } - return Status{}; -} -} // namespace - -NEDepthwiseConvolutionLayer::~NEDepthwiseConvolutionLayer() = default; + nullptr + }; // SRC_2 + Tensor permuted_input{}; // INT_0 + Tensor permuted_weights{}; // INT_1 + Tensor permuted_output{}; // INT_2 + Tensor workspace{}; // INT_3 + Tensor packed_weights{}; // INT_4 + std::shared_ptr op{ nullptr }; + bool is_prepared{ false }; + bool permute{ false }; +}; NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::NEDepthwiseConvolutionLayerOptimizedInternal(std::shared_ptr memory_manager) - : _memory_group(memory_manager), _dwc_optimized_func(memory_manager), _permute_input(), _permute_weights(), _permute_output(), _activationlayer_function(), _accumulator(), _permuted_input(), - _permuted_weights(), _permuted_output(), _original_weights(nullptr), _has_bias(false), _is_quantized(false), _is_nchw(true), _permute(false), _is_activationlayer_enabled(false), _is_prepared(false) + : _memory_group(memory_manager), _impl(std::make_unique()) { } @@ -87,65 +72,76 @@ void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal:: const Size2D &dilation) { ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); - // Perform validation step - ARM_COMPUTE_ERROR_THROW_ON(NEDepthwiseConvolutionLayerOptimizedInternal::validate(input->info(), weights->info(), (biases == nullptr) ? nullptr : biases->info(), - output->info(), conv_info, depth_multiplier, act_info, dilation)); - - _original_weights = weights; - _is_quantized = is_data_type_quantized_asymmetric(input->info()->data_type()); - _has_bias = biases != nullptr; - _is_nchw = input->info()->data_layout() == DataLayout::NCHW; - _permute = _is_nchw; - _is_prepared = false; - _is_activationlayer_enabled = act_info.enabled(); + + bool is_nhwc = input->info()->data_layout() == DataLayout::NCHW; + _impl->src = input; + _impl->weights = weights; + _impl->biases = biases; + _impl->dst = output; + _impl->permute = is_nhwc; + + _impl->op = std::make_unique(); + ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation }; + _impl->op->configure(_impl->src->info(), _impl->weights->info(), _impl->biases == nullptr ? nullptr : _impl->biases->info(), + _impl->dst->info(), info); // Configure pipeline - ActivationLayerInfo act_info_to_use = ActivationLayerInfo(); - const bool is_relu = arm_compute::utils::info_helpers::is_relu(act_info); - const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(act_info); - _is_activationlayer_enabled = act_info.enabled() && !(is_relu || is_relu6); - if(!_is_activationlayer_enabled) + ActivationLayerInfo act_info_to_use = ActivationLayerInfo(); + const bool is_relu = arm_compute::utils::info_helpers::is_relu(act_info); + const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(act_info); + bool is_activationlayer_enabled = act_info.enabled() && !(is_relu || is_relu6); + + if(!is_activationlayer_enabled) { act_info_to_use = act_info; } + info = ConvolutionInfo{ conv_info, depth_multiplier, act_info_to_use, dilation }; - if(_is_nchw) + auto dwc_optimized_func = std::make_unique(); + + if(is_nhwc) { - _memory_group.manage(&_permuted_input); - _memory_group.manage(&_permuted_output); + auto permute_input = std::make_unique(); + auto permute_weights = std::make_unique(); + auto permute_output = std::make_unique(); + + _memory_group.manage(&_impl->permuted_input); + _memory_group.manage(&_impl->permuted_weights); + _memory_group.manage(&_impl->permuted_output); // Configure the function to transform the input tensor from NCHW -> NHWC - _permute_input.configure(input, &_permuted_input, PermutationVector(2U, 0U, 1U)); - _permuted_input.info()->set_data_layout(DataLayout::NHWC); + permute_input->configure(input->info(), _impl->permuted_input.info(), PermutationVector(2U, 0U, 1U)); + _impl->permuted_input.info()->set_data_layout(DataLayout::NHWC); // Configure the function to transform the weights tensor from IHW -> HWI - _permute_weights.configure(weights, &_permuted_weights, PermutationVector(2U, 0U, 1U)); - _permuted_weights.info()->set_data_layout(DataLayout::NHWC); + permute_weights->configure(weights->info(), _impl->permuted_weights.info(), PermutationVector(2U, 0U, 1U)); + _impl->permuted_weights.info()->set_data_layout(DataLayout::NHWC); - _permuted_output.info()->set_data_layout(DataLayout::NHWC); - _permuted_output.info()->set_quantization_info(output->info()->quantization_info()); + _impl->permuted_output.info()->set_data_layout(DataLayout::NHWC); + _impl->permuted_output.info()->set_quantization_info(output->info()->quantization_info()); // Configure optimized depthwise - _dwc_optimized_func.configure(&_permuted_input, &_permuted_weights, biases, &_permuted_output, conv_info, depth_multiplier, act_info_to_use, dilation); + dwc_optimized_func->configure(_impl->permuted_input.info(), _impl->permuted_weights.info(), biases->info(), _impl->permuted_output.info(), info); // Configure the function to transform the convoluted output to ACL's native ordering format NCHW - _permuted_output.info()->set_data_layout(DataLayout::NHWC); - _permute_output.configure(&_permuted_output, output, PermutationVector(1U, 2U, 0U)); + _impl->permuted_output.info()->set_data_layout(DataLayout::NHWC); + permute_output->configure(_impl->permuted_output.info(), output->info(), PermutationVector(1U, 2U, 0U)); - // Allocate tensors - _permuted_input.allocator()->allocate(); - _permuted_output.allocator()->allocate(); + _impl->permuted_input.allocator()->allocate(); + _impl->permuted_output.allocator()->allocate(); } else { - _dwc_optimized_func.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info_to_use, dilation); + dwc_optimized_func->configure(_impl->src->info(), _impl->weights->info(), biases->info(), _impl->dst->info(), info); } - // Configure activation - if(_is_activationlayer_enabled) - { - _activationlayer_function.configure(output, nullptr, act_info); - } + // Allocate memory based on the internal memory requirements + experimental::MemoryRequirements mem_req = dwc_optimized_func->workspace(); + _impl->workspace.allocator()->init(TensorInfo(TensorShape{ mem_req[0].size }, 1, DataType::S8), mem_req[0].alignment); + _impl->packed_weights.allocator()->init(TensorInfo(TensorShape{ mem_req[1].size }, 1, DataType::S8), mem_req[1].alignment); + + _impl->workspace.allocator()->allocate(); + _impl->packed_weights.allocator()->allocate(); } Status NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::validate(const ITensorInfo *input, @@ -157,63 +153,66 @@ Status NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal const ActivationLayerInfo &act_info, const Size2D &dilation) { - return validate_arguments_optimized(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); + ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation }; + return cpu::CpuDepthwiseConvolution::validate(input, weights, biases, output, info); } void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::run() { prepare(); - MemoryGroupResourceScope scope_mg(_memory_group); - // Permute input - if(_permute) - { - _permute_input.run(); - } - - // Run assembly function - _dwc_optimized_func.run(); - - // Permute output - if(_is_nchw) - { - _permute_output.run(); - } - - // Run activation - if(_is_activationlayer_enabled) - { - _activationlayer_function.run(); - } + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC_0, _impl->src); + pack.add_tensor(TensorType::ACL_SRC_1, _impl->weights); + pack.add_tensor(TensorType::ACL_SRC_2, _impl->biases); + pack.add_tensor(TensorType::ACL_INT_0, &_impl->permuted_input); + pack.add_tensor(TensorType::ACL_INT_1, &_impl->permuted_weights); + pack.add_tensor(TensorType::ACL_INT_2, &_impl->permuted_output); + pack.add_tensor(TensorType::ACL_INT_3, &_impl->workspace); + pack.add_tensor(TensorType::ACL_INT_4, &_impl->packed_weights); + pack.add_tensor(TensorType::ACL_DST_0, _impl->dst); + + _impl->op->run(pack); } void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerOptimizedInternal::prepare() { - if(!_is_prepared) + if(!_impl->is_prepared) { // Permute weights - if(_permute) + if(_impl->permute) { - _permuted_weights.allocator()->allocate(); - _permute_weights.run(); - _original_weights->mark_as_unused(); + _impl->permuted_weights.allocator()->allocate(); + _impl->weights->mark_as_unused(); } - // Prepare optimized function - _dwc_optimized_func.prepare(); - if(!_permuted_weights.is_used()) + if(!_impl->permuted_weights.is_used()) { - _permuted_weights.allocator()->free(); + _impl->permuted_weights.allocator()->free(); } - _is_prepared = true; + _impl->is_prepared = true; } } +struct NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::Impl +{ + Tensor permuted_input{}; + Tensor permuted_weights{}; + Tensor permuted_output{}; + bool is_prepared{ false }; + bool is_nchw{ false }; + bool is_activationlayer_enabled{ false }; + const ITensor *weights{ nullptr }; + const ITensor *biases{ nullptr }; + const ITensor *src{ nullptr }; + ITensor *dst{ nullptr }; + std::shared_ptr op{ nullptr }; +}; + NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::NEDepthwiseConvolutionLayerGeneric() - : _depthwise_conv_kernel(), _permute_input(), _permute_weights(), _permute_output(), _activationlayer_function(), _permuted_input(), _permuted_weights(), _permuted_output(), _is_prepared(false), - _is_nchw(false), _is_activationlayer_enabled(false), _original_weights(nullptr) + : _impl(std::make_unique()) { } @@ -224,45 +223,49 @@ void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::configure( ARM_COMPUTE_ERROR_THROW_ON(NEDepthwiseConvolutionLayer::validate(input->info(), weights->info(), (biases == nullptr) ? nullptr : biases->info(), output->info(), conv_info, depth_multiplier, act_info, dilation)); - _is_nchw = input->info()->data_layout() == DataLayout::NCHW; - _is_prepared = !_is_nchw; + const ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation }; + _impl->op = std::make_unique(); + _impl->op->configure(input->info(), weights->info(), biases == nullptr ? nullptr : biases->info(), output->info(), info); + + _impl->src = input; + _impl->dst = output; + _impl->weights = weights; + _impl->biases = biases; + _impl->is_nchw = input->info()->data_layout() == DataLayout::NCHW; + _impl->is_prepared = !_impl->is_nchw; ITensor *input_to_use = input; const ITensor *weights_to_use = weights; ITensor *output_to_use = output; - if(_is_nchw) + if(_impl->is_nchw) { - _permute_input.configure(input, &_permuted_input, PermutationVector(2U, 0U, 1U)); - _permuted_input.info()->set_data_layout(DataLayout::NHWC); - input_to_use = &_permuted_input; + auto permute_input = std::make_unique(); + auto permute_weights = std::make_unique(); - _permute_weights.configure(weights, &_permuted_weights, PermutationVector(2U, 0U, 1U)); - _permuted_weights.info()->set_data_layout(DataLayout::NHWC); - weights_to_use = &_permuted_weights; + permute_input->configure(input->info(), _impl->permuted_input.info(), PermutationVector(2U, 0U, 1U)); + _impl->permuted_input.info()->set_data_layout(DataLayout::NHWC); + input_to_use = &_impl->permuted_input; - _permuted_output.allocator()->init(output->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(TensorShape())); - output_to_use = &_permuted_output; + permute_weights->configure(weights->info(), _impl->permuted_weights.info(), PermutationVector(2U, 0U, 1U)); + _impl->permuted_weights.info()->set_data_layout(DataLayout::NHWC); + weights_to_use = &_impl->permuted_weights; + + _impl->permuted_output.allocator()->init(output->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(TensorShape())); + output_to_use = &_impl->permuted_output; } - _original_weights = weights_to_use; - _depthwise_conv_kernel = std::make_unique(); - _depthwise_conv_kernel->configure(input_to_use, weights_to_use, biases, output_to_use, conv_info, depth_multiplier, dilation); + auto depthwise_conv_kernel = std::make_unique(); + depthwise_conv_kernel->configure(input_to_use->info(), weights_to_use->info(), biases->info(), output_to_use->info(), info); - if(_is_nchw) + if(_impl->is_nchw) { - _permute_output.configure(&_permuted_output, output, PermutationVector(1U, 2U, 0U)); - _permuted_output.info()->set_data_layout(DataLayout::NHWC); + auto permute_output = std::make_unique(); + permute_output->configure(_impl->permuted_output.info(), output->info(), PermutationVector(1U, 2U, 0U)); + _impl->permuted_output.info()->set_data_layout(DataLayout::NHWC); - _permuted_input.allocator()->allocate(); - _permuted_weights.allocator()->allocate(); - _permuted_output.allocator()->allocate(); - } - - //Configure Activation Layer - _is_activationlayer_enabled = act_info.enabled(); - if(_is_activationlayer_enabled) - { - _activationlayer_function.configure(output, nullptr, act_info); + _impl->permuted_input.allocator()->allocate(); + _impl->permuted_weights.allocator()->allocate(); + _impl->permuted_output.allocator()->allocate(); } } @@ -270,89 +273,53 @@ Status NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::validate const PadStrideInfo &conv_info, unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation) { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); - if(input->data_layout() == DataLayout::NCHW) - { - TensorShape permuted_input_shape = input->tensor_shape(); - TensorShape permuted_weights_shape = weights->tensor_shape(); - TensorShape permuted_output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier, dilation); - permute(permuted_input_shape, PermutationVector(2U, 0U, 1U)); - permute(permuted_weights_shape, PermutationVector(2U, 0U, 1U)); - permute(permuted_output_shape, PermutationVector(2U, 0U, 1U)); - - const TensorInfo permuted_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_input_shape).set_data_layout(DataLayout::NHWC)); - const TensorInfo permuted_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_weights_shape).set_data_layout(DataLayout::NHWC)); - const TensorInfo permuted_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_output_shape).set_data_layout(DataLayout::NCHW)); - - ARM_COMPUTE_RETURN_ON_ERROR(NEPermute::validate(input, &permuted_input, PermutationVector(2U, 0U, 1U))); - ARM_COMPUTE_RETURN_ON_ERROR(NEPermute::validate(weights, &permuted_weights, PermutationVector(2U, 0U, 1U))); - ARM_COMPUTE_RETURN_ON_ERROR(NEPermute::validate(&permuted_output, output, PermutationVector(1U, 2U, 0U))); - - ARM_COMPUTE_RETURN_ON_ERROR(NEDepthwiseConvolutionLayerNativeKernel::validate(&permuted_input, &permuted_weights, biases, &permuted_output, conv_info, depth_multiplier, dilation)); - } - else - { - ARM_COMPUTE_RETURN_ON_ERROR(NEDepthwiseConvolutionLayerNativeKernel::validate(input, weights, biases, output, conv_info, depth_multiplier, dilation)); - } - - // Validate Activation Layer - if(act_info.enabled()) - { - ARM_COMPUTE_RETURN_ON_ERROR(NEActivationLayer::validate(output, nullptr, act_info)); - } - - return Status{}; + ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation }; + return cpu::CpuDepthwiseConvolution::validate(input, weights, biases, output, info); } void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::run() { - if(_is_nchw) - { - prepare(); - _permute_input.run(); - } - - NEScheduler::get().schedule(_depthwise_conv_kernel.get(), Window::DimY); - - if(_is_nchw) - { - _permute_output.run(); - } - - if(_is_activationlayer_enabled) - { - _activationlayer_function.run(); - } + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC_0, _impl->src); + pack.add_tensor(TensorType::ACL_SRC_1, _impl->weights); + pack.add_tensor(TensorType::ACL_SRC_2, _impl->biases); + pack.add_tensor(TensorType::ACL_INT_0, &_impl->permuted_input); + pack.add_tensor(TensorType::ACL_INT_1, &_impl->permuted_weights); + pack.add_tensor(TensorType::ACL_INT_2, &_impl->permuted_output); + pack.add_tensor(TensorType::ACL_DST_0, _impl->dst); + + _impl->op->run(pack); } -void NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayerGeneric::prepare() +NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer(std::shared_ptr memory_manager) + : _memory_group(std::move(memory_manager)), _impl(std::make_unique()) { - if(!_is_prepared) - { - ARM_COMPUTE_ERROR_ON(!_original_weights->is_used()); - - _permute_weights.run(); - _original_weights->mark_as_unused(); - _is_prepared = true; - } } -NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer(std::shared_ptr memory_manager) - : _depth_conv_func(DepthwiseConvolutionFunction::GENERIC), _func_optimized(std::move(memory_manager)), _func_generic() +#ifndef DOXYGEN_SKIP_THIS +struct NEDepthwiseConvolutionLayer::NEDepthwiseConvolutionLayer::Impl { -} + DepthwiseConvolutionFunction depth_conv_func{ DepthwiseConvolutionFunction::OPTIMIZED }; + NEDepthwiseConvolutionLayerOptimizedInternal func_optimized{ nullptr }; + NEDepthwiseConvolutionLayerGeneric func_generic{}; + std::shared_ptr op{ nullptr }; +}; +#endif // DOXYGEN_SKIP_THIS void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weights, const ITensor *biases, ITensor *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation) { - _depth_conv_func = get_depthwiseconvolution_function(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), conv_info, depth_multiplier, act_info, dilation); - switch(_depth_conv_func) + const ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation }; + _impl->op = std::make_shared(); + _impl->depth_conv_func = _impl->op->get_depthwiseconvolution_function(input->info(), weights->info(), (biases != nullptr) ? biases->info() : nullptr, output->info(), + info); + switch(_impl->depth_conv_func) { case DepthwiseConvolutionFunction::OPTIMIZED: - _func_optimized.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); + _impl->func_optimized.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); break; case DepthwiseConvolutionFunction::GENERIC: - _func_generic.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); + _impl->func_generic.configure(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); break; default: ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction"); @@ -362,43 +329,19 @@ void NEDepthwiseConvolutionLayer::configure(ITensor *input, const ITensor *weigh Status NEDepthwiseConvolutionLayer::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const PadStrideInfo &conv_info, unsigned int depth_multiplier, const ActivationLayerInfo &act_info, const Size2D &dilation) { - DepthwiseConvolutionFunction depth_conv_func = get_depthwiseconvolution_function(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); - switch(depth_conv_func) - { - case DepthwiseConvolutionFunction::OPTIMIZED: - return NEDepthwiseConvolutionLayerOptimizedInternal::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); - break; - case DepthwiseConvolutionFunction::GENERIC: - return NEDepthwiseConvolutionLayerGeneric::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation); - break; - default: - ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction"); - } -} - -DepthwiseConvolutionFunction NEDepthwiseConvolutionLayer::get_depthwiseconvolution_function(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, - const PadStrideInfo &conv_info, - unsigned int depth_multiplier, ActivationLayerInfo act_info, const Size2D &dilation) -{ - if(bool(NEDepthwiseConvolutionLayerOptimizedInternal::validate(input, weights, biases, output, conv_info, depth_multiplier, act_info, dilation))) - { - return DepthwiseConvolutionFunction::OPTIMIZED; - } - else - { - return DepthwiseConvolutionFunction::GENERIC; - } + ConvolutionInfo info{ conv_info, depth_multiplier, act_info, dilation }; + return cpu::CpuDepthwiseConvolution::validate(input, weights, biases, output, info); } void NEDepthwiseConvolutionLayer::run() { - switch(_depth_conv_func) + switch(_impl->depth_conv_func) { case DepthwiseConvolutionFunction::OPTIMIZED: - _func_optimized.run(); + _impl->func_optimized.run(); break; case DepthwiseConvolutionFunction::GENERIC: - _func_generic.run(); + _impl->func_generic.run(); break; default: ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured"); @@ -407,13 +350,13 @@ void NEDepthwiseConvolutionLayer::run() void NEDepthwiseConvolutionLayer::prepare() { - switch(_depth_conv_func) + switch(_impl->depth_conv_func) { case DepthwiseConvolutionFunction::OPTIMIZED: - _func_optimized.prepare(); + _impl->func_optimized.prepare(); break; case DepthwiseConvolutionFunction::GENERIC: - _func_generic.prepare(); + _impl->func_generic.prepare(); break; default: ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured"); diff --git a/src/runtime/NEON/functions/assembly/NEDepthwiseConvolutionAssemblyDispatch.cpp b/src/runtime/NEON/functions/assembly/NEDepthwiseConvolutionAssemblyDispatch.cpp deleted file mode 100644 index 101df98b7d..0000000000 --- a/src/runtime/NEON/functions/assembly/NEDepthwiseConvolutionAssemblyDispatch.cpp +++ /dev/null @@ -1,569 +0,0 @@ -/* - * Copyright (c) 2019-2020 Arm Limited. - * - * SPDX-License-Identifier: MIT - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to - * deal in the Software without restriction, including without limitation the - * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or - * sell copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - */ - -#include "arm_compute/runtime/NEON/functions/assembly/NEDepthwiseConvolutionAssemblyDispatch.h" - -#include "arm_compute/core/ITensor.h" -#include "arm_compute/core/Utils.h" -#include "arm_compute/core/utils/misc/InfoHelpers.h" -#include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "arm_compute/core/utils/quantization/AsymmHelpers.h" -#include "src/core/CPP/Validate.h" -#include "src/core/NEON/kernels/assembly/NEDepthwiseConvolutionAssemblyKernelWrapper.h" -#include "src/core/NEON/kernels/convolution/depthwise/depthwise_dilated.hpp" -#include "src/core/NEON/kernels/convolution/depthwise/depthwise_quantized_dilated.hpp" -#include "src/core/helpers/AutoConfiguration.h" - -#include "arm_compute/runtime/NEON/NEScheduler.h" - -#include - -namespace arm_compute -{ -namespace -{ -std::unique_ptr get_qasymm8_convolver(int kernel_size, int stride_x, - int n_batches, int in_rows, int in_cols, int n_channels, - int dilation_factor, neon_convolution_kernels::ActivationFunction activation, - const qasymm8::QAsymm8Params &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo, - const qasymm8::QAsymm8RescaleParams &rescale_params, - int padding_top, int padding_left, int padding_bottom, int padding_right) -{ - switch(kernel_size) - { - case 3: - { - switch(stride_x) - { - case 1: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - case 2: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - default: - return nullptr; - } - } - case 5: - { - switch(stride_x) - { - case 1: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - case 2: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - default: - return nullptr; - } - } - default: - return nullptr; - } -} - -std::unique_ptr get_qsymm8_perchannel_convolver(int kernel_size, int stride_x, - int n_batches, int in_rows, int in_cols, int n_channels, - neon_convolution_kernels::ActivationFunction activation, - const qsymm8::QSymm8PerChannelParams &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo, - const qsymm8::QSymm8PerChannelRescaleParams &rescale_params, - int padding_top, int padding_left, int padding_bottom, int padding_right) -{ - switch(kernel_size) - { - case 3: - { - switch(stride_x) - { - case 1: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - case 2: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - default: - return nullptr; - } - } - case 5: - { - switch(stride_x) - { - case 1: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - case 2: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - default: - return nullptr; - } - } - default: - return nullptr; - } -} - -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -std::unique_ptr get_fp16_convolver(int kernel_size, int stride_x, - int n_batches, int in_rows, int in_cols, int n_channels, - int dilation_factor, neon_convolution_kernels::ActivationFunction activation, - int padding_top, int padding_left, int padding_bottom, int padding_right) -{ - switch(kernel_size) - { - case 3: - { - switch(stride_x) - { - case 1: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - case 2: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - default: - return nullptr; - } - } - case 5: - { - switch(stride_x) - { - case 1: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - case 2: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - default: - return nullptr; - } - } - default: - return nullptr; - } -} -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - -std::unique_ptr get_fp32_convolver(int kernel_size, int stride_x, - int n_batches, int in_rows, int in_cols, int n_channels, - int dilation_factor, neon_convolution_kernels::ActivationFunction activation, - int padding_top, int padding_left, int padding_bottom, int padding_right) -{ - switch(kernel_size) - { - case 3: - { - switch(stride_x) - { - case 1: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - case 2: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - default: - return nullptr; - } - } - case 5: - { - switch(stride_x) - { - case 1: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - case 2: - return std::make_unique>( - n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - default: - return nullptr; - } - } - default: - return nullptr; - } -} - -std::unique_ptr create_convolver(const ITensor *input, - const ITensor *weights, - ITensor *output, - PadStrideInfo conv_info, - ActivationLayerInfo act_info, - const Size2D &dilation) -{ - ARM_COMPUTE_UNUSED(dilation); - const DataType data_type = input->info()->data_type(); - const TensorShape shape = input->info()->tensor_shape(); - - const int n_batches = shape[3]; - const int in_rows = shape.z(); - const int in_cols = shape.y(); - const int n_channels = shape.x(); - const int dilation_factor = dilation.x(); - const int padding_top = conv_info.pad_top(); - const int padding_left = conv_info.pad_left(); - const int padding_bottom = conv_info.pad_bottom(); - const int padding_right = conv_info.pad_right(); - - const bool is_uniform_quantized = (data_type == DataType::QASYMM8) && (weights->info()->data_type() == DataType::QASYMM8); - const bool is_perchannel_quantized = (data_type == DataType::QASYMM8) && (weights->info()->data_type() == DataType::QSYMM8_PER_CHANNEL); - - const unsigned int stride_x = conv_info.stride().first; - const unsigned int kernel_size = weights->info()->tensor_shape().y(); - - // Map activation function - neon_convolution_kernels::ActivationFunction activation = neon_convolution_kernels::ActivationFunction::None; - if(arm_compute::utils::info_helpers::is_relu(act_info)) - { - activation = neon_convolution_kernels::ActivationFunction::ReLU; - } - else if(arm_compute::utils::info_helpers::is_relu6(act_info)) - { - activation = neon_convolution_kernels::ActivationFunction::ReLU6; - } - - // Create quantized convolver - if(is_uniform_quantized) - { - const UniformQuantizationInfo input_qinfo = input->info()->quantization_info().uniform(); - const UniformQuantizationInfo weights_qinfo = weights->info()->quantization_info().uniform(); - const UniformQuantizationInfo output_qinfo = output->info()->quantization_info().uniform(); - - // Check that quantization info are in the range [0, 255] - ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255); - ARM_COMPUTE_ERROR_ON(weights_qinfo.offset < 0 || weights_qinfo.offset > 255); - ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255); - const qasymm8::QAsymm8Params iqinfo{ static_cast(input_qinfo.offset), input_qinfo.scale }; - const qasymm8::QAsymm8Params wqinfo{ static_cast(weights_qinfo.offset), weights_qinfo.scale }; - const qasymm8::QAsymm8Params oqinfo{ static_cast(output_qinfo.offset), output_qinfo.scale }; - - // Calculate rescale parameters - const float fmultipler = iqinfo.scale * wqinfo.scale / oqinfo.scale; - int32_t qmultiplier = 0; - int32_t qshift = 0; - quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift); - qasymm8::QAsymm8RescaleParams rescale_params(qshift, qmultiplier, fmultipler); - - return get_qasymm8_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, - wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - } - else if(is_perchannel_quantized) - { - const UniformQuantizationInfo input_qinfo = input->info()->quantization_info().uniform(); - const QuantizationInfo weights_qinfo = weights->info()->quantization_info(); - const UniformQuantizationInfo output_qinfo = output->info()->quantization_info().uniform(); - - // Check that quantization info are in the range [0, 255] - ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255); - ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255); - const qasymm8::QAsymm8Params iqinfo{ static_cast(input_qinfo.offset), input_qinfo.scale }; - const qsymm8::QSymm8PerChannelParams wqinfo{ weights_qinfo.scale() }; - const qasymm8::QAsymm8Params oqinfo{ static_cast(output_qinfo.offset), output_qinfo.scale }; - - // Calculate rescale parameters - std::vector fmultipliers; - std::vector qmultipliers; - std::vector qshifts; - - for(auto const s : wqinfo.scales) - { - const float fmultipler = iqinfo.scale * s / oqinfo.scale; - int32_t qmultiplier = 0; - int32_t qshift = 0; - quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift); - fmultipliers.push_back(fmultipler); - qmultipliers.push_back(qmultiplier); - qshifts.push_back(qshift); - } - - qsymm8::QSymm8PerChannelRescaleParams rescale_params(qshifts, qmultipliers, fmultipliers); - - return get_qsymm8_perchannel_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, activation, - wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); - } - else - { - // Create float convolver - switch(data_type) - { -#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F16: - { - return get_fp16_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - } -#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - case DataType::F32: - { - return get_fp32_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); - } - default: - return nullptr; - } - } -} -} // namespace - -struct NEDepthwiseConvolutionAssemblyDispatch::LocalImpl -{ - std::unique_ptr _dwc_assembly_kernel{ nullptr }; - NEDepthwiseConvolutionAssemblyKernelWrapper _dwc_acl_kernel{}; -}; - -#ifndef DOXYGEN_SKIP_THIS -NEDepthwiseConvolutionAssemblyDispatch::NEDepthwiseConvolutionAssemblyDispatch(std::shared_ptr memory_manager) - : _memory_group(std::move(memory_manager)), _input(nullptr), _weights(nullptr), _bias(nullptr), _output(nullptr), _packed_weights(), _workspace(), _is_prepared(false), - _pImpl(std::make_unique()) -{ -} -#endif /* DOXYGEN_SKIP_THIS */ - -NEDepthwiseConvolutionAssemblyDispatch::~NEDepthwiseConvolutionAssemblyDispatch() = default; - -void NEDepthwiseConvolutionAssemblyDispatch::configure(const ITensor *input, - const ITensor *weights, - const ITensor *bias, - ITensor *output, - const PadStrideInfo &conv_info, - unsigned int depth_multiplier, - const ActivationLayerInfo &act_info, - const Size2D &dilation) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); - ARM_COMPUTE_UNUSED(depth_multiplier); - ARM_COMPUTE_ERROR_THROW_ON(NEDepthwiseConvolutionAssemblyDispatch::validate(input->info(), - weights->info(), - bias != nullptr ? bias->info() : nullptr, - output->info(), - conv_info, - depth_multiplier, - act_info, - dilation)); - - // Output auto inizialitation if not yet initialized - const TensorShape output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input->info(), *weights->info(), conv_info, depth_multiplier, dilation); - auto_init_if_empty(*output->info(), input->info()->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape).set_quantization_info(output->info()->quantization_info())); - - _input = input; - _weights = weights; - _bias = bias; - _output = output; - _is_prepared = false; - - // Create convolver - _pImpl->_dwc_assembly_kernel = create_convolver(input, weights, output, conv_info, act_info, dilation); - ARM_COMPUTE_ERROR_ON(_pImpl->_dwc_assembly_kernel == nullptr); - - // Create assembly kernel wrapper - _pImpl->_dwc_acl_kernel.configure(_pImpl->_dwc_assembly_kernel.get()); - - constexpr size_t alignment = 128; - - // Create workspace - const unsigned int num_threads = NEScheduler::get().num_threads(); - const size_t workspace_size = _pImpl->_dwc_assembly_kernel->get_working_space_size(num_threads); - ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "Workspace size cannot be 0 !"); - _workspace.allocator()->init(TensorInfo(TensorShape{ workspace_size }, 1, DataType::S8), alignment); - _memory_group.manage(&_workspace); - _workspace.allocator()->allocate(); - - // Create packing tensor - const size_t pack_tensor_size = _pImpl->_dwc_assembly_kernel->get_packed_params_size(); - ARM_COMPUTE_ERROR_ON_MSG(pack_tensor_size == 0, "Pack tensor size cannot be 0 !"); - _packed_weights.allocator()->init(TensorInfo(TensorShape{ pack_tensor_size }, 1, DataType::S8), alignment); -} - -Status NEDepthwiseConvolutionAssemblyDispatch::validate(const ITensorInfo *input, - const ITensorInfo *weights, - const ITensorInfo *bias, - const ITensorInfo *output, - const PadStrideInfo &conv_info, - unsigned int depth_multiplier, - const ActivationLayerInfo &act_info, - const Size2D &dilation) -{ - ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); - ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); - if(weights->data_type() != DataType::QSYMM8_PER_CHANNEL) - { - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); - } - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights); - - // Validate convolver - ARM_COMPUTE_RETURN_ERROR_ON(!is_optimized_supported(input, weights, conv_info, depth_multiplier, dilation)); - - // Validate activation - const bool is_relu = arm_compute::utils::info_helpers::is_relu(act_info); - const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(act_info); - ARM_COMPUTE_RETURN_ERROR_ON(act_info.enabled() && !(is_relu || is_relu6)); - - // Check bias - if(bias != nullptr) - { - unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL); - ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1); - ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != weights->dimension(channel_idx)); - } - - // Check output - if(output->total_size() != 0) - { - const TensorShape output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input, *weights, conv_info, depth_multiplier, dilation); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - } - - // The uniform quantization case will only have 1 scale value in the weights quantization info - const UniformQuantizationInfo input_qinfo = input->quantization_info().uniform(); - const QuantizationInfo weights_qinfo = weights->quantization_info(); - const UniformQuantizationInfo output_qinfo = output->quantization_info().uniform(); - for(auto const s : weights_qinfo.scale()) - { - const float fmultipler = input_qinfo.scale * s / output_qinfo.scale; - ARM_COMPUTE_RETURN_ERROR_ON(fmultipler > 1.f); - } - - return Status{}; -} - -bool NEDepthwiseConvolutionAssemblyDispatch::is_optimized_supported(const ITensorInfo *input, - const ITensorInfo *weights, - PadStrideInfo conv_info, - unsigned int depth_multiplier, - const Size2D &dilation) -{ - ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights); - - // Reshape input shape if in NHWC format - const DataLayout data_layout = input->data_layout(); - TensorShape in_shape{ input->tensor_shape() }; - if(data_layout == DataLayout::NHWC) - { - in_shape.set(Window::DimX, input->tensor_shape().y()); - in_shape.set(Window::DimY, input->tensor_shape().z()); - in_shape.set(Window::DimZ, input->tensor_shape().x()); - } - - // Check data type - // TODO (COMPMID-3004): Add assembly optimized routine for QASYMM8_SIGNED NEDepthwiseConvolutionLayer - const DataType input_type = input->data_type(); - const bool is_input_type_valid = is_data_type_float(input_type) || input_type == DataType::QASYMM8; - const DataType weights_type = weights->data_type(); - const bool is_weights_type_valid = is_data_type_float(weights_type) || weights_type == DataType::QASYMM8 || weights_type == DataType::QASYMM8_SIGNED - || weights_type == DataType::QSYMM8_PER_CHANNEL; - - // Check weighs size - std::set supported_kernel_sizes = { 3, 5 }; - const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); - const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); - const unsigned int kernel_w = weights->dimension(width_idx); - const unsigned int kernel_h = weights->dimension(height_idx); - bool weights_supported = (kernel_w == kernel_h) && (supported_kernel_sizes.count(kernel_w) != 0); - - // Check for supported strides - const auto &strides = conv_info.stride(); - bool supported_strides = (strides.first == strides.second) && ((strides.first == 1) || (strides.first == 2)); - - // Check for supported padding - const auto pad_top = conv_info.pad_top(); - const auto pad_right = conv_info.pad_right(); - const auto pad_bottom = conv_info.pad_bottom(); - const auto pad_left = conv_info.pad_left(); - PadStrideInfo same_pad = calculate_same_pad(in_shape, TensorShape(kernel_w, kernel_h), conv_info, DataLayout::NCHW, dilation); - bool is_same_padding = (pad_top == same_pad.pad_top()) && (pad_right == same_pad.pad_right()) && (pad_bottom == same_pad.pad_bottom()) && (pad_left == same_pad.pad_left()); - bool is_valid_padding = (pad_top == 0) && (pad_right == 0) && (pad_bottom == 0) && (pad_left == 0); - bool supported_padding = is_same_padding || is_valid_padding; - // TODO(COMPMID-2464): Enable once dilated conv with stride 2 is supported - bool is_dilation_supported = ((dilation == Size2D(1U, 1U)) || ((dilation.x() == dilation.y()) && strides.first == 1)); - - if(weights_type == DataType::QSYMM8_PER_CHANNEL) - { - is_dilation_supported = is_dilation_supported && (dilation == Size2D(1U, 1U)); - } - - return is_input_type_valid && is_weights_type_valid && weights_supported && supported_strides && supported_padding && (depth_multiplier == 1) && is_dilation_supported; -} - -void NEDepthwiseConvolutionAssemblyDispatch::run() -{ - // Prepare assembly kernel - prepare(); - - MemoryGroupResourceScope scope_mg(_memory_group); - - // Setup inputs/outputs - ARM_COMPUTE_ERROR_ON(_workspace.buffer() == nullptr); - _pImpl->_dwc_assembly_kernel->set_working_space(static_cast(_workspace.buffer())); - - ARM_COMPUTE_ERROR_ON(_input->buffer() == nullptr); - const int input_element_size = _input->info()->element_size(); - const int input_batch_stride = _input->info()->strides_in_bytes()[3] / input_element_size; - const int input_row_stride = _input->info()->strides_in_bytes().z() / input_element_size; - const int input_col_stride = _input->info()->strides_in_bytes().y() / input_element_size; - const void *input_ptr = _input->buffer() + _input->info()->offset_first_element_in_bytes(); - _pImpl->_dwc_assembly_kernel->set_input(input_ptr, input_batch_stride, input_row_stride, input_col_stride); - - ARM_COMPUTE_ERROR_ON(_output->buffer() == nullptr); - const int output_element_size = _output->info()->element_size(); - const int output_batch_stride = _output->info()->strides_in_bytes()[3] / output_element_size; - const int output_row_stride = _output->info()->strides_in_bytes().z() / output_element_size; - const int output_col_stride = _output->info()->strides_in_bytes().y() / output_element_size; - void *output_ptr = _output->buffer() + _output->info()->offset_first_element_in_bytes(); - _pImpl->_dwc_assembly_kernel->set_output(output_ptr, output_batch_stride, output_row_stride, output_col_stride); - - // Schedule assembly kernel - NEScheduler::get().schedule(&_pImpl->_dwc_acl_kernel, Window::DimX); -} - -void NEDepthwiseConvolutionAssemblyDispatch::prepare() -{ - if(!_is_prepared) - { - _packed_weights.allocator()->allocate(); - ARM_COMPUTE_ERROR_ON(_packed_weights.buffer() == nullptr); - - // Pack weights and bias - const int weights_element_size = _weights->info()->element_size(); - const int weights_row_stride = _weights->info()->strides_in_bytes().z() / weights_element_size; - const int weights_col_stride = _weights->info()->strides_in_bytes().y() / weights_element_size; - _pImpl->_dwc_assembly_kernel->pack_params(_packed_weights.buffer(), - _weights->buffer() + _weights->info()->offset_first_element_in_bytes(), - weights_row_stride, - weights_col_stride, - (_bias != nullptr) ? _bias->buffer() : nullptr); - _pImpl->_dwc_assembly_kernel->set_packed_params_buffer(_packed_weights.buffer()); - - _weights->mark_as_unused(); - if(_bias != nullptr) - { - _bias->mark_as_unused(); - } - _is_prepared = true; - } -} -} // namespace arm_compute diff --git a/src/runtime/cpu/operators/CpuDepthwiseConvolution.cpp b/src/runtime/cpu/operators/CpuDepthwiseConvolution.cpp new file mode 100644 index 0000000000..183a2af0cd --- /dev/null +++ b/src/runtime/cpu/operators/CpuDepthwiseConvolution.cpp @@ -0,0 +1,521 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/cpu/operators/CpuDepthwiseConvolution.h" + +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/InfoHelpers.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" +#include "src/core/cpu/kernels/CpuDepthwiseConvolutionNativeKernel.h" + +namespace arm_compute +{ +namespace cpu +{ +namespace +{ +Status validate_arguments_optimized(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); + if(!is_data_type_quantized_per_channel(weights->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + } + ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); + ARM_COMPUTE_RETURN_ERROR_ON(info.dilation.x() < 1 || info.dilation.y() < 1); + const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH); + const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); + ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_w) + (weights->dimension(idx_w) - 1) * (info.dilation.x() - 1) > input->dimension(idx_w) + info.pad_stride_info.pad_left() + + info.pad_stride_info.pad_right()); + ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_h) + (weights->dimension(idx_h) - 1) * (info.dilation.y() - 1) > input->dimension(idx_h) + info.pad_stride_info.pad_top() + + info.pad_stride_info.pad_bottom()); + + if(biases != nullptr) + { + const unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL); + ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(channel_idx)); + } + + ARM_COMPUTE_RETURN_ON_ERROR(CpuDepthwiseConvolutionAssemblyDispatch::validate(input, weights, biases, output, info)); + + //Validate Activation Layer + if(info.act_info.enabled()) + { + ARM_COMPUTE_RETURN_ON_ERROR(CpuActivation::validate(output, nullptr, info.act_info)); + } + return Status{}; +} +} // namespace + +CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::CpuDepthwiseConvolutionOptimizedInternal() + : _dwc_optimized_func(nullptr), _permute_input(nullptr), _permute_weights(nullptr), _permute_output(nullptr), _activationlayer_function(nullptr), _has_bias(false), _is_quantized(false), + _is_nchw(true), _permute(false), _is_activationlayer_enabled(false), _is_prepared(false) +{ +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::configure(ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *biases, + ITensorInfo *output, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(CpuDepthwiseConvolutionOptimizedInternal::validate(input, weights, (biases == nullptr) ? nullptr : biases, + output, info)); + + _is_quantized = is_data_type_quantized_asymmetric(input->data_type()); + _has_bias = biases != nullptr; + _is_nchw = input->data_layout() == DataLayout::NCHW; + _permute = _is_nchw; + _is_prepared = false; + + // Configure pipeline + ActivationLayerInfo act_info_to_use = ActivationLayerInfo(); + const bool is_relu = arm_compute::utils::info_helpers::is_relu(info.act_info); + const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(info.act_info); + _is_activationlayer_enabled = info.act_info.enabled() && !(is_relu || is_relu6); + + if(!_is_activationlayer_enabled) + { + act_info_to_use = info.act_info; + } + + _dwc_optimized_func = std::make_unique(); + if(_is_nchw) + { + _permute_input = std::make_unique(); + _permute_weights = std::make_unique(); + _permute_output = std::make_unique(); + + auto input_perm = std::make_unique(); + auto weights_perm = std::make_unique(); + auto output_perm = std::make_unique(); + + // Configure the function to transform the input tensor from NCHW -> NHWC + _permute_input->configure(input, input_perm.get(), PermutationVector(2U, 0U, 1U)); + input_perm->set_data_layout(DataLayout::NHWC); + + // Configure the function to transform the weights tensor from IHW -> HWI + _permute_weights->configure(weights, weights_perm.get(), PermutationVector(2U, 0U, 1U)); + weights_perm->set_data_layout(DataLayout::NHWC); + + output_perm->set_data_layout(DataLayout::NHWC); + output_perm->set_quantization_info(output->quantization_info()); + + // Configure optimized depthwise + _dwc_optimized_func->configure(input_perm.get(), weights_perm.get(), biases, output_perm.get(), info); + + // Configure the function to transform the convoluted output to ACL's native ordering format NCHW + output_perm->set_data_layout(DataLayout::NHWC); + _permute_output->configure(output_perm.get(), output, PermutationVector(1U, 2U, 0U)); + } + else + { + _dwc_optimized_func->configure(input, weights, biases, output, info); + } + + // Configure activation + if(_is_activationlayer_enabled) + { + _activationlayer_function = std::make_unique(); + _activationlayer_function->configure(output, nullptr, info.act_info); + } +} + +Status CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::validate(const ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *biases, + const ITensorInfo *output, + const ConvolutionInfo &info) +{ + return validate_arguments_optimized(input, weights, biases, output, info); +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::run(ITensorPack &tensors) +{ + ARM_COMPUTE_ERROR_ON_MSG(tensors.empty(), "No inputs provided"); + prepare(tensors); + + auto bias = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto dst = tensors.get_tensor(TensorType::ACL_DST_0); + auto workspace = tensors.get_tensor(TensorType::ACL_INT_3); + auto packed_weights = tensors.get_tensor(TensorType::ACL_INT_4); + + // Permute input + if(_permute) + { + ITensorPack pack; + auto src = tensors.get_tensor(TensorType::ACL_SRC_0); + auto src_perm = tensors.get_tensor(TensorType::ACL_INT_0); + pack.add_tensor(TensorType::ACL_SRC, src); + pack.add_tensor(TensorType::ACL_DST, src_perm); + _permute_input->run(pack); + } + + // Run assembly function + if(_is_nchw) + { + auto src_perm = tensors.get_tensor(TensorType::ACL_INT_0); + auto weights_perm = tensors.get_tensor(TensorType::ACL_INT_1); + auto dst_perm = tensors.get_tensor(TensorType::ACL_INT_2); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC_0, src_perm); + pack.add_tensor(TensorType::ACL_SRC_1, weights_perm); + pack.add_tensor(TensorType::ACL_SRC_2, bias); + pack.add_tensor(TensorType::ACL_INT_0, workspace); + pack.add_tensor(TensorType::ACL_INT_1, packed_weights); + pack.add_tensor(TensorType::ACL_DST, dst_perm); + _dwc_optimized_func->run(pack); + } + else + { + auto src = tensors.get_tensor(TensorType::ACL_SRC_0); + auto weights = tensors.get_tensor(TensorType::ACL_SRC_1); + auto dst = tensors.get_tensor(TensorType::ACL_DST); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC_0, src); + pack.add_tensor(TensorType::ACL_SRC_1, weights); + pack.add_tensor(TensorType::ACL_SRC_2, bias); + pack.add_tensor(TensorType::ACL_INT_0, workspace); + pack.add_tensor(TensorType::ACL_INT_1, packed_weights); + pack.add_tensor(TensorType::ACL_DST, dst); + _dwc_optimized_func->run(pack); + } + + // Permute output + if(_is_nchw) + { + ITensorPack pack; + auto dst_perm = tensors.get_tensor(TensorType::ACL_INT_2); + pack.add_tensor(TensorType::ACL_SRC, dst_perm); + pack.add_tensor(TensorType::ACL_DST, dst); + _permute_output->run(pack); + } + + // Run activation + if(_is_activationlayer_enabled) + { + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, dst); + pack.add_tensor(TensorType::ACL_DST, dst); + _activationlayer_function->run(pack); + } +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::prepare(ITensorPack &tensors) +{ + if(!_is_prepared) + { + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto bias = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto packed_weights = tensors.get_tensor(TensorType::ACL_INT_4); + + // Permute weights + if(_permute) + { + auto permuted_weights = tensors.get_tensor(TensorType::ACL_INT_1); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, weights); + pack.add_tensor(TensorType::ACL_DST, permuted_weights); + _permute_weights->run(pack); + + ITensorPack pack_opt; + pack_opt.add_const_tensor(TensorType::ACL_SRC_1, permuted_weights); + pack_opt.add_tensor(TensorType::ACL_SRC_2, bias); + pack_opt.add_tensor(TensorType::ACL_INT_1, packed_weights); + + // Prepare optimized function + _dwc_optimized_func->prepare(pack_opt); + } + else + { + ITensorPack pack_opt; + pack_opt.add_tensor(TensorType::ACL_SRC_1, weights); + pack_opt.add_tensor(TensorType::ACL_SRC_2, bias); + pack_opt.add_tensor(TensorType::ACL_INT_1, packed_weights); + + // Prepare optimized function + _dwc_optimized_func->prepare(pack_opt); + } + + _is_prepared = true; + } +} + +CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::CpuDepthwiseConvolutionGeneric() + : _depthwise_conv_kernel(nullptr), _permute_input(nullptr), _permute_weights(nullptr), _permute_output(nullptr), _activationlayer_function(nullptr), _is_nchw(true), _is_prepared(false), + _is_activationlayer_enabled(false) +{ +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + ARM_COMPUTE_ERROR_THROW_ON(CpuDepthwiseConvolution::validate(input, weights, (biases == nullptr) ? nullptr : biases, + output, info)); + + _is_nchw = input->data_layout() == DataLayout::NCHW; + _is_prepared = !_is_nchw; + + ITensorInfo *input_to_use = input; + const ITensorInfo *weights_to_use = weights; + ITensorInfo *output_to_use = output; + + auto input_perm = std::make_unique(); + auto weights_perm = std::make_unique(); + auto output_perm = std::make_unique(output->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(TensorShape())); + + if(_is_nchw) + { + _permute_input = std::make_unique(); + _permute_weights = std::make_unique(); + + _permute_input->configure(input, input_perm.get(), PermutationVector(2U, 0U, 1U)); + input_perm->set_data_layout(DataLayout::NHWC); + input_to_use = input_perm.get(); + + _permute_weights->configure(weights, weights_perm.get(), PermutationVector(2U, 0U, 1U)); + weights_perm->set_data_layout(DataLayout::NHWC); + weights_to_use = weights_perm.get(); + + output_to_use = output_perm.get(); + } + + _depthwise_conv_kernel = std::make_unique(); + _depthwise_conv_kernel->configure(input_to_use, weights_to_use, biases, output_to_use, info); + + if(_is_nchw) + { + _permute_output = std::make_unique(); + _permute_output->configure(output_perm.get(), output, PermutationVector(1U, 2U, 0U)); + output_perm->set_data_layout(DataLayout::NHWC); + } + + //Configure Activation Layer + _is_activationlayer_enabled = info.act_info.enabled(); + if(_is_activationlayer_enabled) + { + _activationlayer_function = std::make_unique(); + _activationlayer_function->configure(output, nullptr, info.act_info); + } +} + +Status CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); + if(input->data_layout() == DataLayout::NCHW) + { + TensorShape permuted_input_shape = input->tensor_shape(); + TensorShape permuted_weights_shape = weights->tensor_shape(); + TensorShape permuted_output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info); + permute(permuted_input_shape, PermutationVector(2U, 0U, 1U)); + permute(permuted_weights_shape, PermutationVector(2U, 0U, 1U)); + permute(permuted_output_shape, PermutationVector(2U, 0U, 1U)); + + const TensorInfo permuted_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_input_shape).set_data_layout(DataLayout::NHWC)); + const TensorInfo permuted_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_weights_shape).set_data_layout(DataLayout::NHWC)); + const TensorInfo permuted_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_output_shape).set_data_layout(DataLayout::NCHW)); + + ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(input, &permuted_input, PermutationVector(2U, 0U, 1U))); + ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(weights, &permuted_weights, PermutationVector(2U, 0U, 1U))); + ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(&permuted_output, output, PermutationVector(1U, 2U, 0U))); + + ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuDepthwiseConvolutionNativeKernel::validate(&permuted_input, &permuted_weights, biases, &permuted_output, info)); + } + else + { + ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuDepthwiseConvolutionNativeKernel::validate(input, weights, biases, output, info)); + } + + // Validate Activation Layer + if(info.act_info.enabled()) + { + ARM_COMPUTE_RETURN_ON_ERROR(CpuActivation::validate(output, nullptr, info.act_info)); + } + + return Status{}; +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::run(ITensorPack &tensors) +{ + auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0); + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto biases = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto dst = tensors.get_tensor(TensorType::ACL_DST_0); + + if(_is_nchw) + { + prepare(tensors); + auto src_perm = tensors.get_tensor(TensorType::ACL_INT_0); + auto weights_perm = tensors.get_tensor(TensorType::ACL_INT_1); + auto dst_perm = tensors.get_tensor(TensorType::ACL_INT_2); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, src); + pack.add_tensor(TensorType::ACL_DST, src_perm); + _permute_input->run(pack); + + ITensorPack pack_depth; + pack_depth.add_const_tensor(TensorType::ACL_SRC_0, src_perm); + pack_depth.add_const_tensor(TensorType::ACL_SRC_1, weights_perm); + pack_depth.add_tensor(TensorType::ACL_SRC_2, biases); + pack_depth.add_tensor(TensorType::ACL_DST, dst_perm); + NEScheduler::get().schedule_op(_depthwise_conv_kernel.get(), Window::DimY, _depthwise_conv_kernel->window(), pack_depth); + } + else + { + ITensorPack pack_depth; + pack_depth.add_tensor(TensorType::ACL_SRC_0, src); + pack_depth.add_tensor(TensorType::ACL_SRC_1, weights); + pack_depth.add_tensor(TensorType::ACL_SRC_2, biases); + pack_depth.add_tensor(TensorType::ACL_DST, dst); + NEScheduler::get().schedule_op(_depthwise_conv_kernel.get(), Window::DimY, _depthwise_conv_kernel->window(), pack_depth); + } + + if(_is_nchw) + { + ITensorPack pack; + auto dst_perm = tensors.get_tensor(TensorType::ACL_INT_2); + pack.add_tensor(TensorType::ACL_SRC, dst_perm); + pack.add_tensor(TensorType::ACL_DST, dst); + _permute_output->run(pack); + } + + if(_is_activationlayer_enabled) + { + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, dst); + pack.add_tensor(TensorType::ACL_DST, dst); + _activationlayer_function->run(pack); + } +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::prepare(ITensorPack &tensors) +{ + if(!_is_prepared) + { + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto weights_perm = tensors.get_tensor(TensorType::ACL_INT_1); + + ARM_COMPUTE_ERROR_ON(!weights->is_used()); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, weights); + pack.add_tensor(TensorType::ACL_DST, weights_perm); + + _permute_weights->run(pack); + weights->mark_as_unused(); + _is_prepared = true; + } +} + +CpuDepthwiseConvolution::CpuDepthwiseConvolution() + : _depth_conv_func(DepthwiseConvolutionFunction::GENERIC), _func_optimized(), _func_generic() +{ +} + +void CpuDepthwiseConvolution::configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info) +{ + _depth_conv_func = get_depthwiseconvolution_function(input, weights, (biases != nullptr) ? biases : nullptr, output, info); + switch(_depth_conv_func) + { + case DepthwiseConvolutionFunction::OPTIMIZED: + _func_optimized.configure(input, weights, biases, output, info); + break; + case DepthwiseConvolutionFunction::GENERIC: + _func_generic.configure(input, weights, biases, output, info); + break; + default: + ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction"); + } +} + +Status CpuDepthwiseConvolution::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info) +{ + DepthwiseConvolutionFunction depth_conv_func = get_depthwiseconvolution_function(input, weights, biases, output, info); + switch(depth_conv_func) + { + case DepthwiseConvolutionFunction::OPTIMIZED: + return CpuDepthwiseConvolutionOptimizedInternal::validate(input, weights, biases, output, info); + break; + case DepthwiseConvolutionFunction::GENERIC: + return CpuDepthwiseConvolutionGeneric::validate(input, weights, biases, output, info); + break; + default: + ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction"); + } +} + +DepthwiseConvolutionFunction CpuDepthwiseConvolution::get_depthwiseconvolution_function(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const ConvolutionInfo &info) +{ + if(bool(CpuDepthwiseConvolutionOptimizedInternal::validate(input, weights, biases, output, info))) + { + return DepthwiseConvolutionFunction::OPTIMIZED; + } + else + { + return DepthwiseConvolutionFunction::GENERIC; + } +} + +void CpuDepthwiseConvolution::run(ITensorPack &tensors) +{ + switch(_depth_conv_func) + { + case DepthwiseConvolutionFunction::OPTIMIZED: + _func_optimized.run(tensors); + break; + case DepthwiseConvolutionFunction::GENERIC: + _func_generic.run(tensors); + break; + default: + ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured"); + } +} + +void CpuDepthwiseConvolution::prepare(ITensorPack &tensors) +{ + switch(_depth_conv_func) + { + case DepthwiseConvolutionFunction::OPTIMIZED: + _func_optimized.prepare(tensors); + break; + case DepthwiseConvolutionFunction::GENERIC: + _func_generic.prepare(tensors); + break; + default: + ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured"); + } +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/runtime/cpu/operators/CpuDepthwiseConvolution.h b/src/runtime/cpu/operators/CpuDepthwiseConvolution.h new file mode 100644 index 0000000000..e39cb7db4d --- /dev/null +++ b/src/runtime/cpu/operators/CpuDepthwiseConvolution.h @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CPU_DEQUANTIZATION_H +#define ARM_COMPUTE_CPU_DEQUANTIZATION_H + +#include "arm_compute/core/ITensorInfo.h" +#include "arm_compute/core/experimental/Types.h" +#include "src/core/cpu/ICpuKernel.h" +#include "src/core/cpu/kernels/CpuDepthwiseConvolutionNativeKernel.h" +#include "src/runtime/cpu/ICpuOperator.h" +#include "src/runtime/cpu/operators/CpuActivation.h" +#include "src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h" +#include "src/runtime/cpu/operators/CpuPermute.h" + +#include + +namespace arm_compute +{ +namespace cpu +{ +/** Function to execute a depthwise convolution. + */ +class CpuDepthwiseConvolution : public ICpuOperator +{ +public: + /** Default constructor */ + CpuDepthwiseConvolution(); + /** Initialize the function's source, destination, weights and convolution information. + * + * @param[in, out] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] weights Weights tensor info. These are 3D tensor infos with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] info Depthwise convolution meta-data. + */ + void configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info); + + /** Static function to check if given info will lead to a valid configuration of @ref CpuDepthwiseConvolution + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + * @param[in] output Destination tensor info. Data type supported: same as @p input. + * @param[in] weights Weights tensor info. These are 3D tensors info with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] info Depthwise convolution meta-data. + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info); + + /** Static function to choose the best depthwise convolution function for @ref CpuDepthwiseConvolution + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] output Destination tensor. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + * + * @return a Depthwise Convolution Function + */ + static DepthwiseConvolutionFunction get_depthwiseconvolution_function(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const ConvolutionInfo &info); + + // Inherited methods overriden: + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + +private: + /** Basic function to execute optimized depthwise convolution routines. This function calls the following kernels: + * + * @note At the moment 3x3 and 5x5 convolution of stride 1, 2 are supported + * + * -# @ref NEFillBorderKernel (if pad_x or pad_y > 0) and no assembly kernel implementation is present + * -# @ref CpuDepthwiseConvolution3x3Kernel if 3x3 and no assembly kernel implementation is present + * -# @ref NEDepthwiseConvolutionAssemblyDispatch if assembly kernel implementation is present + * -# @ref NEDirectConvolutionLayerOutputStageKernel if re-quantization of output is required + * -# @ref NEActivationLayer if fused activation is required + * + */ + class CpuDepthwiseConvolutionOptimizedInternal : public ICpuOperator + { + public: + /** Default constructor */ + CpuDepthwiseConvolutionOptimizedInternal(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionOptimizedInternal(const CpuDepthwiseConvolutionOptimizedInternal &) = delete; + /** Default move constructor */ + CpuDepthwiseConvolutionOptimizedInternal(CpuDepthwiseConvolutionOptimizedInternal &&) = default; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionOptimizedInternal &operator=(const CpuDepthwiseConvolutionOptimizedInternal &) = delete; + /** Default move assignment operator */ + CpuDepthwiseConvolutionOptimizedInternal &operator=(CpuDepthwiseConvolutionOptimizedInternal &&) = default; + /** Default destructor */ + ~CpuDepthwiseConvolutionOptimizedInternal() = default; + /** Initialize the function's source, destination, kernels and border_size. + * + * @param[in, out] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. (Written to only for border filling). + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. Data type supported: Same as @p input. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + */ + void configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info); + + /** Static function to check if given info will lead to a valid configuration of @ref CpuDepthwiseConvolution3x3 + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. (Written to only for border filling). + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. Data type supported: Same as @p input. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] output Destination tensor info. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info); + + // Inherited methods overriden: + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + + private: + std::unique_ptr _dwc_optimized_func{ nullptr }; + std::unique_ptr _permute_input{ nullptr }; + std::unique_ptr _permute_weights{ nullptr }; + std::unique_ptr _permute_output{ nullptr }; + std::unique_ptr _activationlayer_function{ nullptr }; + bool _has_bias{ false }; + bool _is_quantized{ false }; + bool _is_nchw{ true }; + bool _permute{ false }; + bool _is_activationlayer_enabled{ false }; + bool _is_prepared{ false }; + }; + + /** Basic function to execute a generic depthwise convolution. This function calls the following kernel: + * + * -# @ref CpuDepthwiseConvolutionNativeKernel + * + */ + class CpuDepthwiseConvolutionGeneric : public ICpuOperator + { + public: + /** Default constructor */ + CpuDepthwiseConvolutionGeneric(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionGeneric(const CpuDepthwiseConvolutionGeneric &) = delete; + /** Default move constructor */ + CpuDepthwiseConvolutionGeneric(CpuDepthwiseConvolutionGeneric &&) = default; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionGeneric &operator=(const CpuDepthwiseConvolutionGeneric &) = delete; + /** Default move assignment operator */ + CpuDepthwiseConvolutionGeneric &operator=(CpuDepthwiseConvolutionGeneric &&) = default; + /** Default destructor */ + ~CpuDepthwiseConvolutionGeneric() = default; + /** Initialize the function's source, destination, weights and convolution information. + * + * @param[in, out] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. (Written to only for border filling). + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] info Depthwise convolution meta-data. + */ + void configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info); + + /** Static function to check if given info will lead to a valid configuration of @ref CpuDepthwiseConvolutionGeneric + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. (Written to only for border filling). + * @param[in] output Destination tensor info. Data type supported: same as @p input. + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] info Depthwise convolution meta-data. + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info); + + // Inherited methods overridden: + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + + private: + std::unique_ptr _depthwise_conv_kernel{ nullptr }; + std::unique_ptr _permute_input{ nullptr }; + std::unique_ptr _permute_weights{ nullptr }; + std::unique_ptr _permute_output{ nullptr }; + std::unique_ptr _activationlayer_function{ nullptr }; + bool _is_nchw{ true }; + bool _is_prepared{ false }; + bool _is_activationlayer_enabled{ false }; + }; + + DepthwiseConvolutionFunction _depth_conv_func; + CpuDepthwiseConvolutionOptimizedInternal _func_optimized; + CpuDepthwiseConvolutionGeneric _func_generic; +}; +} // namespace cpu +} // namespace arm_compute +#endif /* ARM_COMPUTE_CPU_DEQUANTIZATION_H */ diff --git a/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.cpp b/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.cpp new file mode 100644 index 0000000000..5f5304cded --- /dev/null +++ b/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.cpp @@ -0,0 +1,564 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h" + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/misc/InfoHelpers.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" +#include "src/core/CPP/Validate.h" +#include "src/core/NEON/kernels/assembly/NEDepthwiseConvolutionAssemblyKernelWrapper.h" +#include "src/core/NEON/kernels/convolution/depthwise/depthwise_dilated.hpp" +#include "src/core/NEON/kernels/convolution/depthwise/depthwise_quantized_dilated.hpp" +#include "src/core/helpers/AutoConfiguration.h" + +#include "arm_compute/runtime/NEON/NEScheduler.h" + +#include + +namespace arm_compute +{ +namespace cpu +{ +namespace +{ +std::unique_ptr get_qasymm8_convolver(int kernel_size, int stride_x, + int n_batches, int in_rows, int in_cols, int n_channels, + int dilation_factor, neon_convolution_kernels::ActivationFunction activation, + const qasymm8::QAsymm8Params &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo, + const qasymm8::QAsymm8RescaleParams &rescale_params, + int padding_top, int padding_left, int padding_bottom, int padding_right) +{ + switch(kernel_size) + { + case 3: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + case 5: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + default: + return nullptr; + } +} + +std::unique_ptr get_qsymm8_perchannel_convolver(int kernel_size, int stride_x, + int n_batches, int in_rows, int in_cols, int n_channels, + neon_convolution_kernels::ActivationFunction activation, + const qsymm8::QSymm8PerChannelParams &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo, + const qsymm8::QSymm8PerChannelRescaleParams &rescale_params, + int padding_top, int padding_left, int padding_bottom, int padding_right) +{ + switch(kernel_size) + { + case 3: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + case 5: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + default: + return nullptr; + } +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +std::unique_ptr get_fp16_convolver(int kernel_size, int stride_x, + int n_batches, int in_rows, int in_cols, int n_channels, + int dilation_factor, neon_convolution_kernels::ActivationFunction activation, + int padding_top, int padding_left, int padding_bottom, int padding_right) +{ + switch(kernel_size) + { + case 3: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + case 5: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + default: + return nullptr; + } +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +std::unique_ptr get_fp32_convolver(int kernel_size, int stride_x, + int n_batches, int in_rows, int in_cols, int n_channels, + int dilation_factor, neon_convolution_kernels::ActivationFunction activation, + int padding_top, int padding_left, int padding_bottom, int padding_right) +{ + switch(kernel_size) + { + case 3: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + case 5: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + default: + return nullptr; + } +} + +std::unique_ptr create_convolver(const ITensorInfo *input, + const ITensorInfo *weights, + ITensorInfo *output, + const ConvolutionInfo &info) +{ + const DataType data_type = input->data_type(); + const TensorShape shape = input->tensor_shape(); + + const int n_batches = shape[3]; + const int in_rows = shape.z(); + const int in_cols = shape.y(); + const int n_channels = shape.x(); + const int dilation_factor = info.dilation.x(); + const int padding_top = info.pad_stride_info.pad_top(); + const int padding_left = info.pad_stride_info.pad_left(); + const int padding_bottom = info.pad_stride_info.pad_bottom(); + const int padding_right = info.pad_stride_info.pad_right(); + + const bool is_uniform_quantized = (data_type == DataType::QASYMM8) && (weights->data_type() == DataType::QASYMM8); + const bool is_perchannel_quantized = (data_type == DataType::QASYMM8) && (weights->data_type() == DataType::QSYMM8_PER_CHANNEL); + + const unsigned int stride_x = info.pad_stride_info.stride().first; + const unsigned int kernel_size = weights->tensor_shape().y(); + + // Map activation function + neon_convolution_kernels::ActivationFunction activation = neon_convolution_kernels::ActivationFunction::None; + if(arm_compute::utils::info_helpers::is_relu(info.act_info)) + { + activation = neon_convolution_kernels::ActivationFunction::ReLU; + } + else if(arm_compute::utils::info_helpers::is_relu6(info.act_info)) + { + activation = neon_convolution_kernels::ActivationFunction::ReLU6; + } + + // Create quantized convolver + if(is_uniform_quantized) + { + const UniformQuantizationInfo input_qinfo = input->quantization_info().uniform(); + const UniformQuantizationInfo weights_qinfo = weights->quantization_info().uniform(); + const UniformQuantizationInfo output_qinfo = output->quantization_info().uniform(); + + // Check that quantization info are in the range [0, 255] + ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255); + ARM_COMPUTE_ERROR_ON(weights_qinfo.offset < 0 || weights_qinfo.offset > 255); + ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255); + const qasymm8::QAsymm8Params iqinfo{ static_cast(input_qinfo.offset), input_qinfo.scale }; + const qasymm8::QAsymm8Params wqinfo{ static_cast(weights_qinfo.offset), weights_qinfo.scale }; + const qasymm8::QAsymm8Params oqinfo{ static_cast(output_qinfo.offset), output_qinfo.scale }; + + // Calculate rescale parameters + const float fmultipler = iqinfo.scale * wqinfo.scale / oqinfo.scale; + int32_t qmultiplier = 0; + int32_t qshift = 0; + quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift); + qasymm8::QAsymm8RescaleParams rescale_params(qshift, qmultiplier, fmultipler); + + return get_qasymm8_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, + wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + } + else if(is_perchannel_quantized) + { + const UniformQuantizationInfo input_qinfo = input->quantization_info().uniform(); + const QuantizationInfo weights_qinfo = weights->quantization_info(); + const UniformQuantizationInfo output_qinfo = output->quantization_info().uniform(); + + // Check that quantization info are in the range [0, 255] + ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255); + ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255); + const qasymm8::QAsymm8Params iqinfo{ static_cast(input_qinfo.offset), input_qinfo.scale }; + const qsymm8::QSymm8PerChannelParams wqinfo{ weights_qinfo.scale() }; + const qasymm8::QAsymm8Params oqinfo{ static_cast(output_qinfo.offset), output_qinfo.scale }; + + // Calculate rescale parameters + std::vector fmultipliers; + std::vector qmultipliers; + std::vector qshifts; + + for(auto const s : wqinfo.scales) + { + const float fmultipler = iqinfo.scale * s / oqinfo.scale; + int32_t qmultiplier = 0; + int32_t qshift = 0; + quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift); + fmultipliers.push_back(fmultipler); + qmultipliers.push_back(qmultiplier); + qshifts.push_back(qshift); + } + + qsymm8::QSymm8PerChannelRescaleParams rescale_params(qshifts, qmultipliers, fmultipliers); + + return get_qsymm8_perchannel_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, activation, + wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + } + else + { + // Create float convolver + switch(data_type) + { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + { + return get_fp16_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F32: + { + return get_fp32_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + } + default: + return nullptr; + } + } +} +} // namespace + +struct CpuDepthwiseConvolutionAssemblyDispatch::LocalImpl +{ + std::unique_ptr dwc_assembly_kernel{ nullptr }; + NEDepthwiseConvolutionAssemblyKernelWrapper dwc_acl_kernel{}; + bool is_prepared{ false }; + experimental::MemoryRequirements mem_req{}; +}; + +#ifndef DOXYGEN_SKIP_THIS +CpuDepthwiseConvolutionAssemblyDispatch::CpuDepthwiseConvolutionAssemblyDispatch() + : _pImpl(std::make_unique()) +{ +} +#endif /* DOXYGEN_SKIP_THIS */ + +CpuDepthwiseConvolutionAssemblyDispatch::~CpuDepthwiseConvolutionAssemblyDispatch() = default; + +void CpuDepthwiseConvolutionAssemblyDispatch::configure(const ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *bias, + ITensorInfo *output, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + ARM_COMPUTE_UNUSED(bias); + ARM_COMPUTE_ERROR_THROW_ON(CpuDepthwiseConvolutionAssemblyDispatch::validate(input, + weights, + bias != nullptr ? bias : nullptr, + output, + info)); + + // Output auto inizialitation if not yet initialized + const TensorShape output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info); + auto_init_if_empty(*output, input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape).set_quantization_info(output->quantization_info())); + + _pImpl->is_prepared = false; + + // Create convolver + _pImpl->dwc_assembly_kernel = create_convolver(input, weights, output, info); + ARM_COMPUTE_ERROR_ON(_pImpl->dwc_assembly_kernel == nullptr); + + // Create assembly kernel wrapper + _pImpl->dwc_acl_kernel.configure(_pImpl->dwc_assembly_kernel.get()); + + constexpr size_t alignment = 128; + + // Create workspace + const unsigned int num_threads = NEScheduler::get().num_threads(); + const size_t workspace_size = _pImpl->dwc_assembly_kernel->get_working_space_size(num_threads); + ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "Workspace size cannot be 0 !"); + _pImpl->mem_req.push_back({ TensorType::ACL_INT_0, workspace_size, alignment }); + + // Create packing tensor + const size_t pack_tensor_size = _pImpl->dwc_assembly_kernel->get_packed_params_size(); + ARM_COMPUTE_ERROR_ON_MSG(pack_tensor_size == 0, "Pack tensor size cannot be 0 !"); + + _pImpl->mem_req.push_back({ TensorType::ACL_INT_1, pack_tensor_size, alignment }); +} + +experimental::MemoryRequirements CpuDepthwiseConvolutionAssemblyDispatch::workspace() const +{ + return _pImpl->mem_req; +} + +Status CpuDepthwiseConvolutionAssemblyDispatch::validate(const ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *bias, + const ITensorInfo *output, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); + if(weights->data_type() != DataType::QSYMM8_PER_CHANNEL) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + } + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights); + + // Validate convolver + ARM_COMPUTE_RETURN_ERROR_ON(!is_optimized_supported(input, weights, info)); + + // Validate activation + const bool is_relu = arm_compute::utils::info_helpers::is_relu(info.act_info); + const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(info.act_info); + ARM_COMPUTE_RETURN_ERROR_ON(info.act_info.enabled() && !(is_relu || is_relu6)); + + // Check bias + if(bias != nullptr) + { + unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL); + ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != weights->dimension(channel_idx)); + } + + // Check output + if(output->total_size() != 0) + { + const TensorShape output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + } + + // The uniform quantization case will only have 1 scale value in the weights quantization info + const UniformQuantizationInfo input_qinfo = input->quantization_info().uniform(); + const QuantizationInfo weights_qinfo = weights->quantization_info(); + const UniformQuantizationInfo output_qinfo = output->quantization_info().uniform(); + for(auto const s : weights_qinfo.scale()) + { + const float fmultipler = input_qinfo.scale * s / output_qinfo.scale; + ARM_COMPUTE_RETURN_ERROR_ON(fmultipler > 1.f); + } + + return Status{}; +} + +bool CpuDepthwiseConvolutionAssemblyDispatch::is_optimized_supported(const ITensorInfo *input, + const ITensorInfo *weights, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights); + + // Reshape input shape if in NHWC format + const DataLayout data_layout = input->data_layout(); + TensorShape in_shape{ input->tensor_shape() }; + if(data_layout == DataLayout::NHWC) + { + in_shape.set(Window::DimX, input->tensor_shape().y()); + in_shape.set(Window::DimY, input->tensor_shape().z()); + in_shape.set(Window::DimZ, input->tensor_shape().x()); + } + + // Check data type + // TODO (COMPMID-3004): Add assembly optimized routine for QASYMM8_SIGNED NEDepthwiseConvolutionLayer + const DataType input_type = input->data_type(); + const bool is_input_type_valid = is_data_type_float(input_type) || input_type == DataType::QASYMM8; + const DataType weights_type = weights->data_type(); + const bool is_weights_type_valid = is_data_type_float(weights_type) || weights_type == DataType::QASYMM8 || weights_type == DataType::QASYMM8_SIGNED + || weights_type == DataType::QSYMM8_PER_CHANNEL; + + // Check weighs size + std::set supported_kernel_sizes = { 3, 5 }; + const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const unsigned int kernel_w = weights->dimension(width_idx); + const unsigned int kernel_h = weights->dimension(height_idx); + bool weights_supported = (kernel_w == kernel_h) && (supported_kernel_sizes.count(kernel_w) != 0); + + // Check for supported strides + const auto &strides = info.pad_stride_info.stride(); + bool supported_strides = (strides.first == strides.second) && ((strides.first == 1) || (strides.first == 2)); + + // Check for supported padding + const auto pad_top = info.pad_stride_info.pad_top(); + const auto pad_right = info.pad_stride_info.pad_right(); + const auto pad_bottom = info.pad_stride_info.pad_bottom(); + const auto pad_left = info.pad_stride_info.pad_left(); + PadStrideInfo same_pad = calculate_same_pad(in_shape, TensorShape(kernel_w, kernel_h), info.pad_stride_info, DataLayout::NCHW, info.dilation); + bool is_same_padding = (pad_top == same_pad.pad_top()) && (pad_right == same_pad.pad_right()) && (pad_bottom == same_pad.pad_bottom()) && (pad_left == same_pad.pad_left()); + bool is_valid_padding = (pad_top == 0) && (pad_right == 0) && (pad_bottom == 0) && (pad_left == 0); + bool supported_padding = is_same_padding || is_valid_padding; + // TODO(COMPMID-2464): Enable once dilated conv with stride 2 is supported + bool is_dilation_supported = ((info.dilation == Size2D(1U, 1U)) || ((info.dilation.x() == info.dilation.y()) && strides.first == 1)); + + if(weights_type == DataType::QSYMM8_PER_CHANNEL) + { + is_dilation_supported = is_dilation_supported && (info.dilation == Size2D(1U, 1U)); + } + + return is_input_type_valid && is_weights_type_valid && weights_supported && supported_strides && supported_padding && (info.depth_multiplier == 1) && is_dilation_supported; +} + +void CpuDepthwiseConvolutionAssemblyDispatch::run(ITensorPack &tensors) +{ + // Prepare assembly kernel + prepare(tensors); + + auto src = tensors.get_tensor(TensorType::ACL_SRC_0); + auto workspace = tensors.get_tensor(TensorType::ACL_INT_0); + auto dst = tensors.get_tensor(TensorType::ACL_DST); + + // Setup inputs/outputs + ARM_COMPUTE_ERROR_ON(workspace == nullptr && workspace->buffer() == nullptr); + _pImpl->dwc_assembly_kernel->set_working_space(static_cast(workspace->buffer())); + + ARM_COMPUTE_ERROR_ON(workspace->buffer() == nullptr); + const int input_element_size = src->info()->element_size(); + const int input_batch_stride = src->info()->strides_in_bytes()[3] / input_element_size; + const int input_row_stride = src->info()->strides_in_bytes().z() / input_element_size; + const int input_col_stride = src->info()->strides_in_bytes().y() / input_element_size; + const void *input_ptr = src->buffer() + src->info()->offset_first_element_in_bytes(); + _pImpl->dwc_assembly_kernel->set_input(input_ptr, input_batch_stride, input_row_stride, input_col_stride); + + ARM_COMPUTE_ERROR_ON(dst->buffer() == nullptr); + const int output_element_size = dst->info()->element_size(); + const int output_batch_stride = dst->info()->strides_in_bytes()[3] / output_element_size; + const int output_row_stride = dst->info()->strides_in_bytes().z() / output_element_size; + const int output_col_stride = dst->info()->strides_in_bytes().y() / output_element_size; + void *output_ptr = dst->buffer() + dst->info()->offset_first_element_in_bytes(); + _pImpl->dwc_assembly_kernel->set_output(output_ptr, output_batch_stride, output_row_stride, output_col_stride); + + // Schedule assembly kernel + NEScheduler::get().schedule(&_pImpl->dwc_acl_kernel, Window::DimX); +} + +void CpuDepthwiseConvolutionAssemblyDispatch::prepare(ITensorPack &tensors) +{ + if(!_pImpl->is_prepared) + { + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto bias = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto packed_weights = tensors.get_tensor(TensorType::ACL_INT_1); + + ARM_COMPUTE_ERROR_ON(packed_weights->buffer() == nullptr); + + // Pack weights and bias + const int weights_element_size = weights->info()->element_size(); + const int weights_row_stride = weights->info()->strides_in_bytes().z() / weights_element_size; + const int weights_col_stride = weights->info()->strides_in_bytes().y() / weights_element_size; + _pImpl->dwc_assembly_kernel->pack_params(packed_weights->buffer(), + weights->buffer() + weights->info()->offset_first_element_in_bytes(), + weights_row_stride, + weights_col_stride, + (bias != nullptr) ? bias->buffer() : nullptr); + _pImpl->dwc_assembly_kernel->set_packed_params_buffer(packed_weights->buffer()); + + weights->mark_as_unused(); + if(bias != nullptr) + { + bias->mark_as_unused(); + } + _pImpl->is_prepared = true; + } +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h b/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h new file mode 100644 index 0000000000..6aac74c3ef --- /dev/null +++ b/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CPU_DEPTHWISECONVOLUTIONASSEMBLYDISPATCH_H +#define ARM_COMPUTE_CPU_DEPTHWISECONVOLUTIONASSEMBLYDISPATCH_H + +#include "src/runtime/cpu/ICpuOperator.h" + +namespace arm_compute +{ +namespace cpu +{ +/** Depthwise convolution assembly kernel glue */ +class CpuDepthwiseConvolutionAssemblyDispatch : public ICpuOperator +{ +public: + CpuDepthwiseConvolutionAssemblyDispatch(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionAssemblyDispatch(const CpuDepthwiseConvolutionAssemblyDispatch &) = delete; + /** Default move constructor */ + CpuDepthwiseConvolutionAssemblyDispatch(CpuDepthwiseConvolutionAssemblyDispatch &&) = default; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionAssemblyDispatch &operator=(const CpuDepthwiseConvolutionAssemblyDispatch &) = delete; + /** Default move assignment operator */ + CpuDepthwiseConvolutionAssemblyDispatch &operator=(CpuDepthwiseConvolutionAssemblyDispatch &&) = default; + /** Default destructor */ + ~CpuDepthwiseConvolutionAssemblyDispatch(); + /** Initialize the function's source, destination, kernels and border_size. + * + * @note Supports only NHWC format + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/F16/F32. (Written to only for border filling). + * @param[in] weights Weights tensor info. These are 3D tensors with shape [W, H, IFM]. Data type supported: Same as @p input. + * @param[in] bias (Optional) Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input. + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + */ + void configure(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *output, const ConvolutionInfo &info); + /** Static function to check if given info will lead to a valid configuration of @ref CpuDepthwiseConvolutionAssemblyDispatch + * + * @note Supports only NHWC format + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/F16/F32. (Written to only for border filling). + * @param[in] weights Weights tensor info. These are 3D tensors with shape [W, H, IFM]. Data type supported: Same as @p input. + * @param[in] bias (Optional) Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input. + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + * + * @return An error status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *output, const ConvolutionInfo &info); + /** Check if the optimized kernel can be used for the given kernel sizes and strides + * + * @warning Even if this return true the inputs and outputs might need to get permuted as the only layout supported is NHWC + * + * @param[in] input Input tensor info. + * @param[in] weights Weights tensor info. + * @param[in] info Depthwise convolution meta-data. + * + * @return True if the assembly kernel could be used else false. Note that transformations of input/output could be needed. + */ + static bool is_optimized_supported(const ITensorInfo *input, const ITensorInfo *weights, const ConvolutionInfo &info); + + // Inherited methods overridden: + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + experimental::MemoryRequirements workspace() const override; + +private: + struct LocalImpl; + std::unique_ptr _pImpl; +}; +} // namespace cpu +} // namespace arm_compute +#endif /* ARM_COMPUTE_CPU_DEPTHWISECONVOLUTIONASSEMBLYDISPATCH_H */ -- cgit v1.2.1