From d02d5edfa15ba6c04a9986a8a362a945cb38ac31 Mon Sep 17 00:00:00 2001 From: Michele Di Giorgio Date: Fri, 22 Jan 2021 09:47:04 +0000 Subject: Integrate improved CPU depthwise convolution kernels * Replace assembly kernels for depthwise convolution with more optimized ones. * Add int8 assembly kernels. * Fix implicit padding on optimized kernels Resolves: COMPMID-3867, COMPMID-4361 Change-Id: I0b0867e05f61be4f368f62190d55e14d0ab3ebf2 Signed-off-by: Michele Di Giorgio Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5622 Tested-by: Arm Jenkins Reviewed-by: Georgios Pinitas --- .../cpu/kernels/CpuDepthwiseConv2dNativeKernel.cpp | 33 +- .../CpuDepthwiseConv2dAssemblyWrapperKernel.cpp | 359 +++++++++++++++++++++ .../CpuDepthwiseConv2dAssemblyWrapperKernel.h | 120 +++++++ .../internal/CpuPool2dAssemblyWrapperKernel.cpp | 7 +- 4 files changed, 516 insertions(+), 3 deletions(-) create mode 100644 src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp create mode 100644 src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h (limited to 'src/core/cpu/kernels') diff --git a/src/core/cpu/kernels/CpuDepthwiseConv2dNativeKernel.cpp b/src/core/cpu/kernels/CpuDepthwiseConv2dNativeKernel.cpp index 4ddb35f2d5..eac9baaf01 100644 --- a/src/core/cpu/kernels/CpuDepthwiseConv2dNativeKernel.cpp +++ b/src/core/cpu/kernels/CpuDepthwiseConv2dNativeKernel.cpp @@ -28,7 +28,6 @@ #include "arm_compute/core/utils/misc/ShapeCalculator.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "src/core/CPP/Validate.h" -#include "src/core/NEON/kernels/convolution/depthwise/impl_qa8_qa8.hpp" #include "src/core/NEON/wrapper/traits.h" #include "src/core/NEON/wrapper/wrapper.h" #include "src/core/helpers/AutoConfiguration.h" @@ -98,6 +97,38 @@ struct DepthwiseConvolutionRunInfo } }; +inline int32x4_t saturating_doubling_high_mul(const int32x4_t &a, const int32_t &b) +{ + return vqrdmulhq_n_s32(a, b); +} + +inline int32_t saturating_doubling_high_mul(const int32_t &a, const int32_t &b) +{ + return vget_lane_s32(vqrdmulh_n_s32(vdup_n_s32(a), b), 0); +} + +inline int32x4_t rounding_divide_by_exp2(const int32x4_t &x, const int exponent) +{ + const int32x4_t shift = vdupq_n_s32(-exponent); + const int32x4_t fixup = vshrq_n_s32(vandq_s32(x, shift), 31); + const int32x4_t fixed = vqaddq_s32(x, fixup); + return vrshlq_s32(fixed, shift); +} + +inline int32x2_t rounding_divide_by_exp2(const int32x2_t &x, const int exponent) +{ + const int32x2_t shift = vdup_n_s32(-exponent); + const int32x2_t fixup = vshr_n_s32(vand_s32(x, shift), 31); + const int32x2_t fixed = vqadd_s32(x, fixup); + return vrshl_s32(fixed, shift); +} + +inline int32_t rounding_divide_by_exp2(const int32_t &x, const int exponent) +{ + const int32x2_t xs = vdup_n_s32(x); + return vget_lane_s32(rounding_divide_by_exp2(xs, exponent), 0); +} + inline bool is_valid_input_region(int32_t base_w, uint32_t base_h, uint32_t w, uint32_t h, const DepthwiseConvolutionRunInfo &run_info, const Size2D &dilation) { const int32_t current_h = base_h + h * dilation.y(); diff --git a/src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp b/src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp new file mode 100644 index 0000000000..f5c63b763f --- /dev/null +++ b/src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.cpp @@ -0,0 +1,359 @@ +/* + * Copyright (c) 2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h" + +#include "arm_compute/core/Utils.h" +#include "arm_compute/core/Validate.h" +#include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" +#include "src/core/CPP/Validate.h" +#include "src/core/helpers/AutoConfiguration.h" +#include "src/core/helpers/WindowHelpers.h" +#include "src/core/utils/AssemblyUtils.h" + +#include "src/core/NEON/kernels/assembly/depthwise.hpp" + +#include "depthwise_common.hpp" + +#include + +namespace arm_compute +{ +namespace cpu +{ +namespace kernels +{ +using namespace arm_compute::misc::shape_calculator; + +namespace +{ +constexpr unsigned int idx_width = 1; +constexpr unsigned int idx_height = 2; +constexpr unsigned int idx_channels = 0; +constexpr unsigned int idx_batches = 3; + +template +void create_arm_dwc(const ITensorInfo *src, const ITensorInfo *weights, ITensorInfo *dst, + const ConvolutionInfo &info, const CPUInfo &cpu_info, + std::unique_ptr &kernel) +{ + unsigned int stride_cols{}; + unsigned int stride_rows{}; + std::tie(stride_cols, stride_rows) = info.pad_stride_info.stride(); + + const arm_conv::PaddingValues padding = assembly_utils::map_to_arm_conv_padding(info.pad_stride_info); + + const unsigned int n_batches = src->dimension(idx_batches); + const unsigned int src_rows = src->dimension(idx_height); + const unsigned int src_cols = src->dimension(idx_width); + const unsigned int n_channels = src->dimension(idx_channels); + const unsigned int dst_rows = dst->dimension(idx_height); + const unsigned int dst_cols = dst->dimension(idx_width); + + const unsigned int kernel_cols = weights->dimension(idx_width); + const unsigned int kernel_rows = weights->dimension(idx_height); + + const arm_gemm::Activation activation = assembly_utils::map_to_arm_gemm_activation(info.act_info); + + arm_conv::depthwise::DepthwiseArgs args(&cpu_info, kernel_rows, kernel_cols, stride_rows, stride_cols, + n_batches, src_rows, src_cols, n_channels, dst_rows, dst_cols, info.depth_multiplier, + padding, activation, nullptr); + + // Configure assembly pooling kernel + auto dwc_kernel_asm = arm_conv::depthwise::depthwise(args); + if(dwc_kernel_asm == nullptr) + { + // Configuration not supported: Leave function unconfigured: + return; + } + + kernel = std::move(dwc_kernel_asm); +} + +template +void create_arm_dwc_quant(const ITensorInfo *src, const ITensorInfo *weights, ITensorInfo *dst, + const ConvolutionInfo &info, const CPUInfo &cpu_info, + std::unique_ptr &kernel, + std::vector &multipliers, std::vector &right_shifts, std::vector &left_shifts) +{ + unsigned int stride_cols{}; + unsigned int stride_rows{}; + std::tie(stride_cols, stride_rows) = info.pad_stride_info.stride(); + + const arm_conv::PaddingValues padding = assembly_utils::map_to_arm_conv_padding(info.pad_stride_info); + + const unsigned int n_batches = src->dimension(idx_batches); + const unsigned int src_rows = src->dimension(idx_height); + const unsigned int src_cols = src->dimension(idx_width); + const unsigned int n_channels = src->dimension(idx_channels); + const unsigned int dst_rows = dst->dimension(idx_height); + const unsigned int dst_cols = dst->dimension(idx_width); + + const unsigned int kernel_cols = weights->dimension(idx_width); + const unsigned int kernel_rows = weights->dimension(idx_height); + + const arm_gemm::Activation activation = assembly_utils::map_to_arm_gemm_activation(info.act_info); + + arm_conv::depthwise::DepthwiseArgs args(&cpu_info, kernel_rows, kernel_cols, stride_rows, stride_cols, + n_batches, src_rows, src_cols, n_channels, dst_rows, dst_cols, info.depth_multiplier, + padding, activation, nullptr); + + const auto src_qinfo = src->quantization_info().uniform(); + const auto weights_qinfo = weights->quantization_info(); + const auto dst_qinfo = dst->quantization_info().uniform(); + + const unsigned int num_filters = weights_qinfo.scale().size(); + + multipliers.resize(num_filters); + std::vector dst_shifts(num_filters); + quantization::compute_quantized_multipliers_and_shifts(src, + weights, + dst, + multipliers.data(), + dst_shifts.data()); + + // Quantize activation bounds + int32_t min_activation = std::numeric_limits::lowest(); + int32_t max_activation = std::numeric_limits::max(); + if(info.act_info.enabled()) + { + std::tie(min_activation, max_activation) = get_quantized_activation_min_max(info.act_info, src->data_type(), dst_qinfo); + } + + // Set quantization parameters for assembly kernels + arm_gemm::Requantize32 requant_args{}; + if(is_data_type_quantized_per_channel(weights->data_type())) + { + left_shifts.resize(num_filters); + right_shifts.resize(num_filters); + bool need_left_shift = false; // Select more optimized path if left shift is not needed + for(unsigned int i = 0; i < num_filters; ++i) + { + left_shifts[i] = std::max(-dst_shifts[i], static_cast(0)); + right_shifts[i] = std::min(-dst_shifts[i], static_cast(0)); + if(dst_shifts[i] < 0 && !need_left_shift) + { + need_left_shift = true; + } + } + + requant_args = arm_gemm::Requantize32(nullptr, + 0, + src_qinfo.offset, + weights_qinfo.uniform().offset, + dst_qinfo.offset, + (need_left_shift) ? left_shifts.data() : nullptr, + right_shifts.data(), + multipliers.data(), + static_cast(min_activation), + static_cast(max_activation)); + } + else + { + requant_args = arm_gemm::Requantize32(nullptr, + 0, + src_qinfo.offset, + weights_qinfo.uniform().offset, + dst_qinfo.offset, + -dst_shifts[0], + multipliers[0], + static_cast(min_activation), + static_cast(max_activation)); + } + + // Configure assembly pooling kernel with requantization + auto dwc_kernel_asm = arm_conv::depthwise::depthwise(args, requant_args); + if(dwc_kernel_asm == nullptr) + { + // Configuration not supported: Leave function unconfigured: + return; + } + + kernel = std::move(dwc_kernel_asm); +} +} // namespace + +CpuDepthwiseConv2dAssemblyWrapperKernel::CpuDepthwiseConv2dAssemblyWrapperKernel() + : _kernel_asm(nullptr), + _multipliers(), + _left_shifts(), + _right_shifts() +{ +} + +CpuDepthwiseConv2dAssemblyWrapperKernel::~CpuDepthwiseConv2dAssemblyWrapperKernel() = default; + +void CpuDepthwiseConv2dAssemblyWrapperKernel::configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *, ITensorInfo *dst, + const ConvolutionInfo &info, const CPUInfo &cpu_info) +{ + ARM_COMPUTE_UNUSED(cpu_info); + ARM_COMPUTE_ERROR_ON_NULLPTR(src, weights, dst); + + // Destination initialization if not yet initialized + const TensorShape dst_shape = compute_depthwise_convolution_shape(*src, *weights, info); + auto_init_if_empty(*dst, src->clone()->set_tensor_shape(dst_shape)); + +#if defined(__aarch64__) + switch(src->data_type()) + { + case DataType::QASYMM8: + if(is_data_type_quantized_per_channel(weights->data_type())) + { + create_arm_dwc_quant(src, weights, dst, info, cpu_info, _kernel_asm, _multipliers, _right_shifts, _left_shifts); + } + else + { + create_arm_dwc_quant(src, weights, dst, info, cpu_info, _kernel_asm, _multipliers, _right_shifts, _left_shifts); + } + break; + case DataType::QASYMM8_SIGNED: + create_arm_dwc_quant(src, weights, dst, info, cpu_info, _kernel_asm, _multipliers, _right_shifts, _left_shifts); + break; +#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + case DataType::F16: + create_arm_dwc(src, weights, dst, info, cpu_info, _kernel_asm); + break; +#endif // defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + case DataType::F32: + create_arm_dwc(src, weights, dst, info, cpu_info, _kernel_asm); + break; + default: + break; + } +#endif // defined(__aarch64__) + + Window win = calculate_max_window(*dst, Steps()); + ICpuKernel::configure(win); +} + +Status CpuDepthwiseConv2dAssemblyWrapperKernel::validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *dst, const ConvolutionInfo &info) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst); + +#if !defined(__aarch64__) + ARM_COMPUTE_RETURN_ERROR_MSG("32-bit is not supported by assembly kernels"); +#endif // !defined(__aarch64__) + ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(src); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(src->data_layout() != DataLayout::NHWC, "Only NHWC is supported by assembly kernels"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.dilation != Size2D(1, 1), "Assembly kernels do not support dilation != (1, 1)"); + + if(is_data_type_quantized_per_channel(weights->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(weights, 1, DataType::QSYMM8_PER_CHANNEL); + ARM_COMPUTE_RETURN_ERROR_ON(weights->dimension(0) != weights->quantization_info().scale().size()); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, weights); + } + + if(bias != nullptr) + { + ARM_COMPUTE_RETURN_ERROR_ON(bias->num_dimensions() > 1); + ARM_COMPUTE_RETURN_ERROR_ON(bias->dimension(0) != weights->dimension(0)); + + if(is_data_type_quantized(src->data_type())) + { + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(bias, 1, DataType::S32); + } + else + { + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, bias); + } + } + + if(dst->total_size() > 0) + { + const TensorShape dst_shape = misc::shape_calculator::compute_depthwise_convolution_shape(*src, *weights, info); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DIMENSIONS(dst->tensor_shape(), dst_shape); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(src, dst); + } + return Status{}; +} + +void CpuDepthwiseConv2dAssemblyWrapperKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) +{ + ARM_COMPUTE_ERROR_ON_NULLPTR(_kernel_asm.get()); + ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this); + ARM_COMPUTE_UNUSED(window); + ARM_COMPUTE_UNUSED(info); + + ARM_COMPUTE_ERROR_ON(tensors.empty()); + + const ITensor *src = tensors.get_const_tensor(TensorType::ACL_SRC_0); + ITensor *dst = tensors.get_tensor(TensorType::ACL_DST); + ITensor *workspace = tensors.get_tensor(TensorType::ACL_INT_0); + ITensor *storage = tensors.get_tensor(TensorType::ACL_INT_1); + + const auto src_ptr = src->buffer() + src->info()->offset_first_element_in_bytes(); + auto dst_ptr = dst->buffer() + dst->info()->offset_first_element_in_bytes(); + auto working_space = workspace->buffer() + workspace->info()->offset_first_element_in_bytes(); + auto parameters_ptr = storage->buffer() + storage->info()->offset_first_element_in_bytes(); + + const auto src_shape = src->info()->tensor_shape(); + const auto dst_shape = dst->info()->tensor_shape(); + const auto src_padding = src->info()->padding(); + const auto dst_padding = dst->info()->padding(); + + const size_t ld_src_col = src_shape[0] + src_padding.left + src_padding.right; + const size_t ld_src_row = ld_src_col * (src_shape[1] + src_padding.top + src_padding.bottom); + const size_t ld_src_batch = ld_src_row * src_shape[2]; + const size_t ld_dst_col = dst_shape[0] + dst_padding.left + dst_padding.right; + const size_t ld_dst_row = ld_dst_col * (dst_shape[1] + dst_padding.top + dst_padding.bottom); + const size_t ld_dst_batch = ld_dst_row * dst_shape[2]; + + _kernel_asm->execute(src_ptr, ld_src_col, ld_src_row, ld_src_batch, + parameters_ptr, + dst_ptr, ld_dst_col, ld_dst_row, ld_dst_batch, + working_space, info.thread_id, info.num_threads); +} + +void CpuDepthwiseConv2dAssemblyWrapperKernel::pack_parameters(void *parameters_ptr, void *bias_ptr, void *weights_ptr, size_t ld_weights_col, size_t ld_weight_row) +{ + _kernel_asm->pack_parameters(parameters_ptr, bias_ptr, weights_ptr, ld_weights_col, ld_weight_row); +} + +size_t CpuDepthwiseConv2dAssemblyWrapperKernel::get_storage_size() const +{ + return _kernel_asm->get_storage_size(); +} + +size_t CpuDepthwiseConv2dAssemblyWrapperKernel::get_working_size(unsigned int num_threads, unsigned int num_input_channels) const +{ + return _kernel_asm->get_working_size(num_threads, num_input_channels); +} + +bool CpuDepthwiseConv2dAssemblyWrapperKernel::is_configured() const +{ + return _kernel_asm != nullptr; +} + +const char *CpuDepthwiseConv2dAssemblyWrapperKernel::name() const +{ + return "CpuDepthwiseConv2dAssemblyWrapperKernel"; +} +} // namespace kernels +} // namespace cpu +} // namespace arm_compute diff --git a/src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h b/src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h new file mode 100644 index 0000000000..8ff44441e9 --- /dev/null +++ b/src/core/cpu/kernels/internal/CpuDepthwiseConv2dAssemblyWrapperKernel.h @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2019-2021 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#ifndef ARM_COMPUTE_CPU_DEPTHWISE_CONV2D_ASSEMBLY_WRAPPER_KERNEL_H +#define ARM_COMPUTE_CPU_DEPTHWISE_CONV2D_ASSEMBLY_WRAPPER_KERNEL_H + +#include "arm_compute/core/Types.h" +#include "src/core/common/Macros.h" +#include "src/core/cpu/ICpuKernel.h" + +namespace arm_conv +{ +namespace depthwise +{ +// Forward declarations +class IDepthwiseCommon; +} // depthwise +} // arm_conv + +namespace arm_compute +{ +namespace cpu +{ +namespace kernels +{ +/** This class is a wrapper for the depthwise convolution assembly kernels. */ +class CpuDepthwiseConv2dAssemblyWrapperKernel final : public ICpuKernel +{ +public: + /** Default constructor */ + CpuDepthwiseConv2dAssemblyWrapperKernel(); + ~CpuDepthwiseConv2dAssemblyWrapperKernel(); + ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuDepthwiseConv2dAssemblyWrapperKernel); + + /** Initialise the kernel's src and dst. + * + * @param[in] src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32. + * @param[in] weights Weights tensor info. These are 3D tensors with shape [kernel_x, kernel_y, IFM]. + * Data type supported: same as @p src or QASYMM8/QASYMM8_SIGNED/QSYMM8_PER_CHANNEL when @p src is QASYMM8/QASYMM8_SIGNED. + * @param[in] bias Bias tensor. A 1D tensor with shape [IFM]. Must be nullptr if not needed. + * Data type supported: same as @p src, S32 when @p src is QASYMM8/QASYMM8_SIGNED. + * @param[out] dst Destination tensor info. Data type supported: same as @p input. + * @param[in] info Depthwise convolution layer meta-data. + * @param[in] cpu_info CPU information needed to select the most appropriate kernel. + */ + void configure(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *bias, ITensorInfo *dst, const ConvolutionInfo &info, const CPUInfo &cpu_info); + + /** Indicates whether or not this function can be used to process the given parameters. + * + * Similar to @ref CpuDepthwiseConv2dAssemblyWrapperKernel::configure() + * + * @return a status. + */ + static Status validate(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *bias, const ITensorInfo *dst, const ConvolutionInfo &info); + + // Inherited methods overridden: + void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override; + const char *name() const override; + + /** Pack bias and weights in a storage space for the assembly kernel + * + * @param[in] parameters_ptr Pointer to storage space. + * @param[in] bias_ptr Pointer to bias buffer. + * @param[in] weights_ptr Pointer to weights buffer. + * @param[in] ld_weights_col Columns displacement for the weights tensor. + * @param[in] ld_weights_row Rows displacement for the weights tensor. + */ + void pack_parameters(void *parameters_ptr, void *bias_ptr, void *weights_ptr, size_t ld_weights_col, size_t ld_weights_row); + + /** Get the amount of storage space required for the rearranged weights and bias. + * + * @return size of workspace + */ + size_t get_storage_size() const; + + /** Get size of the workspace needed by the assembly kernel. + * + * @param[in] num_threads Maximum number of threads that are going to be spawned. + * @param[in] num_input_channels Number of channels of the input tensor. + * + * @return size of workspace + */ + size_t get_working_size(unsigned int num_threads, unsigned int num_input_channels) const; + + /** Was the asm kernel successfully configured? + * + * @return True if the asm kernel is configured and ready to run + */ + bool is_configured() const; + +private: + std::unique_ptr _kernel_asm; + std::vector _multipliers{}; + std::vector _left_shifts{}; + std::vector _right_shifts{}; +}; +} // namespace kernels +} // namespace cpu +} // namespace arm_compute +#endif /* ARM_COMPUTE_CPU_DEPTHWISE_CONV2D_ASSEMBLY_WRAPPER_KERNEL_H */ diff --git a/src/core/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp b/src/core/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp index c78ffb9848..89dd27a20a 100644 --- a/src/core/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp +++ b/src/core/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp @@ -43,11 +43,13 @@ using namespace arm_compute::misc::shape_calculator; void CpuPool2dAssemblyWrapperKernel::configure(const ITensorInfo *src, ITensorInfo *dst, const PoolingLayerInfo &info, const CPUInfo &cpu_info) { + ARM_COMPUTE_UNUSED(cpu_info); ARM_COMPUTE_ERROR_ON_NULLPTR(src, dst); // dst initialization if not yet initialized auto_init_if_empty(*dst, src->clone()->set_tensor_shape(compute_pool_shape(*src, info))); +#if defined(__aarch64__) const bool requantize = src->quantization_info() != dst->quantization_info(); switch(src->data_type()) @@ -83,6 +85,7 @@ void CpuPool2dAssemblyWrapperKernel::configure(const ITensorInfo *src, ITensorIn default: break; } +#endif // defined(__aarch64__) Window win = calculate_max_window(*dst, Steps()); INEKernel::configure(win); @@ -192,7 +195,7 @@ void CpuPool2dAssemblyWrapperKernel::create_arm_pooling(const ITensorInfo *src, arm_conv::pooling::PoolingStride stride{}; std::tie(stride.cols, stride.rows) = info.pad_stride_info.stride(); - const arm_conv::pooling::PaddingValues padding{ info.pad_stride_info.pad_left(), info.pad_stride_info.pad_top(), info.pad_stride_info.pad_right(), info.pad_stride_info.pad_bottom() }; + const arm_conv::PaddingValues padding{ info.pad_stride_info.pad_left(), info.pad_stride_info.pad_top(), info.pad_stride_info.pad_right(), info.pad_stride_info.pad_bottom() }; constexpr unsigned int idx_width = 1; constexpr unsigned int idx_height = 2; @@ -231,7 +234,7 @@ void CpuPool2dAssemblyWrapperKernel::create_arm_pooling_requant(const ITensorInf arm_conv::pooling::PoolingStride stride{}; std::tie(stride.cols, stride.rows) = info.pad_stride_info.stride(); - const arm_conv::pooling::PaddingValues padding{ info.pad_stride_info.pad_left(), info.pad_stride_info.pad_top(), info.pad_stride_info.pad_right(), info.pad_stride_info.pad_bottom() }; + const arm_conv::PaddingValues padding{ info.pad_stride_info.pad_left(), info.pad_stride_info.pad_top(), info.pad_stride_info.pad_right(), info.pad_stride_info.pad_bottom() }; constexpr unsigned int idx_width = 1; constexpr unsigned int idx_height = 2; -- cgit v1.2.1