From 60c3b0e6821a80d78ffca5be30e05d062d071cd2 Mon Sep 17 00:00:00 2001 From: Michalis Spyrou Date: Thu, 8 Apr 2021 12:02:58 +0100 Subject: Port DepthwiseConvolution to new API Resolves: COMPMID-4185 Change-Id: Ib5f22356356a022d567bb18d44ea272b62d10ebf Signed-off-by: Michalis Spyrou Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5424 Reviewed-by: Michele Di Giorgio Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins --- .../cpu/operators/CpuDepthwiseConvolution.cpp | 521 +++++++++++++++++++ .../cpu/operators/CpuDepthwiseConvolution.h | 230 +++++++++ .../CpuDepthwiseConvolutionAssemblyDispatch.cpp | 564 +++++++++++++++++++++ .../CpuDepthwiseConvolutionAssemblyDispatch.h | 97 ++++ 4 files changed, 1412 insertions(+) create mode 100644 src/runtime/cpu/operators/CpuDepthwiseConvolution.cpp create mode 100644 src/runtime/cpu/operators/CpuDepthwiseConvolution.h create mode 100644 src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.cpp create mode 100644 src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h (limited to 'src/runtime/cpu/operators') diff --git a/src/runtime/cpu/operators/CpuDepthwiseConvolution.cpp b/src/runtime/cpu/operators/CpuDepthwiseConvolution.cpp new file mode 100644 index 0000000000..183a2af0cd --- /dev/null +++ b/src/runtime/cpu/operators/CpuDepthwiseConvolution.cpp @@ -0,0 +1,521 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/runtime/cpu/operators/CpuDepthwiseConvolution.h" + +#include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/InfoHelpers.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/runtime/NEON/NEScheduler.h" +#include "src/core/cpu/kernels/CpuDepthwiseConvolutionNativeKernel.h" + +namespace arm_compute +{ +namespace cpu +{ +namespace +{ +Status validate_arguments_optimized(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); + if(!is_data_type_quantized_per_channel(weights->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + } + ARM_COMPUTE_RETURN_ERROR_ON(input->data_layout() == DataLayout::UNKNOWN); + ARM_COMPUTE_RETURN_ERROR_ON(info.dilation.x() < 1 || info.dilation.y() < 1); + const size_t idx_w = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::WIDTH); + const size_t idx_h = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::HEIGHT); + ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_w) + (weights->dimension(idx_w) - 1) * (info.dilation.x() - 1) > input->dimension(idx_w) + info.pad_stride_info.pad_left() + + info.pad_stride_info.pad_right()); + ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(idx_h) + (weights->dimension(idx_h) - 1) * (info.dilation.y() - 1) > input->dimension(idx_h) + info.pad_stride_info.pad_top() + + info.pad_stride_info.pad_bottom()); + + if(biases != nullptr) + { + const unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL); + ARM_COMPUTE_RETURN_ERROR_ON(biases->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(biases->dimension(0) != weights->dimension(channel_idx)); + } + + ARM_COMPUTE_RETURN_ON_ERROR(CpuDepthwiseConvolutionAssemblyDispatch::validate(input, weights, biases, output, info)); + + //Validate Activation Layer + if(info.act_info.enabled()) + { + ARM_COMPUTE_RETURN_ON_ERROR(CpuActivation::validate(output, nullptr, info.act_info)); + } + return Status{}; +} +} // namespace + +CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::CpuDepthwiseConvolutionOptimizedInternal() + : _dwc_optimized_func(nullptr), _permute_input(nullptr), _permute_weights(nullptr), _permute_output(nullptr), _activationlayer_function(nullptr), _has_bias(false), _is_quantized(false), + _is_nchw(true), _permute(false), _is_activationlayer_enabled(false), _is_prepared(false) +{ +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::configure(ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *biases, + ITensorInfo *output, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(CpuDepthwiseConvolutionOptimizedInternal::validate(input, weights, (biases == nullptr) ? nullptr : biases, + output, info)); + + _is_quantized = is_data_type_quantized_asymmetric(input->data_type()); + _has_bias = biases != nullptr; + _is_nchw = input->data_layout() == DataLayout::NCHW; + _permute = _is_nchw; + _is_prepared = false; + + // Configure pipeline + ActivationLayerInfo act_info_to_use = ActivationLayerInfo(); + const bool is_relu = arm_compute::utils::info_helpers::is_relu(info.act_info); + const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(info.act_info); + _is_activationlayer_enabled = info.act_info.enabled() && !(is_relu || is_relu6); + + if(!_is_activationlayer_enabled) + { + act_info_to_use = info.act_info; + } + + _dwc_optimized_func = std::make_unique(); + if(_is_nchw) + { + _permute_input = std::make_unique(); + _permute_weights = std::make_unique(); + _permute_output = std::make_unique(); + + auto input_perm = std::make_unique(); + auto weights_perm = std::make_unique(); + auto output_perm = std::make_unique(); + + // Configure the function to transform the input tensor from NCHW -> NHWC + _permute_input->configure(input, input_perm.get(), PermutationVector(2U, 0U, 1U)); + input_perm->set_data_layout(DataLayout::NHWC); + + // Configure the function to transform the weights tensor from IHW -> HWI + _permute_weights->configure(weights, weights_perm.get(), PermutationVector(2U, 0U, 1U)); + weights_perm->set_data_layout(DataLayout::NHWC); + + output_perm->set_data_layout(DataLayout::NHWC); + output_perm->set_quantization_info(output->quantization_info()); + + // Configure optimized depthwise + _dwc_optimized_func->configure(input_perm.get(), weights_perm.get(), biases, output_perm.get(), info); + + // Configure the function to transform the convoluted output to ACL's native ordering format NCHW + output_perm->set_data_layout(DataLayout::NHWC); + _permute_output->configure(output_perm.get(), output, PermutationVector(1U, 2U, 0U)); + } + else + { + _dwc_optimized_func->configure(input, weights, biases, output, info); + } + + // Configure activation + if(_is_activationlayer_enabled) + { + _activationlayer_function = std::make_unique(); + _activationlayer_function->configure(output, nullptr, info.act_info); + } +} + +Status CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::validate(const ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *biases, + const ITensorInfo *output, + const ConvolutionInfo &info) +{ + return validate_arguments_optimized(input, weights, biases, output, info); +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::run(ITensorPack &tensors) +{ + ARM_COMPUTE_ERROR_ON_MSG(tensors.empty(), "No inputs provided"); + prepare(tensors); + + auto bias = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto dst = tensors.get_tensor(TensorType::ACL_DST_0); + auto workspace = tensors.get_tensor(TensorType::ACL_INT_3); + auto packed_weights = tensors.get_tensor(TensorType::ACL_INT_4); + + // Permute input + if(_permute) + { + ITensorPack pack; + auto src = tensors.get_tensor(TensorType::ACL_SRC_0); + auto src_perm = tensors.get_tensor(TensorType::ACL_INT_0); + pack.add_tensor(TensorType::ACL_SRC, src); + pack.add_tensor(TensorType::ACL_DST, src_perm); + _permute_input->run(pack); + } + + // Run assembly function + if(_is_nchw) + { + auto src_perm = tensors.get_tensor(TensorType::ACL_INT_0); + auto weights_perm = tensors.get_tensor(TensorType::ACL_INT_1); + auto dst_perm = tensors.get_tensor(TensorType::ACL_INT_2); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC_0, src_perm); + pack.add_tensor(TensorType::ACL_SRC_1, weights_perm); + pack.add_tensor(TensorType::ACL_SRC_2, bias); + pack.add_tensor(TensorType::ACL_INT_0, workspace); + pack.add_tensor(TensorType::ACL_INT_1, packed_weights); + pack.add_tensor(TensorType::ACL_DST, dst_perm); + _dwc_optimized_func->run(pack); + } + else + { + auto src = tensors.get_tensor(TensorType::ACL_SRC_0); + auto weights = tensors.get_tensor(TensorType::ACL_SRC_1); + auto dst = tensors.get_tensor(TensorType::ACL_DST); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC_0, src); + pack.add_tensor(TensorType::ACL_SRC_1, weights); + pack.add_tensor(TensorType::ACL_SRC_2, bias); + pack.add_tensor(TensorType::ACL_INT_0, workspace); + pack.add_tensor(TensorType::ACL_INT_1, packed_weights); + pack.add_tensor(TensorType::ACL_DST, dst); + _dwc_optimized_func->run(pack); + } + + // Permute output + if(_is_nchw) + { + ITensorPack pack; + auto dst_perm = tensors.get_tensor(TensorType::ACL_INT_2); + pack.add_tensor(TensorType::ACL_SRC, dst_perm); + pack.add_tensor(TensorType::ACL_DST, dst); + _permute_output->run(pack); + } + + // Run activation + if(_is_activationlayer_enabled) + { + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, dst); + pack.add_tensor(TensorType::ACL_DST, dst); + _activationlayer_function->run(pack); + } +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionOptimizedInternal::prepare(ITensorPack &tensors) +{ + if(!_is_prepared) + { + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto bias = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto packed_weights = tensors.get_tensor(TensorType::ACL_INT_4); + + // Permute weights + if(_permute) + { + auto permuted_weights = tensors.get_tensor(TensorType::ACL_INT_1); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, weights); + pack.add_tensor(TensorType::ACL_DST, permuted_weights); + _permute_weights->run(pack); + + ITensorPack pack_opt; + pack_opt.add_const_tensor(TensorType::ACL_SRC_1, permuted_weights); + pack_opt.add_tensor(TensorType::ACL_SRC_2, bias); + pack_opt.add_tensor(TensorType::ACL_INT_1, packed_weights); + + // Prepare optimized function + _dwc_optimized_func->prepare(pack_opt); + } + else + { + ITensorPack pack_opt; + pack_opt.add_tensor(TensorType::ACL_SRC_1, weights); + pack_opt.add_tensor(TensorType::ACL_SRC_2, bias); + pack_opt.add_tensor(TensorType::ACL_INT_1, packed_weights); + + // Prepare optimized function + _dwc_optimized_func->prepare(pack_opt); + } + + _is_prepared = true; + } +} + +CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::CpuDepthwiseConvolutionGeneric() + : _depthwise_conv_kernel(nullptr), _permute_input(nullptr), _permute_weights(nullptr), _permute_output(nullptr), _activationlayer_function(nullptr), _is_nchw(true), _is_prepared(false), + _is_activationlayer_enabled(false) +{ +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + ARM_COMPUTE_ERROR_THROW_ON(CpuDepthwiseConvolution::validate(input, weights, (biases == nullptr) ? nullptr : biases, + output, info)); + + _is_nchw = input->data_layout() == DataLayout::NCHW; + _is_prepared = !_is_nchw; + + ITensorInfo *input_to_use = input; + const ITensorInfo *weights_to_use = weights; + ITensorInfo *output_to_use = output; + + auto input_perm = std::make_unique(); + auto weights_perm = std::make_unique(); + auto output_perm = std::make_unique(output->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(TensorShape())); + + if(_is_nchw) + { + _permute_input = std::make_unique(); + _permute_weights = std::make_unique(); + + _permute_input->configure(input, input_perm.get(), PermutationVector(2U, 0U, 1U)); + input_perm->set_data_layout(DataLayout::NHWC); + input_to_use = input_perm.get(); + + _permute_weights->configure(weights, weights_perm.get(), PermutationVector(2U, 0U, 1U)); + weights_perm->set_data_layout(DataLayout::NHWC); + weights_to_use = weights_perm.get(); + + output_to_use = output_perm.get(); + } + + _depthwise_conv_kernel = std::make_unique(); + _depthwise_conv_kernel->configure(input_to_use, weights_to_use, biases, output_to_use, info); + + if(_is_nchw) + { + _permute_output = std::make_unique(); + _permute_output->configure(output_perm.get(), output, PermutationVector(1U, 2U, 0U)); + output_perm->set_data_layout(DataLayout::NHWC); + } + + //Configure Activation Layer + _is_activationlayer_enabled = info.act_info.enabled(); + if(_is_activationlayer_enabled) + { + _activationlayer_function = std::make_unique(); + _activationlayer_function->configure(output, nullptr, info.act_info); + } +} + +Status CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, weights, output); + if(input->data_layout() == DataLayout::NCHW) + { + TensorShape permuted_input_shape = input->tensor_shape(); + TensorShape permuted_weights_shape = weights->tensor_shape(); + TensorShape permuted_output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info); + permute(permuted_input_shape, PermutationVector(2U, 0U, 1U)); + permute(permuted_weights_shape, PermutationVector(2U, 0U, 1U)); + permute(permuted_output_shape, PermutationVector(2U, 0U, 1U)); + + const TensorInfo permuted_input = TensorInfo(input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_input_shape).set_data_layout(DataLayout::NHWC)); + const TensorInfo permuted_weights = TensorInfo(weights->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_weights_shape).set_data_layout(DataLayout::NHWC)); + const TensorInfo permuted_output = TensorInfo(output->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(permuted_output_shape).set_data_layout(DataLayout::NCHW)); + + ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(input, &permuted_input, PermutationVector(2U, 0U, 1U))); + ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(weights, &permuted_weights, PermutationVector(2U, 0U, 1U))); + ARM_COMPUTE_RETURN_ON_ERROR(CpuPermute::validate(&permuted_output, output, PermutationVector(1U, 2U, 0U))); + + ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuDepthwiseConvolutionNativeKernel::validate(&permuted_input, &permuted_weights, biases, &permuted_output, info)); + } + else + { + ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuDepthwiseConvolutionNativeKernel::validate(input, weights, biases, output, info)); + } + + // Validate Activation Layer + if(info.act_info.enabled()) + { + ARM_COMPUTE_RETURN_ON_ERROR(CpuActivation::validate(output, nullptr, info.act_info)); + } + + return Status{}; +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::run(ITensorPack &tensors) +{ + auto src = tensors.get_const_tensor(TensorType::ACL_SRC_0); + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto biases = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto dst = tensors.get_tensor(TensorType::ACL_DST_0); + + if(_is_nchw) + { + prepare(tensors); + auto src_perm = tensors.get_tensor(TensorType::ACL_INT_0); + auto weights_perm = tensors.get_tensor(TensorType::ACL_INT_1); + auto dst_perm = tensors.get_tensor(TensorType::ACL_INT_2); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, src); + pack.add_tensor(TensorType::ACL_DST, src_perm); + _permute_input->run(pack); + + ITensorPack pack_depth; + pack_depth.add_const_tensor(TensorType::ACL_SRC_0, src_perm); + pack_depth.add_const_tensor(TensorType::ACL_SRC_1, weights_perm); + pack_depth.add_tensor(TensorType::ACL_SRC_2, biases); + pack_depth.add_tensor(TensorType::ACL_DST, dst_perm); + NEScheduler::get().schedule_op(_depthwise_conv_kernel.get(), Window::DimY, _depthwise_conv_kernel->window(), pack_depth); + } + else + { + ITensorPack pack_depth; + pack_depth.add_tensor(TensorType::ACL_SRC_0, src); + pack_depth.add_tensor(TensorType::ACL_SRC_1, weights); + pack_depth.add_tensor(TensorType::ACL_SRC_2, biases); + pack_depth.add_tensor(TensorType::ACL_DST, dst); + NEScheduler::get().schedule_op(_depthwise_conv_kernel.get(), Window::DimY, _depthwise_conv_kernel->window(), pack_depth); + } + + if(_is_nchw) + { + ITensorPack pack; + auto dst_perm = tensors.get_tensor(TensorType::ACL_INT_2); + pack.add_tensor(TensorType::ACL_SRC, dst_perm); + pack.add_tensor(TensorType::ACL_DST, dst); + _permute_output->run(pack); + } + + if(_is_activationlayer_enabled) + { + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, dst); + pack.add_tensor(TensorType::ACL_DST, dst); + _activationlayer_function->run(pack); + } +} + +void CpuDepthwiseConvolution::CpuDepthwiseConvolutionGeneric::prepare(ITensorPack &tensors) +{ + if(!_is_prepared) + { + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto weights_perm = tensors.get_tensor(TensorType::ACL_INT_1); + + ARM_COMPUTE_ERROR_ON(!weights->is_used()); + + ITensorPack pack; + pack.add_tensor(TensorType::ACL_SRC, weights); + pack.add_tensor(TensorType::ACL_DST, weights_perm); + + _permute_weights->run(pack); + weights->mark_as_unused(); + _is_prepared = true; + } +} + +CpuDepthwiseConvolution::CpuDepthwiseConvolution() + : _depth_conv_func(DepthwiseConvolutionFunction::GENERIC), _func_optimized(), _func_generic() +{ +} + +void CpuDepthwiseConvolution::configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info) +{ + _depth_conv_func = get_depthwiseconvolution_function(input, weights, (biases != nullptr) ? biases : nullptr, output, info); + switch(_depth_conv_func) + { + case DepthwiseConvolutionFunction::OPTIMIZED: + _func_optimized.configure(input, weights, biases, output, info); + break; + case DepthwiseConvolutionFunction::GENERIC: + _func_generic.configure(input, weights, biases, output, info); + break; + default: + ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction"); + } +} + +Status CpuDepthwiseConvolution::validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info) +{ + DepthwiseConvolutionFunction depth_conv_func = get_depthwiseconvolution_function(input, weights, biases, output, info); + switch(depth_conv_func) + { + case DepthwiseConvolutionFunction::OPTIMIZED: + return CpuDepthwiseConvolutionOptimizedInternal::validate(input, weights, biases, output, info); + break; + case DepthwiseConvolutionFunction::GENERIC: + return CpuDepthwiseConvolutionGeneric::validate(input, weights, biases, output, info); + break; + default: + ARM_COMPUTE_ERROR("Unsupported DepthwiseConvolutionFunction"); + } +} + +DepthwiseConvolutionFunction CpuDepthwiseConvolution::get_depthwiseconvolution_function(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const ConvolutionInfo &info) +{ + if(bool(CpuDepthwiseConvolutionOptimizedInternal::validate(input, weights, biases, output, info))) + { + return DepthwiseConvolutionFunction::OPTIMIZED; + } + else + { + return DepthwiseConvolutionFunction::GENERIC; + } +} + +void CpuDepthwiseConvolution::run(ITensorPack &tensors) +{ + switch(_depth_conv_func) + { + case DepthwiseConvolutionFunction::OPTIMIZED: + _func_optimized.run(tensors); + break; + case DepthwiseConvolutionFunction::GENERIC: + _func_generic.run(tensors); + break; + default: + ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured"); + } +} + +void CpuDepthwiseConvolution::prepare(ITensorPack &tensors) +{ + switch(_depth_conv_func) + { + case DepthwiseConvolutionFunction::OPTIMIZED: + _func_optimized.prepare(tensors); + break; + case DepthwiseConvolutionFunction::GENERIC: + _func_generic.prepare(tensors); + break; + default: + ARM_COMPUTE_ERROR("DepthwiseConvolutionFunction not properly configured"); + } +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/runtime/cpu/operators/CpuDepthwiseConvolution.h b/src/runtime/cpu/operators/CpuDepthwiseConvolution.h new file mode 100644 index 0000000000..e39cb7db4d --- /dev/null +++ b/src/runtime/cpu/operators/CpuDepthwiseConvolution.h @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CPU_DEQUANTIZATION_H +#define ARM_COMPUTE_CPU_DEQUANTIZATION_H + +#include "arm_compute/core/ITensorInfo.h" +#include "arm_compute/core/experimental/Types.h" +#include "src/core/cpu/ICpuKernel.h" +#include "src/core/cpu/kernels/CpuDepthwiseConvolutionNativeKernel.h" +#include "src/runtime/cpu/ICpuOperator.h" +#include "src/runtime/cpu/operators/CpuActivation.h" +#include "src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h" +#include "src/runtime/cpu/operators/CpuPermute.h" + +#include + +namespace arm_compute +{ +namespace cpu +{ +/** Function to execute a depthwise convolution. + */ +class CpuDepthwiseConvolution : public ICpuOperator +{ +public: + /** Default constructor */ + CpuDepthwiseConvolution(); + /** Initialize the function's source, destination, weights and convolution information. + * + * @param[in, out] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] weights Weights tensor info. These are 3D tensor infos with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] info Depthwise convolution meta-data. + */ + void configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info); + + /** Static function to check if given info will lead to a valid configuration of @ref CpuDepthwiseConvolution + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + * @param[in] output Destination tensor info. Data type supported: same as @p input. + * @param[in] weights Weights tensor info. These are 3D tensors info with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] info Depthwise convolution meta-data. + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info); + + /** Static function to choose the best depthwise convolution function for @ref CpuDepthwiseConvolution + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] output Destination tensor. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + * + * @return a Depthwise Convolution Function + */ + static DepthwiseConvolutionFunction get_depthwiseconvolution_function(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, + const ConvolutionInfo &info); + + // Inherited methods overriden: + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + +private: + /** Basic function to execute optimized depthwise convolution routines. This function calls the following kernels: + * + * @note At the moment 3x3 and 5x5 convolution of stride 1, 2 are supported + * + * -# @ref NEFillBorderKernel (if pad_x or pad_y > 0) and no assembly kernel implementation is present + * -# @ref CpuDepthwiseConvolution3x3Kernel if 3x3 and no assembly kernel implementation is present + * -# @ref NEDepthwiseConvolutionAssemblyDispatch if assembly kernel implementation is present + * -# @ref NEDirectConvolutionLayerOutputStageKernel if re-quantization of output is required + * -# @ref NEActivationLayer if fused activation is required + * + */ + class CpuDepthwiseConvolutionOptimizedInternal : public ICpuOperator + { + public: + /** Default constructor */ + CpuDepthwiseConvolutionOptimizedInternal(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionOptimizedInternal(const CpuDepthwiseConvolutionOptimizedInternal &) = delete; + /** Default move constructor */ + CpuDepthwiseConvolutionOptimizedInternal(CpuDepthwiseConvolutionOptimizedInternal &&) = default; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionOptimizedInternal &operator=(const CpuDepthwiseConvolutionOptimizedInternal &) = delete; + /** Default move assignment operator */ + CpuDepthwiseConvolutionOptimizedInternal &operator=(CpuDepthwiseConvolutionOptimizedInternal &&) = default; + /** Default destructor */ + ~CpuDepthwiseConvolutionOptimizedInternal() = default; + /** Initialize the function's source, destination, kernels and border_size. + * + * @param[in, out] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. (Written to only for border filling). + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. Data type supported: Same as @p input. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + */ + void configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info); + + /** Static function to check if given info will lead to a valid configuration of @ref CpuDepthwiseConvolution3x3 + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. (Written to only for border filling). + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. Data type supported: Same as @p input. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] output Destination tensor info. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info); + + // Inherited methods overriden: + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + + private: + std::unique_ptr _dwc_optimized_func{ nullptr }; + std::unique_ptr _permute_input{ nullptr }; + std::unique_ptr _permute_weights{ nullptr }; + std::unique_ptr _permute_output{ nullptr }; + std::unique_ptr _activationlayer_function{ nullptr }; + bool _has_bias{ false }; + bool _is_quantized{ false }; + bool _is_nchw{ true }; + bool _permute{ false }; + bool _is_activationlayer_enabled{ false }; + bool _is_prepared{ false }; + }; + + /** Basic function to execute a generic depthwise convolution. This function calls the following kernel: + * + * -# @ref CpuDepthwiseConvolutionNativeKernel + * + */ + class CpuDepthwiseConvolutionGeneric : public ICpuOperator + { + public: + /** Default constructor */ + CpuDepthwiseConvolutionGeneric(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionGeneric(const CpuDepthwiseConvolutionGeneric &) = delete; + /** Default move constructor */ + CpuDepthwiseConvolutionGeneric(CpuDepthwiseConvolutionGeneric &&) = default; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionGeneric &operator=(const CpuDepthwiseConvolutionGeneric &) = delete; + /** Default move assignment operator */ + CpuDepthwiseConvolutionGeneric &operator=(CpuDepthwiseConvolutionGeneric &&) = default; + /** Default destructor */ + ~CpuDepthwiseConvolutionGeneric() = default; + /** Initialize the function's source, destination, weights and convolution information. + * + * @param[in, out] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. (Written to only for border filling). + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] info Depthwise convolution meta-data. + */ + void configure(ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, ITensorInfo *output, const ConvolutionInfo &info); + + /** Static function to check if given info will lead to a valid configuration of @ref CpuDepthwiseConvolutionGeneric + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/QASYMM8_SIGNED/F16/F32. (Written to only for border filling). + * @param[in] output Destination tensor info. Data type supported: same as @p input. + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. + * Data type supported: Same as @p input or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p input is QASYMM8/QASYMM8_SIGNED. + * @param[in] biases Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input, S32 when input is QASYMM8/QASYMM8_SIGNED. + * @param[in] info Depthwise convolution meta-data. + * + * @return a status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *biases, const ITensorInfo *output, const ConvolutionInfo &info); + + // Inherited methods overridden: + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + + private: + std::unique_ptr _depthwise_conv_kernel{ nullptr }; + std::unique_ptr _permute_input{ nullptr }; + std::unique_ptr _permute_weights{ nullptr }; + std::unique_ptr _permute_output{ nullptr }; + std::unique_ptr _activationlayer_function{ nullptr }; + bool _is_nchw{ true }; + bool _is_prepared{ false }; + bool _is_activationlayer_enabled{ false }; + }; + + DepthwiseConvolutionFunction _depth_conv_func; + CpuDepthwiseConvolutionOptimizedInternal _func_optimized; + CpuDepthwiseConvolutionGeneric _func_generic; +}; +} // namespace cpu +} // namespace arm_compute +#endif /* ARM_COMPUTE_CPU_DEQUANTIZATION_H */ diff --git a/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.cpp b/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.cpp new file mode 100644 index 0000000000..5f5304cded --- /dev/null +++ b/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.cpp @@ -0,0 +1,564 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + +#include "src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h" + +#include "arm_compute/core/ITensor.h" +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/utils/misc/InfoHelpers.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" +#include "src/core/CPP/Validate.h" +#include "src/core/NEON/kernels/assembly/NEDepthwiseConvolutionAssemblyKernelWrapper.h" +#include "src/core/NEON/kernels/convolution/depthwise/depthwise_dilated.hpp" +#include "src/core/NEON/kernels/convolution/depthwise/depthwise_quantized_dilated.hpp" +#include "src/core/helpers/AutoConfiguration.h" + +#include "arm_compute/runtime/NEON/NEScheduler.h" + +#include + +namespace arm_compute +{ +namespace cpu +{ +namespace +{ +std::unique_ptr get_qasymm8_convolver(int kernel_size, int stride_x, + int n_batches, int in_rows, int in_cols, int n_channels, + int dilation_factor, neon_convolution_kernels::ActivationFunction activation, + const qasymm8::QAsymm8Params &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo, + const qasymm8::QAsymm8RescaleParams &rescale_params, + int padding_top, int padding_left, int padding_bottom, int padding_right) +{ + switch(kernel_size) + { + case 3: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + case 5: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + default: + return nullptr; + } +} + +std::unique_ptr get_qsymm8_perchannel_convolver(int kernel_size, int stride_x, + int n_batches, int in_rows, int in_cols, int n_channels, + neon_convolution_kernels::ActivationFunction activation, + const qsymm8::QSymm8PerChannelParams &wqinfo, const qasymm8::QAsymm8Params &iqinfo, const qasymm8::QAsymm8Params &oqinfo, + const qsymm8::QSymm8PerChannelRescaleParams &rescale_params, + int padding_top, int padding_left, int padding_bottom, int padding_right) +{ + switch(kernel_size) + { + case 3: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + case 5: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, activation, wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + default: + return nullptr; + } +} + +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +std::unique_ptr get_fp16_convolver(int kernel_size, int stride_x, + int n_batches, int in_rows, int in_cols, int n_channels, + int dilation_factor, neon_convolution_kernels::ActivationFunction activation, + int padding_top, int padding_left, int padding_bottom, int padding_right) +{ + switch(kernel_size) + { + case 3: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + case 5: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + default: + return nullptr; + } +} +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + +std::unique_ptr get_fp32_convolver(int kernel_size, int stride_x, + int n_batches, int in_rows, int in_cols, int n_channels, + int dilation_factor, neon_convolution_kernels::ActivationFunction activation, + int padding_top, int padding_left, int padding_bottom, int padding_right) +{ + switch(kernel_size) + { + case 3: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + case 5: + { + switch(stride_x) + { + case 1: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + case 2: + return std::make_unique>( + n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + default: + return nullptr; + } + } + default: + return nullptr; + } +} + +std::unique_ptr create_convolver(const ITensorInfo *input, + const ITensorInfo *weights, + ITensorInfo *output, + const ConvolutionInfo &info) +{ + const DataType data_type = input->data_type(); + const TensorShape shape = input->tensor_shape(); + + const int n_batches = shape[3]; + const int in_rows = shape.z(); + const int in_cols = shape.y(); + const int n_channels = shape.x(); + const int dilation_factor = info.dilation.x(); + const int padding_top = info.pad_stride_info.pad_top(); + const int padding_left = info.pad_stride_info.pad_left(); + const int padding_bottom = info.pad_stride_info.pad_bottom(); + const int padding_right = info.pad_stride_info.pad_right(); + + const bool is_uniform_quantized = (data_type == DataType::QASYMM8) && (weights->data_type() == DataType::QASYMM8); + const bool is_perchannel_quantized = (data_type == DataType::QASYMM8) && (weights->data_type() == DataType::QSYMM8_PER_CHANNEL); + + const unsigned int stride_x = info.pad_stride_info.stride().first; + const unsigned int kernel_size = weights->tensor_shape().y(); + + // Map activation function + neon_convolution_kernels::ActivationFunction activation = neon_convolution_kernels::ActivationFunction::None; + if(arm_compute::utils::info_helpers::is_relu(info.act_info)) + { + activation = neon_convolution_kernels::ActivationFunction::ReLU; + } + else if(arm_compute::utils::info_helpers::is_relu6(info.act_info)) + { + activation = neon_convolution_kernels::ActivationFunction::ReLU6; + } + + // Create quantized convolver + if(is_uniform_quantized) + { + const UniformQuantizationInfo input_qinfo = input->quantization_info().uniform(); + const UniformQuantizationInfo weights_qinfo = weights->quantization_info().uniform(); + const UniformQuantizationInfo output_qinfo = output->quantization_info().uniform(); + + // Check that quantization info are in the range [0, 255] + ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255); + ARM_COMPUTE_ERROR_ON(weights_qinfo.offset < 0 || weights_qinfo.offset > 255); + ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255); + const qasymm8::QAsymm8Params iqinfo{ static_cast(input_qinfo.offset), input_qinfo.scale }; + const qasymm8::QAsymm8Params wqinfo{ static_cast(weights_qinfo.offset), weights_qinfo.scale }; + const qasymm8::QAsymm8Params oqinfo{ static_cast(output_qinfo.offset), output_qinfo.scale }; + + // Calculate rescale parameters + const float fmultipler = iqinfo.scale * wqinfo.scale / oqinfo.scale; + int32_t qmultiplier = 0; + int32_t qshift = 0; + quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift); + qasymm8::QAsymm8RescaleParams rescale_params(qshift, qmultiplier, fmultipler); + + return get_qasymm8_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, + wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + } + else if(is_perchannel_quantized) + { + const UniformQuantizationInfo input_qinfo = input->quantization_info().uniform(); + const QuantizationInfo weights_qinfo = weights->quantization_info(); + const UniformQuantizationInfo output_qinfo = output->quantization_info().uniform(); + + // Check that quantization info are in the range [0, 255] + ARM_COMPUTE_ERROR_ON(input_qinfo.offset < 0 || input_qinfo.offset > 255); + ARM_COMPUTE_ERROR_ON(output_qinfo.offset < 0 || output_qinfo.offset > 255); + const qasymm8::QAsymm8Params iqinfo{ static_cast(input_qinfo.offset), input_qinfo.scale }; + const qsymm8::QSymm8PerChannelParams wqinfo{ weights_qinfo.scale() }; + const qasymm8::QAsymm8Params oqinfo{ static_cast(output_qinfo.offset), output_qinfo.scale }; + + // Calculate rescale parameters + std::vector fmultipliers; + std::vector qmultipliers; + std::vector qshifts; + + for(auto const s : wqinfo.scales) + { + const float fmultipler = iqinfo.scale * s / oqinfo.scale; + int32_t qmultiplier = 0; + int32_t qshift = 0; + quantization::calculate_quantized_multiplier_less_than_one(fmultipler, &qmultiplier, &qshift); + fmultipliers.push_back(fmultipler); + qmultipliers.push_back(qmultiplier); + qshifts.push_back(qshift); + } + + qsymm8::QSymm8PerChannelRescaleParams rescale_params(qshifts, qmultipliers, fmultipliers); + + return get_qsymm8_perchannel_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, activation, + wqinfo, iqinfo, oqinfo, rescale_params, padding_top, padding_left, padding_bottom, padding_right); + } + else + { + // Create float convolver + switch(data_type) + { +#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F16: + { + return get_fp16_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + } +#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + case DataType::F32: + { + return get_fp32_convolver(kernel_size, stride_x, n_batches, in_rows, in_cols, n_channels, dilation_factor, activation, padding_top, padding_left, padding_bottom, padding_right); + } + default: + return nullptr; + } + } +} +} // namespace + +struct CpuDepthwiseConvolutionAssemblyDispatch::LocalImpl +{ + std::unique_ptr dwc_assembly_kernel{ nullptr }; + NEDepthwiseConvolutionAssemblyKernelWrapper dwc_acl_kernel{}; + bool is_prepared{ false }; + experimental::MemoryRequirements mem_req{}; +}; + +#ifndef DOXYGEN_SKIP_THIS +CpuDepthwiseConvolutionAssemblyDispatch::CpuDepthwiseConvolutionAssemblyDispatch() + : _pImpl(std::make_unique()) +{ +} +#endif /* DOXYGEN_SKIP_THIS */ + +CpuDepthwiseConvolutionAssemblyDispatch::~CpuDepthwiseConvolutionAssemblyDispatch() = default; + +void CpuDepthwiseConvolutionAssemblyDispatch::configure(const ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *bias, + ITensorInfo *output, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights, output); + ARM_COMPUTE_UNUSED(bias); + ARM_COMPUTE_ERROR_THROW_ON(CpuDepthwiseConvolutionAssemblyDispatch::validate(input, + weights, + bias != nullptr ? bias : nullptr, + output, + info)); + + // Output auto inizialitation if not yet initialized + const TensorShape output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info); + auto_init_if_empty(*output, input->clone()->set_is_resizable(true).reset_padding().set_tensor_shape(output_shape).set_quantization_info(output->quantization_info())); + + _pImpl->is_prepared = false; + + // Create convolver + _pImpl->dwc_assembly_kernel = create_convolver(input, weights, output, info); + ARM_COMPUTE_ERROR_ON(_pImpl->dwc_assembly_kernel == nullptr); + + // Create assembly kernel wrapper + _pImpl->dwc_acl_kernel.configure(_pImpl->dwc_assembly_kernel.get()); + + constexpr size_t alignment = 128; + + // Create workspace + const unsigned int num_threads = NEScheduler::get().num_threads(); + const size_t workspace_size = _pImpl->dwc_assembly_kernel->get_working_space_size(num_threads); + ARM_COMPUTE_ERROR_ON_MSG(workspace_size == 0, "Workspace size cannot be 0 !"); + _pImpl->mem_req.push_back({ TensorType::ACL_INT_0, workspace_size, alignment }); + + // Create packing tensor + const size_t pack_tensor_size = _pImpl->dwc_assembly_kernel->get_packed_params_size(); + ARM_COMPUTE_ERROR_ON_MSG(pack_tensor_size == 0, "Pack tensor size cannot be 0 !"); + + _pImpl->mem_req.push_back({ TensorType::ACL_INT_1, pack_tensor_size, alignment }); +} + +experimental::MemoryRequirements CpuDepthwiseConvolutionAssemblyDispatch::workspace() const +{ + return _pImpl->mem_req; +} + +Status CpuDepthwiseConvolutionAssemblyDispatch::validate(const ITensorInfo *input, + const ITensorInfo *weights, + const ITensorInfo *bias, + const ITensorInfo *output, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(input); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::F16, DataType::F32); + if(weights->data_type() != DataType::QSYMM8_PER_CHANNEL) + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, weights); + } + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_LAYOUT(input, weights); + + // Validate convolver + ARM_COMPUTE_RETURN_ERROR_ON(!is_optimized_supported(input, weights, info)); + + // Validate activation + const bool is_relu = arm_compute::utils::info_helpers::is_relu(info.act_info); + const bool is_relu6 = arm_compute::utils::info_helpers::is_relu6(info.act_info); + ARM_COMPUTE_RETURN_ERROR_ON(info.act_info.enabled() && !(is_relu || is_relu6)); + + // Check bias + if(bias != nullptr) + { + unsigned int channel_idx = get_data_layout_dimension_index(input->data_layout(), DataLayoutDimension::CHANNEL); + ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != weights->dimension(channel_idx)); + } + + // Check output + if(output->total_size() != 0) + { + const TensorShape output_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*input, *weights, info); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(output->tensor_shape(), output_shape); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + } + + // The uniform quantization case will only have 1 scale value in the weights quantization info + const UniformQuantizationInfo input_qinfo = input->quantization_info().uniform(); + const QuantizationInfo weights_qinfo = weights->quantization_info(); + const UniformQuantizationInfo output_qinfo = output->quantization_info().uniform(); + for(auto const s : weights_qinfo.scale()) + { + const float fmultipler = input_qinfo.scale * s / output_qinfo.scale; + ARM_COMPUTE_RETURN_ERROR_ON(fmultipler > 1.f); + } + + return Status{}; +} + +bool CpuDepthwiseConvolutionAssemblyDispatch::is_optimized_supported(const ITensorInfo *input, + const ITensorInfo *weights, + const ConvolutionInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(input, weights); + + // Reshape input shape if in NHWC format + const DataLayout data_layout = input->data_layout(); + TensorShape in_shape{ input->tensor_shape() }; + if(data_layout == DataLayout::NHWC) + { + in_shape.set(Window::DimX, input->tensor_shape().y()); + in_shape.set(Window::DimY, input->tensor_shape().z()); + in_shape.set(Window::DimZ, input->tensor_shape().x()); + } + + // Check data type + // TODO (COMPMID-3004): Add assembly optimized routine for QASYMM8_SIGNED NEDepthwiseConvolutionLayer + const DataType input_type = input->data_type(); + const bool is_input_type_valid = is_data_type_float(input_type) || input_type == DataType::QASYMM8; + const DataType weights_type = weights->data_type(); + const bool is_weights_type_valid = is_data_type_float(weights_type) || weights_type == DataType::QASYMM8 || weights_type == DataType::QASYMM8_SIGNED + || weights_type == DataType::QSYMM8_PER_CHANNEL; + + // Check weighs size + std::set supported_kernel_sizes = { 3, 5 }; + const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const unsigned int kernel_w = weights->dimension(width_idx); + const unsigned int kernel_h = weights->dimension(height_idx); + bool weights_supported = (kernel_w == kernel_h) && (supported_kernel_sizes.count(kernel_w) != 0); + + // Check for supported strides + const auto &strides = info.pad_stride_info.stride(); + bool supported_strides = (strides.first == strides.second) && ((strides.first == 1) || (strides.first == 2)); + + // Check for supported padding + const auto pad_top = info.pad_stride_info.pad_top(); + const auto pad_right = info.pad_stride_info.pad_right(); + const auto pad_bottom = info.pad_stride_info.pad_bottom(); + const auto pad_left = info.pad_stride_info.pad_left(); + PadStrideInfo same_pad = calculate_same_pad(in_shape, TensorShape(kernel_w, kernel_h), info.pad_stride_info, DataLayout::NCHW, info.dilation); + bool is_same_padding = (pad_top == same_pad.pad_top()) && (pad_right == same_pad.pad_right()) && (pad_bottom == same_pad.pad_bottom()) && (pad_left == same_pad.pad_left()); + bool is_valid_padding = (pad_top == 0) && (pad_right == 0) && (pad_bottom == 0) && (pad_left == 0); + bool supported_padding = is_same_padding || is_valid_padding; + // TODO(COMPMID-2464): Enable once dilated conv with stride 2 is supported + bool is_dilation_supported = ((info.dilation == Size2D(1U, 1U)) || ((info.dilation.x() == info.dilation.y()) && strides.first == 1)); + + if(weights_type == DataType::QSYMM8_PER_CHANNEL) + { + is_dilation_supported = is_dilation_supported && (info.dilation == Size2D(1U, 1U)); + } + + return is_input_type_valid && is_weights_type_valid && weights_supported && supported_strides && supported_padding && (info.depth_multiplier == 1) && is_dilation_supported; +} + +void CpuDepthwiseConvolutionAssemblyDispatch::run(ITensorPack &tensors) +{ + // Prepare assembly kernel + prepare(tensors); + + auto src = tensors.get_tensor(TensorType::ACL_SRC_0); + auto workspace = tensors.get_tensor(TensorType::ACL_INT_0); + auto dst = tensors.get_tensor(TensorType::ACL_DST); + + // Setup inputs/outputs + ARM_COMPUTE_ERROR_ON(workspace == nullptr && workspace->buffer() == nullptr); + _pImpl->dwc_assembly_kernel->set_working_space(static_cast(workspace->buffer())); + + ARM_COMPUTE_ERROR_ON(workspace->buffer() == nullptr); + const int input_element_size = src->info()->element_size(); + const int input_batch_stride = src->info()->strides_in_bytes()[3] / input_element_size; + const int input_row_stride = src->info()->strides_in_bytes().z() / input_element_size; + const int input_col_stride = src->info()->strides_in_bytes().y() / input_element_size; + const void *input_ptr = src->buffer() + src->info()->offset_first_element_in_bytes(); + _pImpl->dwc_assembly_kernel->set_input(input_ptr, input_batch_stride, input_row_stride, input_col_stride); + + ARM_COMPUTE_ERROR_ON(dst->buffer() == nullptr); + const int output_element_size = dst->info()->element_size(); + const int output_batch_stride = dst->info()->strides_in_bytes()[3] / output_element_size; + const int output_row_stride = dst->info()->strides_in_bytes().z() / output_element_size; + const int output_col_stride = dst->info()->strides_in_bytes().y() / output_element_size; + void *output_ptr = dst->buffer() + dst->info()->offset_first_element_in_bytes(); + _pImpl->dwc_assembly_kernel->set_output(output_ptr, output_batch_stride, output_row_stride, output_col_stride); + + // Schedule assembly kernel + NEScheduler::get().schedule(&_pImpl->dwc_acl_kernel, Window::DimX); +} + +void CpuDepthwiseConvolutionAssemblyDispatch::prepare(ITensorPack &tensors) +{ + if(!_pImpl->is_prepared) + { + auto weights = tensors.get_const_tensor(TensorType::ACL_SRC_1); + auto bias = tensors.get_const_tensor(TensorType::ACL_SRC_2); + auto packed_weights = tensors.get_tensor(TensorType::ACL_INT_1); + + ARM_COMPUTE_ERROR_ON(packed_weights->buffer() == nullptr); + + // Pack weights and bias + const int weights_element_size = weights->info()->element_size(); + const int weights_row_stride = weights->info()->strides_in_bytes().z() / weights_element_size; + const int weights_col_stride = weights->info()->strides_in_bytes().y() / weights_element_size; + _pImpl->dwc_assembly_kernel->pack_params(packed_weights->buffer(), + weights->buffer() + weights->info()->offset_first_element_in_bytes(), + weights_row_stride, + weights_col_stride, + (bias != nullptr) ? bias->buffer() : nullptr); + _pImpl->dwc_assembly_kernel->set_packed_params_buffer(packed_weights->buffer()); + + weights->mark_as_unused(); + if(bias != nullptr) + { + bias->mark_as_unused(); + } + _pImpl->is_prepared = true; + } +} +} // namespace cpu +} // namespace arm_compute diff --git a/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h b/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h new file mode 100644 index 0000000000..6aac74c3ef --- /dev/null +++ b/src/runtime/cpu/operators/CpuDepthwiseConvolutionAssemblyDispatch.h @@ -0,0 +1,97 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CPU_DEPTHWISECONVOLUTIONASSEMBLYDISPATCH_H +#define ARM_COMPUTE_CPU_DEPTHWISECONVOLUTIONASSEMBLYDISPATCH_H + +#include "src/runtime/cpu/ICpuOperator.h" + +namespace arm_compute +{ +namespace cpu +{ +/** Depthwise convolution assembly kernel glue */ +class CpuDepthwiseConvolutionAssemblyDispatch : public ICpuOperator +{ +public: + CpuDepthwiseConvolutionAssemblyDispatch(); + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionAssemblyDispatch(const CpuDepthwiseConvolutionAssemblyDispatch &) = delete; + /** Default move constructor */ + CpuDepthwiseConvolutionAssemblyDispatch(CpuDepthwiseConvolutionAssemblyDispatch &&) = default; + /** Prevent instances of this class from being copied (As this class contains pointers) */ + CpuDepthwiseConvolutionAssemblyDispatch &operator=(const CpuDepthwiseConvolutionAssemblyDispatch &) = delete; + /** Default move assignment operator */ + CpuDepthwiseConvolutionAssemblyDispatch &operator=(CpuDepthwiseConvolutionAssemblyDispatch &&) = default; + /** Default destructor */ + ~CpuDepthwiseConvolutionAssemblyDispatch(); + /** Initialize the function's source, destination, kernels and border_size. + * + * @note Supports only NHWC format + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/F16/F32. (Written to only for border filling). + * @param[in] weights Weights tensor info. These are 3D tensors with shape [W, H, IFM]. Data type supported: Same as @p input. + * @param[in] bias (Optional) Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input. + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + */ + void configure(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *output, const ConvolutionInfo &info); + /** Static function to check if given info will lead to a valid configuration of @ref CpuDepthwiseConvolutionAssemblyDispatch + * + * @note Supports only NHWC format + * + * @param[in] input Source tensor info. Data type supported: QASYMM8/F16/F32. (Written to only for border filling). + * @param[in] weights Weights tensor info. These are 3D tensors with shape [W, H, IFM]. Data type supported: Same as @p input. + * @param[in] bias (Optional) Biases tensor info. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: Same as @p input. + * @param[out] output Destination tensor info. Data type supported: same as @p input. + * @param[in] info Depthwise convolution meta-data. + * + * @return An error status + */ + static Status validate(const ITensorInfo *input, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *output, const ConvolutionInfo &info); + /** Check if the optimized kernel can be used for the given kernel sizes and strides + * + * @warning Even if this return true the inputs and outputs might need to get permuted as the only layout supported is NHWC + * + * @param[in] input Input tensor info. + * @param[in] weights Weights tensor info. + * @param[in] info Depthwise convolution meta-data. + * + * @return True if the assembly kernel could be used else false. Note that transformations of input/output could be needed. + */ + static bool is_optimized_supported(const ITensorInfo *input, const ITensorInfo *weights, const ConvolutionInfo &info); + + // Inherited methods overridden: + void run(ITensorPack &tensors) override; + void prepare(ITensorPack &tensors) override; + experimental::MemoryRequirements workspace() const override; + +private: + struct LocalImpl; + std::unique_ptr _pImpl; +}; +} // namespace cpu +} // namespace arm_compute +#endif /* ARM_COMPUTE_CPU_DEPTHWISECONVOLUTIONASSEMBLYDISPATCH_H */ -- cgit v1.2.1